summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile13
-rw-r--r--README.md16
-rw-r--r--docs/source/readme.rst41
-rw-r--r--examples/plot_otda_linear_mapping.py138
-rw-r--r--ot/bregman.py30
-rw-r--r--ot/da.py277
-rw-r--r--ot/externals/__init__.py0
-rw-r--r--ot/externals/funcsigs.py817
-rw-r--r--ot/gromov.py4
-rw-r--r--ot/lp/__init__.py3
-rw-r--r--ot/optim.py9
-rw-r--r--ot/utils.py12
-rw-r--r--test/test_da.py46
-rw-r--r--test/test_gromov.py79
-rw-r--r--test/test_utils.py77
15 files changed, 1510 insertions, 52 deletions
diff --git a/Makefile b/Makefile
index 3f19e8a..95714b8 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
-PYTHON=python
+PYTHON=python3
help :
@echo "The following make targets are available:"
@@ -41,14 +41,14 @@ pep8 :
flake8 examples/ ot/ test/
test : FORCE pep8
- python -m py.test -v test/ --cov=ot --cov-report html:cov_html
+ $(PYTHON) -m pytest -v test/ --cov=ot --cov-report html:cov_html
pytest : FORCE
- python -m py.test -v test/ --cov=ot
+ $(PYTHON) -m py.test -v test/ --cov=ot
uploadpypi :
#python setup.py register
- python setup.py sdist upload -r pypi
+ $(PYTHON) setup.py sdist upload -r pypi
rdoc :
pandoc --from=markdown --to=rst --output=docs/source/readme.rst README.md
@@ -56,6 +56,11 @@ rdoc :
notebook :
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
+
+autopep8 :
+ autopep8 -ir test ot examples
+aautopep8 :
+ autopep8 -air test ot examples
FORCE :
diff --git a/README.md b/README.md
index 8a9d7fa..6b7cff0 100644
--- a/README.md
+++ b/README.md
@@ -13,14 +13,14 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:
-* OT solver for the linear program/ Earth Movers Distance [1].
+* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat).
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
-* Joint OT matrix and mapping estimation [8].
+* Linear OT [14] and Joint OT matrix and mapping estimation [8].
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
-* Gromov-Wasserstein distances and barycenters [12]
+* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
@@ -195,7 +195,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). [Generalized conditional gradient: analysis of convergence and applications](https://arxiv.org/pdf/1510.06567.pdf). arXiv preprint arXiv:1510.06567.
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS), 2016.
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS).
[9] Schmitzer, B. (2016). [Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/pdf/1610.06519.pdf). arXiv preprint arXiv:1610.06519.
@@ -203,6 +203,10 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
-[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML).
-[13] Mémoli, Facundo. [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 (2011): 417-487.
+[13] Mémoli, Facundo (2011). [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 : 417-487.
+
+[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43.
+
+[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) .
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 347bde2..725c207 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -1,8 +1,8 @@
POT: Python Optimal Transport
=============================
-|PyPI version| |Build Status| |Documentation Status| |Anaconda Cloud|
-|License| |Anaconda downloads|
+|PyPI version| |Anaconda Cloud| |Build Status| |Documentation Status|
+|Anaconda downloads| |License|
This open source Python library provide several solvers for optimization
problems related to Optimal Transport for signal, image processing and
@@ -10,7 +10,8 @@ machine learning.
It provides the following solvers:
-- OT solver for the linear program/ Earth Movers Distance [1].
+- OT Network Flow solver for the linear program/ Earth Movers Distance
+ [1].
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]
and stabilized version [9][10] with optional GPU implementation
(required cudamat).
@@ -19,10 +20,11 @@ It provides the following solvers:
regularization [5]
- Conditional gradient [6] and Generalized conditional gradient for
regularized OT [7].
-- Joint OT matrix and mapping estimation [8].
+- Linear OT [14] and Joint OT matrix and mapping estimation [8].
- Wasserstein Discriminant Analysis [11] (requires autograd +
pymanopt).
-- Gromov-Wasserstein distances and barycenters [12]
+- Gromov-Wasserstein distances and barycenters ([13] and regularized
+ [12])
Some demonstrations (both in Python and Jupyter Notebook format) are
available in the examples folder.
@@ -281,10 +283,10 @@ conditional gradient: analysis of convergence and
applications <https://arxiv.org/pdf/1510.06567.pdf>`__. arXiv preprint
arXiv:1510.06567.
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, `Mapping estimation
-for discrete optimal
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), `Mapping
+estimation for discrete optimal
transport <http://remi.flamary.com/biblio/perrot2016mapping.pdf>`__,
-Neural Information Processing Systems (NIPS), 2016.
+Neural Information Processing Systems (NIPS).
[9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport
@@ -301,25 +303,32 @@ arXiv:1607.05816.
Analysis <https://arxiv.org/pdf/1608.08063.pdf>`__. arXiv preprint
arXiv:1608.08063.
-[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
`Gromov-Wasserstein averaging of kernel and distance
matrices <http://proceedings.mlr.press/v48/peyre16.html>`__
-International Conference on Machine Learning (ICML). 2016.
+International Conference on Machine Learning (ICML).
-[13] Mémoli, Facundo. `Gromov–Wasserstein distances and the metric
-approach to object
+[13] Mémoli, Facundo (2011). `Gromov–Wasserstein distances and the
+metric approach to object
matching <https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf>`__.
-Foundations of computational mathematics 11.4 (2011): 417-487.
+Foundations of computational mathematics 11.4 : 417-487.
+
+[14] Knott, M. and Smith, C. S. (1984).`On the optimal mapping of
+distributions <https://link.springer.com/article/10.1007/BF00934745>`__,
+Journal of Optimization Theory and Applications Vol 43.
+
+[15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal
+Transport <https://arxiv.org/pdf/1803.00567.pdf>`__ .
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
+.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
+ :target: https://anaconda.org/conda-forge/pot
.. |Build Status| image:: https://travis-ci.org/rflamary/POT.svg?branch=master
:target: https://travis-ci.org/rflamary/POT
.. |Documentation Status| image:: https://readthedocs.org/projects/pot/badge/?version=latest
:target: http://pot.readthedocs.io/en/latest/?badge=latest
-.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
+.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg
:target: https://anaconda.org/conda-forge/pot
.. |License| image:: https://anaconda.org/conda-forge/pot/badges/license.svg
:target: https://github.com/rflamary/POT/blob/master/LICENSE
-.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg
- :target: https://anaconda.org/conda-forge/pot
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
new file mode 100644
index 0000000..7a3b761
--- /dev/null
+++ b/examples/plot_otda_linear_mapping.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Mar 20 14:31:15 2018
+
+@author: rflamary
+"""
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+n = 1000
+d = 2
+sigma = .1
+
+# source samples
+angles = np.random.rand(n, 1) * 2 * np.pi
+xs = np.concatenate((np.sin(angles), np.cos(angles)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xs[:n // 2, 1] += 2
+
+
+# target samples
+anglet = np.random.rand(n, 1) * 2 * np.pi
+xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xt[:n // 2, 1] += 2
+
+
+A = np.array([[1.5, .7], [.7, 1.5]])
+b = np.array([[4, 2]])
+xt = xt.dot(A) + b
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1, (5, 5))
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+
+
+##############################################################################
+# Estimate linear mapping and transport
+# -------------------------------------
+
+Ae, be = ot.da.OT_mapping_linear(xs, xt)
+
+xst = xs.dot(Ae) + be
+
+
+##############################################################################
+# Plot transported samples
+# ------------------------
+
+pl.figure(1, (5, 5))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+
+pl.show()
+
+##############################################################################
+# Load image data
+# ---------------
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+# Loading images
+I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+##############################################################################
+# Estimate mapping and adapt
+# ----------------------------
+
+mapping = ot.da.LinearTransport()
+
+mapping.fit(Xs=X1, Xt=X2)
+
+
+xst = mapping.transform(Xs=X1)
+xts = mapping.inverse_transform(Xt=X2)
+
+I1t = minmax(mat2im(xst, I1.shape))
+I2t = minmax(mat2im(xts, I2.shape))
+
+# %%
+
+
+##############################################################################
+# Plot transformed images
+# -----------------------
+
+pl.figure(2, figsize=(10, 7))
+
+pl.subplot(2, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Im. 1')
+
+pl.subplot(2, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Im. 2')
+
+pl.subplot(2, 2, 3)
+pl.imshow(I1t)
+pl.axis('off')
+pl.title('Mapping Im. 1')
+
+pl.subplot(2, 2, 4)
+pl.imshow(I2t)
+pl.axis('off')
+pl.title('Inverse mapping Im. 2')
diff --git a/ot/bregman.py b/ot/bregman.py
index d63c51d..07b8660 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,7 +11,8 @@ Bregman projections for regularized OT
import numpy as np
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -120,7 +121,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ver
return sink()
-def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
u"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -233,7 +235,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, ve
return sink()
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
Solve the entropic regularization OT problem with log stabilization
@@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) -
+ beta.reshape((1, nb))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) /
+ reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
# print(np.min(K))
@@ -620,7 +626,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
+ tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((na, 1)) -
+ beta.reshape((1, nb))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -811,7 +819,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F
return geometricBar(weights, UKv)
-def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):
+def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
+ stopThr=1e-3, verbose=False, log=False):
"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
diff --git a/ot/da.py b/ot/da.py
index 532dcd2..48b418f 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,6 +10,7 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
+import scipy.linalg as linalg
from .bregman import sinkhorn
from .lp import emd
@@ -356,7 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
+ return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
def solve_L(G):
""" solve L problem with fixed G (least square)"""
@@ -556,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -633,6 +636,110 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
+def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ """ return OT linear operator between samples
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of convariances (>0)
+ ws : np.ndarray (ns,1), optional
+ weights for the source samples
+ wt : np.ndarray (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d x d) ndarray
+ Linear operator
+ b : (1 x d) ndarray
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+
+ """
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = xs.mean(0, keepdims=True)
+ mxt = xt.mean(0, keepdims=True)
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = np.zeros((1, d))
+ mxt = np.zeros((1, d))
+
+ if ws is None:
+ ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+
+ if wt is None:
+ wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+
+ Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
+ Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+
+ Cs12 = linalg.sqrtm(Cs)
+ Cs_12 = linalg.inv(Cs12)
+
+ M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+
+ A = Cs_12.dot(M0.dot(Cs_12))
+
+ b = mxt - mxs.dot(A)
+
+ if log:
+ log = {}
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ log['Cs12'] = Cs12
+ log['Cs_12'] = Cs_12
+ return A, b, log
+ else:
+ return A, b
+
+
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
"removed in 0.5"
"\n\tfor standard transport use class EMDTransport instead.")
@@ -1180,6 +1287,172 @@ class BaseTransport(BaseEstimator):
return transp_Xt
+class LinearTransport(BaseTransport):
+ """ OT linear operator between empirical distributions
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in
+ remark 2.29 in [15].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ """
+
+ def __init__(self, reg=1e-8, bias=True, log=False,
+ distribution_estimation=distribution_estimation_uniform):
+
+ self.bias = bias
+ self.log = log
+ self.reg = reg
+ self.distribution_estimation = distribution_estimation
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+ # coupling estimation
+ returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
+ ws=self.mu_s.reshape((-1, 1)),
+ wt=self.mu_t.reshape((-1, 1)),
+ bias=self.bias, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.A_, self.B_, self.log_ = returned_
+ else:
+ self.A_, self.B_, = returned_
+ self.log_ = dict()
+
+ # re compute inverse mapping
+ self.A1_ = linalg.inv(self.A_)
+ self.B1_ = -self.B_.dot(self.A1_)
+
+ return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ transp_Xs = Xs.dot(self.A_) + self.B_
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ transp_Xt = Xt.dot(self.A1_) + self.B1_
+
+ return transp_Xt
+
+
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm
diff --git a/ot/externals/__init__.py b/ot/externals/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/ot/externals/__init__.py
diff --git a/ot/externals/funcsigs.py b/ot/externals/funcsigs.py
new file mode 100644
index 0000000..c73fdc9
--- /dev/null
+++ b/ot/externals/funcsigs.py
@@ -0,0 +1,817 @@
+# Copyright 2001-2013 Python Software Foundation; All Rights Reserved
+"""Function signature objects for callables
+
+Back port of Python 3.3's function signature tools from the inspect module,
+modified to be compatible with Python 2.7 and 3.2+.
+"""
+from __future__ import absolute_import, division, print_function
+import itertools
+import functools
+import re
+import types
+
+from collections import OrderedDict
+
+__version__ = "0.4"
+
+__all__ = ['BoundArguments', 'Parameter', 'Signature', 'signature']
+
+
+_WrapperDescriptor = type(type.__call__)
+_MethodWrapper = type(all.__call__)
+
+_NonUserDefinedCallables = (_WrapperDescriptor,
+ _MethodWrapper,
+ types.BuiltinFunctionType)
+
+
+def formatannotation(annotation, base_module=None):
+ if isinstance(annotation, type):
+ if annotation.__module__ in ('builtins', '__builtin__', base_module):
+ return annotation.__name__
+ return annotation.__module__ + '.' + annotation.__name__
+ return repr(annotation)
+
+
+def _get_user_defined_method(cls, method_name, *nested):
+ try:
+ if cls is type:
+ return
+ meth = getattr(cls, method_name)
+ for name in nested:
+ meth = getattr(meth, name, meth)
+ except AttributeError:
+ return
+ else:
+ if not isinstance(meth, _NonUserDefinedCallables):
+ # Once '__signature__' will be added to 'C'-level
+ # callables, this check won't be necessary
+ return meth
+
+
+def signature(obj):
+ '''Get a signature object for the passed callable.'''
+
+ if not callable(obj):
+ raise TypeError('{0!r} is not a callable object'.format(obj))
+
+ if isinstance(obj, types.MethodType):
+ sig = signature(obj.__func__)
+ if obj.__self__ is None:
+ # Unbound method: the first parameter becomes positional-only
+ if sig.parameters:
+ first = sig.parameters.values()[0].replace(
+ kind=_POSITIONAL_ONLY)
+ return sig.replace(
+ parameters=(first,) + tuple(sig.parameters.values())[1:])
+ else:
+ return sig
+ else:
+ # In this case we skip the first parameter of the underlying
+ # function (usually `self` or `cls`).
+ return sig.replace(parameters=tuple(sig.parameters.values())[1:])
+
+ try:
+ sig = obj.__signature__
+ except AttributeError:
+ pass
+ else:
+ if sig is not None:
+ return sig
+
+ try:
+ # Was this function wrapped by a decorator?
+ wrapped = obj.__wrapped__
+ except AttributeError:
+ pass
+ else:
+ return signature(wrapped)
+
+ if isinstance(obj, types.FunctionType):
+ return Signature.from_function(obj)
+
+ if isinstance(obj, functools.partial):
+ sig = signature(obj.func)
+
+ new_params = OrderedDict(sig.parameters.items())
+
+ partial_args = obj.args or ()
+ partial_keywords = obj.keywords or {}
+ try:
+ ba = sig.bind_partial(*partial_args, **partial_keywords)
+ except TypeError:
+ msg = 'partial object {0!r} has incorrect arguments'.format(obj)
+ raise ValueError(msg)
+
+ for arg_name, arg_value in ba.arguments.items():
+ param = new_params[arg_name]
+ if arg_name in partial_keywords:
+ # We set a new default value, because the following code
+ # is correct:
+ #
+ # >>> def foo(a): print(a)
+ # >>> print(partial(partial(foo, a=10), a=20)())
+ # 20
+ # >>> print(partial(partial(foo, a=10), a=20)(a=30))
+ # 30
+ #
+ # So, with 'partial' objects, passing a keyword argument is
+ # like setting a new default value for the corresponding
+ # parameter
+ #
+ # We also mark this parameter with '_partial_kwarg'
+ # flag. Later, in '_bind', the 'default' value of this
+ # parameter will be added to 'kwargs', to simulate
+ # the 'functools.partial' real call.
+ new_params[arg_name] = param.replace(default=arg_value,
+ _partial_kwarg=True)
+
+ elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) and
+ not param._partial_kwarg):
+ new_params.pop(arg_name)
+
+ return sig.replace(parameters=new_params.values())
+
+ sig = None
+ if isinstance(obj, type):
+ # obj is a class or a metaclass
+
+ # First, let's see if it has an overloaded __call__ defined
+ # in its metaclass
+ call = _get_user_defined_method(type(obj), '__call__')
+ if call is not None:
+ sig = signature(call)
+ else:
+ # Now we check if the 'obj' class has a '__new__' method
+ new = _get_user_defined_method(obj, '__new__')
+ if new is not None:
+ sig = signature(new)
+ else:
+ # Finally, we should have at least __init__ implemented
+ init = _get_user_defined_method(obj, '__init__')
+ if init is not None:
+ sig = signature(init)
+ elif not isinstance(obj, _NonUserDefinedCallables):
+ # An object with __call__
+ # We also check that the 'obj' is not an instance of
+ # _WrapperDescriptor or _MethodWrapper to avoid
+ # infinite recursion (and even potential segfault)
+ call = _get_user_defined_method(type(obj), '__call__', 'im_func')
+ if call is not None:
+ sig = signature(call)
+
+ if sig is not None:
+ # For classes and objects we skip the first parameter of their
+ # __call__, __new__, or __init__ methods
+ return sig.replace(parameters=tuple(sig.parameters.values())[1:])
+
+ if isinstance(obj, types.BuiltinFunctionType):
+ # Raise a nicer error message for builtins
+ msg = 'no signature found for builtin function {0!r}'.format(obj)
+ raise ValueError(msg)
+
+ raise ValueError(
+ 'callable {0!r} is not supported by signature'.format(obj))
+
+
+class _void(object):
+ '''A private marker - used in Parameter & Signature'''
+
+
+class _empty(object):
+ pass
+
+
+class _ParameterKind(int):
+ def __new__(self, *args, **kwargs):
+ obj = int.__new__(self, *args)
+ obj._name = kwargs['name']
+ return obj
+
+ def __str__(self):
+ return self._name
+
+ def __repr__(self):
+ return '<_ParameterKind: {0!r}>'.format(self._name)
+
+
+_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY')
+_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD')
+_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL')
+_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY')
+_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD')
+
+
+class Parameter(object):
+ '''Represents a parameter in a function signature.
+
+ Has the following public attributes:
+
+ * name : str
+ The name of the parameter as a string.
+ * default : object
+ The default value for the parameter if specified. If the
+ parameter has no default value, this attribute is not set.
+ * annotation
+ The annotation for the parameter if specified. If the
+ parameter has no annotation, this attribute is not set.
+ * kind : str
+ Describes how argument values are bound to the parameter.
+ Possible values: `Parameter.POSITIONAL_ONLY`,
+ `Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`,
+ `Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`.
+ '''
+
+ __slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg')
+
+ POSITIONAL_ONLY = _POSITIONAL_ONLY
+ POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD
+ VAR_POSITIONAL = _VAR_POSITIONAL
+ KEYWORD_ONLY = _KEYWORD_ONLY
+ VAR_KEYWORD = _VAR_KEYWORD
+
+ empty = _empty
+
+ def __init__(self, name, kind, default=_empty, annotation=_empty,
+ _partial_kwarg=False):
+
+ if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD,
+ _VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD):
+ raise ValueError("invalid value for 'Parameter.kind' attribute")
+ self._kind = kind
+
+ if default is not _empty:
+ if kind in (_VAR_POSITIONAL, _VAR_KEYWORD):
+ msg = '{0} parameters cannot have default values'.format(kind)
+ raise ValueError(msg)
+ self._default = default
+ self._annotation = annotation
+
+ if name is None:
+ if kind != _POSITIONAL_ONLY:
+ raise ValueError("None is not a valid name for a "
+ "non-positional-only parameter")
+ self._name = name
+ else:
+ name = str(name)
+ if kind != _POSITIONAL_ONLY and not re.match(
+ r'[a-z_]\w*$', name, re.I):
+ msg = '{0!r} is not a valid parameter name'.format(name)
+ raise ValueError(msg)
+ self._name = name
+
+ self._partial_kwarg = _partial_kwarg
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def default(self):
+ return self._default
+
+ @property
+ def annotation(self):
+ return self._annotation
+
+ @property
+ def kind(self):
+ return self._kind
+
+ def replace(self, name=_void, kind=_void, annotation=_void,
+ default=_void, _partial_kwarg=_void):
+ '''Creates a customized copy of the Parameter.'''
+
+ if name is _void:
+ name = self._name
+
+ if kind is _void:
+ kind = self._kind
+
+ if annotation is _void:
+ annotation = self._annotation
+
+ if default is _void:
+ default = self._default
+
+ if _partial_kwarg is _void:
+ _partial_kwarg = self._partial_kwarg
+
+ return type(self)(name, kind, default=default, annotation=annotation,
+ _partial_kwarg=_partial_kwarg)
+
+ def __str__(self):
+ kind = self.kind
+
+ formatted = self._name
+ if kind == _POSITIONAL_ONLY:
+ if formatted is None:
+ formatted = ''
+ formatted = '<{0}>'.format(formatted)
+
+ # Add annotation and default value
+ if self._annotation is not _empty:
+ formatted = '{0}:{1}'.format(formatted,
+ formatannotation(self._annotation))
+
+ if self._default is not _empty:
+ formatted = '{0}={1}'.format(formatted, repr(self._default))
+
+ if kind == _VAR_POSITIONAL:
+ formatted = '*' + formatted
+ elif kind == _VAR_KEYWORD:
+ formatted = '**' + formatted
+
+ return formatted
+
+ def __repr__(self):
+ return '<{0} at {1:#x} {2!r}>'.format(self.__class__.__name__,
+ id(self), self.name)
+
+ def __hash__(self):
+ msg = "unhashable type: '{0}'".format(self.__class__.__name__)
+ raise TypeError(msg)
+
+ def __eq__(self, other):
+ return (issubclass(other.__class__, Parameter) and
+ self._name == other._name and
+ self._kind == other._kind and
+ self._default == other._default and
+ self._annotation == other._annotation)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class BoundArguments(object):
+ '''Result of `Signature.bind` call. Holds the mapping of arguments
+ to the function's parameters.
+
+ Has the following public attributes:
+
+ * arguments : OrderedDict
+ An ordered mutable mapping of parameters' names to arguments' values.
+ Does not contain arguments' default values.
+ * signature : Signature
+ The Signature object that created this instance.
+ * args : tuple
+ Tuple of positional arguments values.
+ * kwargs : dict
+ Dict of keyword arguments values.
+ '''
+
+ def __init__(self, signature, arguments):
+ self.arguments = arguments
+ self._signature = signature
+
+ @property
+ def signature(self):
+ return self._signature
+
+ @property
+ def args(self):
+ args = []
+ for param_name, param in self._signature.parameters.items():
+ if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or
+ param._partial_kwarg):
+ # Keyword arguments mapped by 'functools.partial'
+ # (Parameter._partial_kwarg is True) are mapped
+ # in 'BoundArguments.kwargs', along with VAR_KEYWORD &
+ # KEYWORD_ONLY
+ break
+
+ try:
+ arg = self.arguments[param_name]
+ except KeyError:
+ # We're done here. Other arguments
+ # will be mapped in 'BoundArguments.kwargs'
+ break
+ else:
+ if param.kind == _VAR_POSITIONAL:
+ # *args
+ args.extend(arg)
+ else:
+ # plain argument
+ args.append(arg)
+
+ return tuple(args)
+
+ @property
+ def kwargs(self):
+ kwargs = {}
+ kwargs_started = False
+ for param_name, param in self._signature.parameters.items():
+ if not kwargs_started:
+ if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or
+ param._partial_kwarg):
+ kwargs_started = True
+ else:
+ if param_name not in self.arguments:
+ kwargs_started = True
+ continue
+
+ if not kwargs_started:
+ continue
+
+ try:
+ arg = self.arguments[param_name]
+ except KeyError:
+ pass
+ else:
+ if param.kind == _VAR_KEYWORD:
+ # **kwargs
+ kwargs.update(arg)
+ else:
+ # plain keyword argument
+ kwargs[param_name] = arg
+
+ return kwargs
+
+ def __hash__(self):
+ msg = "unhashable type: '{0}'".format(self.__class__.__name__)
+ raise TypeError(msg)
+
+ def __eq__(self, other):
+ return (issubclass(other.__class__, BoundArguments) and
+ self.signature == other.signature and
+ self.arguments == other.arguments)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class Signature(object):
+ '''A Signature object represents the overall signature of a function.
+ It stores a Parameter object for each parameter accepted by the
+ function, as well as information specific to the function itself.
+
+ A Signature object has the following public attributes and methods:
+
+ * parameters : OrderedDict
+ An ordered mapping of parameters' names to the corresponding
+ Parameter objects (keyword-only arguments are in the same order
+ as listed in `code.co_varnames`).
+ * return_annotation : object
+ The annotation for the return type of the function if specified.
+ If the function has no annotation for its return type, this
+ attribute is not set.
+ * bind(*args, **kwargs) -> BoundArguments
+ Creates a mapping from positional and keyword arguments to
+ parameters.
+ * bind_partial(*args, **kwargs) -> BoundArguments
+ Creates a partial mapping from positional and keyword arguments
+ to parameters (simulating 'functools.partial' behavior.)
+ '''
+
+ __slots__ = ('_return_annotation', '_parameters')
+
+ _parameter_cls = Parameter
+ _bound_arguments_cls = BoundArguments
+
+ empty = _empty
+
+ def __init__(self, parameters=None, return_annotation=_empty,
+ __validate_parameters__=True):
+ '''Constructs Signature from the given list of Parameter
+ objects and 'return_annotation'. All arguments are optional.
+ '''
+
+ if parameters is None:
+ params = OrderedDict()
+ else:
+ if __validate_parameters__:
+ params = OrderedDict()
+ top_kind = _POSITIONAL_ONLY
+
+ for idx, param in enumerate(parameters):
+ kind = param.kind
+ if kind < top_kind:
+ msg = 'wrong parameter order: {0} before {1}'
+ msg = msg.format(top_kind, param.kind)
+ raise ValueError(msg)
+ else:
+ top_kind = kind
+
+ name = param.name
+ if name is None:
+ name = str(idx)
+ param = param.replace(name=name)
+
+ if name in params:
+ msg = 'duplicate parameter name: {0!r}'.format(name)
+ raise ValueError(msg)
+ params[name] = param
+ else:
+ params = OrderedDict(((param.name, param)
+ for param in parameters))
+
+ self._parameters = params
+ self._return_annotation = return_annotation
+
+ @classmethod
+ def from_function(cls, func):
+ '''Constructs Signature for the given python function'''
+
+ if not isinstance(func, types.FunctionType):
+ raise TypeError('{0!r} is not a Python function'.format(func))
+
+ Parameter = cls._parameter_cls
+
+ # Parameter information.
+ func_code = func.__code__
+ pos_count = func_code.co_argcount
+ arg_names = func_code.co_varnames
+ positional = tuple(arg_names[:pos_count])
+ keyword_only_count = getattr(func_code, 'co_kwonlyargcount', 0)
+ keyword_only = arg_names[pos_count:(pos_count + keyword_only_count)]
+ annotations = getattr(func, '__annotations__', {})
+ defaults = func.__defaults__
+ kwdefaults = getattr(func, '__kwdefaults__', None)
+
+ if defaults:
+ pos_default_count = len(defaults)
+ else:
+ pos_default_count = 0
+
+ parameters = []
+
+ # Non-keyword-only parameters w/o defaults.
+ non_default_count = pos_count - pos_default_count
+ for name in positional[:non_default_count]:
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_POSITIONAL_OR_KEYWORD))
+
+ # ... w/ defaults.
+ for offset, name in enumerate(positional[non_default_count:]):
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_POSITIONAL_OR_KEYWORD,
+ default=defaults[offset]))
+
+ # *args
+ if func_code.co_flags & 0x04:
+ name = arg_names[pos_count + keyword_only_count]
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_VAR_POSITIONAL))
+
+ # Keyword-only parameters.
+ for name in keyword_only:
+ default = _empty
+ if kwdefaults is not None:
+ default = kwdefaults.get(name, _empty)
+
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_KEYWORD_ONLY,
+ default=default))
+ # **kwargs
+ if func_code.co_flags & 0x08:
+ index = pos_count + keyword_only_count
+ if func_code.co_flags & 0x04:
+ index += 1
+
+ name = arg_names[index]
+ annotation = annotations.get(name, _empty)
+ parameters.append(Parameter(name, annotation=annotation,
+ kind=_VAR_KEYWORD))
+
+ return cls(parameters,
+ return_annotation=annotations.get('return', _empty),
+ __validate_parameters__=False)
+
+ @property
+ def parameters(self):
+ try:
+ return types.MappingProxyType(self._parameters)
+ except AttributeError:
+ return OrderedDict(self._parameters.items())
+
+ @property
+ def return_annotation(self):
+ return self._return_annotation
+
+ def replace(self, parameters=_void, return_annotation=_void):
+ '''Creates a customized copy of the Signature.
+ Pass 'parameters' and/or 'return_annotation' arguments
+ to override them in the new copy.
+ '''
+
+ if parameters is _void:
+ parameters = self.parameters.values()
+
+ if return_annotation is _void:
+ return_annotation = self._return_annotation
+
+ return type(self)(parameters,
+ return_annotation=return_annotation)
+
+ def __hash__(self):
+ msg = "unhashable type: '{0}'".format(self.__class__.__name__)
+ raise TypeError(msg)
+
+ def __eq__(self, other):
+ if (not issubclass(type(other), Signature) or
+ self.return_annotation != other.return_annotation or
+ len(self.parameters) != len(other.parameters)):
+ return False
+
+ other_positions = dict((param, idx)
+ for idx, param in enumerate(other.parameters.keys()))
+
+ for idx, (param_name, param) in enumerate(self.parameters.items()):
+ if param.kind == _KEYWORD_ONLY:
+ try:
+ other_param = other.parameters[param_name]
+ except KeyError:
+ return False
+ else:
+ if param != other_param:
+ return False
+ else:
+ try:
+ other_idx = other_positions[param_name]
+ except KeyError:
+ return False
+ else:
+ if (idx != other_idx or
+ param != other.parameters[param_name]):
+ return False
+
+ return True
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _bind(self, args, kwargs, partial=False):
+ '''Private method. Don't use directly.'''
+
+ arguments = OrderedDict()
+
+ parameters = iter(self.parameters.values())
+ parameters_ex = ()
+ arg_vals = iter(args)
+
+ if partial:
+ # Support for binding arguments to 'functools.partial' objects.
+ # See 'functools.partial' case in 'signature()' implementation
+ # for details.
+ for param_name, param in self.parameters.items():
+ if (param._partial_kwarg and param_name not in kwargs):
+ # Simulating 'functools.partial' behavior
+ kwargs[param_name] = param.default
+
+ while True:
+ # Let's iterate through the positional arguments and corresponding
+ # parameters
+ try:
+ arg_val = next(arg_vals)
+ except StopIteration:
+ # No more positional arguments
+ try:
+ param = next(parameters)
+ except StopIteration:
+ # No more parameters. That's it. Just need to check that
+ # we have no `kwargs` after this while loop
+ break
+ else:
+ if param.kind == _VAR_POSITIONAL:
+ # That's OK, just empty *args. Let's start parsing
+ # kwargs
+ break
+ elif param.name in kwargs:
+ if param.kind == _POSITIONAL_ONLY:
+ msg = '{arg!r} parameter is positional only, ' \
+ 'but was passed as a keyword'
+ msg = msg.format(arg=param.name)
+ raise TypeError(msg)
+ parameters_ex = (param,)
+ break
+ elif (param.kind == _VAR_KEYWORD or
+ param.default is not _empty):
+ # That's fine too - we have a default value for this
+ # parameter. So, lets start parsing `kwargs`, starting
+ # with the current parameter
+ parameters_ex = (param,)
+ break
+ else:
+ if partial:
+ parameters_ex = (param,)
+ break
+ else:
+ msg = '{arg!r} parameter lacking default value'
+ msg = msg.format(arg=param.name)
+ raise TypeError(msg)
+ else:
+ # We have a positional argument to process
+ try:
+ param = next(parameters)
+ except StopIteration:
+ raise TypeError('too many positional arguments')
+ else:
+ if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
+ # Looks like we have no parameter for this positional
+ # argument
+ raise TypeError('too many positional arguments')
+
+ if param.kind == _VAR_POSITIONAL:
+ # We have an '*args'-like argument, let's fill it with
+ # all positional arguments we have left and move on to
+ # the next phase
+ values = [arg_val]
+ values.extend(arg_vals)
+ arguments[param.name] = tuple(values)
+ break
+
+ if param.name in kwargs:
+ raise TypeError('multiple values for argument '
+ '{arg!r}'.format(arg=param.name))
+
+ arguments[param.name] = arg_val
+
+ # Now, we iterate through the remaining parameters to process
+ # keyword arguments
+ kwargs_param = None
+ for param in itertools.chain(parameters_ex, parameters):
+ if param.kind == _POSITIONAL_ONLY:
+ # This should never happen in case of a properly built
+ # Signature object (but let's have this check here
+ # to ensure correct behaviour just in case)
+ raise TypeError('{arg!r} parameter is positional only, '
+ 'but was passed as a keyword'.
+ format(arg=param.name))
+
+ if param.kind == _VAR_KEYWORD:
+ # Memorize that we have a '**kwargs'-like parameter
+ kwargs_param = param
+ continue
+
+ param_name = param.name
+ try:
+ arg_val = kwargs.pop(param_name)
+ except KeyError:
+ # We have no value for this parameter. It's fine though,
+ # if it has a default value, or it is an '*args'-like
+ # parameter, left alone by the processing of positional
+ # arguments.
+ if (not partial and param.kind != _VAR_POSITIONAL and
+ param.default is _empty):
+ raise TypeError('{arg!r} parameter lacking default value'.
+ format(arg=param_name))
+
+ else:
+ arguments[param_name] = arg_val
+
+ if kwargs:
+ if kwargs_param is not None:
+ # Process our '**kwargs'-like parameter
+ arguments[kwargs_param.name] = kwargs
+ else:
+ raise TypeError('too many keyword arguments')
+
+ return self._bound_arguments_cls(self, arguments)
+
+ def bind(self, *args, **kwargs):
+ '''Get a BoundArguments object, that maps the passed `args`
+ and `kwargs` to the function's signature. Raises `TypeError`
+ if the passed arguments can not be bound.
+ '''
+ return self._bind(args, kwargs)
+
+ def bind_partial(self, *args, **kwargs):
+ '''Get a BoundArguments object, that partially maps the
+ passed `args` and `kwargs` to the function's signature.
+ Raises `TypeError` if the passed arguments can not be bound.
+ '''
+ return self._bind(args, kwargs, partial=True)
+
+ def __str__(self):
+ result = []
+ render_kw_only_separator = True
+ for idx, param in enumerate(self.parameters.values()):
+ formatted = str(param)
+
+ kind = param.kind
+ if kind == _VAR_POSITIONAL:
+ # OK, we have an '*args'-like parameter, so we won't need
+ # a '*' to separate keyword-only arguments
+ render_kw_only_separator = False
+ elif kind == _KEYWORD_ONLY and render_kw_only_separator:
+ # We have a keyword-only parameter to render and we haven't
+ # rendered an '*args'-like parameter before, so add a '*'
+ # separator to the parameters list ("foo(arg1, *, arg2)" case)
+ result.append('*')
+ # This condition should be only triggered once, so
+ # reset the flag
+ render_kw_only_separator = False
+
+ result.append(formatted)
+
+ rendered = '({0})'.format(', '.join(result))
+
+ if self.return_annotation is not _empty:
+ anno = formatannotation(self.return_annotation)
+ rendered += ' -> {0}'.format(anno)
+
+ return rendered
diff --git a/ot/gromov.py b/ot/gromov.py
index 2a23873..65b2e29 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -595,7 +595,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
gw, logv = entropic_gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
- log['T'] = gw
+ logv['T'] = gw
if log:
return logv['gw_dist'], logv
@@ -613,7 +613,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
The function solves the following optimization problem:
.. math::
- C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+ C = argmin_C\in R^{NxN} \sum_s \lambda_s GW(C,Cs,p,ps)
Where :
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 5c09da2..6371feb 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -107,7 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False):
return G
-def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False):
+def emd2(a, b, M, processes=multiprocessing.cpu_count(),
+ numItermax=100000, log=False, return_matrix=False):
"""Solves the Earth Movers distance problem and returns the loss
.. math::
diff --git a/ot/optim.py b/ot/optim.py
index 1d09adc..f31fae2 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -15,7 +15,8 @@ from .bregman import sinkhorn
# The corresponding scipy function does not work for matrices
-def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99):
+def line_search_armijo(f, xk, pk, gfk, old_fval,
+ args=(), c1=1e-4, alpha0=0.99):
"""
Armijo linesearch function that works with matrices
@@ -71,7 +72,8 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99):
return alpha, fc[0], phi1
-def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False, log=False):
+def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
+ stopThr=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with conditional gradient
@@ -202,7 +204,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False
return G
-def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
+def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
+ numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
diff --git a/ot/utils.py b/ot/utils.py
index 9eab3fc..17983f2 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,10 @@ import numpy as np
from scipy.spatial.distance import cdist
import sys
import warnings
-
+try:
+ from inspect import signature
+except ImportError:
+ from .externals.funcsigs import signature
__time_tic_toc = time.time()
@@ -316,7 +319,7 @@ def _is_deprecated(func):
closures = []
is_deprecated = ('deprecated' in ''.join([c.cell_contents
for c in closures
- if isinstance(c.cell_contents, str)]))
+ if isinstance(c.cell_contents, str)]))
return is_deprecated
@@ -335,10 +338,7 @@ class BaseEstimator(object):
@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator"""
- try:
- from inspect import signature
- except ImportError:
- from .externals.funcsigs import signature
+
# fetch the constructor or the original constructor before
# deprecation wrapping if any
init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
diff --git a/test/test_da.py b/test/test_da.py
index 593dc53..3022721 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -326,8 +326,8 @@ def test_mapping_transport_class():
"""test_mapping_transport
"""
- ns = 150
- nt = 200
+ ns = 60
+ nt = 120
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
@@ -444,6 +444,48 @@ def test_mapping_transport_class():
assert len(otda.log_.keys()) != 0
+def test_linear_mapping():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ A, b = ot.da.OT_mapping_linear(Xs, Xt)
+
+ Xst = Xs.dot(A) + b
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_linear_mapping_class():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ otmap = ot.da.LinearTransport()
+
+ otmap.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otmap, "A_")
+ assert hasattr(otmap, "B_")
+ assert hasattr(otmap, "A1_")
+ assert hasattr(otmap, "B1_")
+
+ Xst = otmap.transform(Xs=Xs)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
def test_otda():
n_samples = 150 # nb samples
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 625e62a..bb23469 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -36,6 +36,18 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_entropic_gromov():
n_samples = 50 # nb samples
@@ -64,3 +76,70 @@ def test_entropic_gromov():
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_gromov_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_gromov_entropic_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
diff --git a/test/test_utils.py b/test/test_utils.py
index 1bd37cd..b524ef6 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -7,6 +7,7 @@
import ot
import numpy as np
+import sys
def test_parmap():
@@ -123,3 +124,79 @@ def test_clean_zeros():
assert len(a) == n - nz
assert len(b) == n - nz2
+
+
+def test_cost_normalization():
+
+ C = np.random.rand(10, 10)
+
+ # does nothing
+ M0 = ot.utils.cost_normalization(C)
+ np.testing.assert_allclose(C, M0)
+
+ M = ot.utils.cost_normalization(C, 'median')
+ np.testing.assert_allclose(np.median(M), 1)
+
+ M = ot.utils.cost_normalization(C, 'max')
+ np.testing.assert_allclose(M.max(), 1)
+
+ M = ot.utils.cost_normalization(C, 'log')
+ np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+
+ M = ot.utils.cost_normalization(C, 'loglog')
+ np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+
+
+def test_check_params():
+
+ res1 = ot.utils.check_params(first='OK', second=20)
+ assert res1 is True
+
+ res0 = ot.utils.check_params(first='OK', second=None)
+ assert res0 is False
+
+
+def test_deprecated_func():
+
+ @ot.utils.deprecated('deprecated text for fun')
+ def fun():
+ pass
+
+ def fun2():
+ pass
+
+ @ot.utils.deprecated('deprecated text for class')
+ class Class():
+ pass
+
+ if sys.version_info < (3, 5):
+ print('Not tested')
+ else:
+ assert ot.utils._is_deprecated(fun) is True
+
+ assert ot.utils._is_deprecated(fun2) is False
+
+
+def test_BaseEstimator():
+
+ class Class(ot.utils.BaseEstimator):
+
+ def __init__(self, first='spam', second='eggs'):
+
+ self.first = first
+ self.second = second
+
+ cl = Class()
+
+ names = cl._get_param_names()
+ assert 'first' in names
+ assert 'second' in names
+
+ params = cl.get_params()
+ assert 'first' in params
+ assert 'second' in params
+
+ params['first'] = 'spam again'
+ cl.set_params(**params)
+
+ assert cl.first == 'spam again'