summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md6
-rwxr-xr-xdata/cross.pngbin0 -> 230 bytes
-rwxr-xr-xdata/square.pngbin0 -> 168 bytes
-rwxr-xr-xdata/star.pngbin0 -> 225 bytes
-rwxr-xr-xdata/triangle.pngbin0 -> 254 bytes
-rw-r--r--examples/da/plot_otda_semi_supervised.py147
-rw-r--r--examples/plot_gromov.py90
-rwxr-xr-xexamples/plot_gromov_barycenter.py248
-rw-r--r--ot/__init__.py6
-rw-r--r--ot/da.py74
-rw-r--r--ot/gromov.py472
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp74
-rw-r--r--ot/lp/__init__.py80
-rw-r--r--ot/lp/emd_wrap.pyx100
-rw-r--r--ot/lp/network_simplex_simple.h98
-rw-r--r--test/test_da.py326
-rw-r--r--test/test_gromov.py37
-rw-r--r--test/test_ot.py114
19 files changed, 1520 insertions, 357 deletions
diff --git a/README.md b/README.md
index adccae7..d340068 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ It provides the following solvers:
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
* Joint OT matrix and mapping estimation [8].
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
-
+* Gromov-Wasserstein distances and barycenters [12]
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
@@ -138,12 +138,12 @@ The contributors to this library are:
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1)
* [Stanislas Chambon](https://slasnista.github.io/)
+* [Antoine Rolet](https://arolet.github.io/)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
-* [Antoine Rolet](https://arolet.github.io/) ( Mex file for EMD )
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
@@ -184,3 +184,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816.
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.
diff --git a/data/cross.png b/data/cross.png
new file mode 100755
index 0000000..1c1a068
--- /dev/null
+++ b/data/cross.png
Binary files differ
diff --git a/data/square.png b/data/square.png
new file mode 100755
index 0000000..45ff0ef
--- /dev/null
+++ b/data/square.png
Binary files differ
diff --git a/data/star.png b/data/star.png
new file mode 100755
index 0000000..3f511a6
--- /dev/null
+++ b/data/star.png
Binary files differ
diff --git a/data/triangle.png b/data/triangle.png
new file mode 100755
index 0000000..ca36d09
--- /dev/null
+++ b/data/triangle.png
Binary files differ
diff --git a/examples/da/plot_otda_semi_supervised.py b/examples/da/plot_otda_semi_supervised.py
new file mode 100644
index 0000000..8095c4d
--- /dev/null
+++ b/examples/da/plot_otda_semi_supervised.py
@@ -0,0 +1,147 @@
+# -*- coding: utf-8 -*-
+"""
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# generate data
+##############################################################################
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
+
+
+##############################################################################
+# Transport source samples onto target samples
+##############################################################################
+
+# unsupervised domain adaptation
+ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+# semi-supervised domain adaptation
+ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+# semi supervised DA uses available labaled target samples to modify the cost
+# matrix involved in the OT problem. The cost of transporting a source sample
+# of class A onto a target sample of class B != A is set to infinite, or a
+# very large value
+
+# note that in the present case we consider that all the target samples are
+# labeled. For daily applications, some target sample might not have labels,
+# in this case the element of yt corresponding to these samples should be
+# filled with -1.
+
+# Warning: we recall that -1 cannot be used as a class label
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+##############################################################################
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - unsupervised DA')
+
+pl.subplot(2, 2, 4)
+pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - semisupervised DA')
+
+pl.tight_layout()
+
+# the optimal coupling in the semi-supervised DA case will exhibit " shape
+# similar" to the cost matrix, (block diagonal matrix)
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+##############################################################################
+
+pl.figure(2, figsize=(8, 4))
+
+pl.subplot(1, 2, 1)
+pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nUnsupervised DA')
+
+pl.subplot(1, 2, 2)
+pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSemi-supervised DA')
+
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+##############################################################################
+
+# display transported samples
+pl.figure(4, figsize=(8, 4))
+pl.subplot(1, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
new file mode 100644
index 0000000..dce66c4
--- /dev/null
+++ b/examples/plot_gromov.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Gromov-Wasserstein example
+==========================
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+
+import ot
+
+
+"""
+Sample two Gaussian distributions (2D and 3D)
+=============================================
+The Gromov-Wasserstein distance allows to compute distances with samples that
+do not belong to the same metric space. For demonstration purpose, we sample
+two Gaussian distributions in 2- and 3-dimensional spaces.
+"""
+
+n_samples = 30 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+
+"""
+Plotting the distributions
+==========================
+"""
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+"""
+Compute distance kernels, normalize them and then display
+=========================================================
+"""
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+C1 /= C1.max()
+C2 /= C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+"""
+Compute Gromov-Wasserstein plans and distance
+=============================================
+"""
+
+p = ot.unif(n_samples)
+q = ot.unif(n_samples)
+
+gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+
+pl.figure()
+pl.imshow(gw, cmap='jet')
+pl.colorbar()
+pl.show()
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
new file mode 100755
index 0000000..52f4966
--- /dev/null
+++ b/examples/plot_gromov_barycenter.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+
+import scipy.ndimage as spi
+import matplotlib.pylab as pl
+from sklearn import manifold
+from sklearn.decomposition import PCA
+
+import ot
+
+"""
+
+Smacof MDS
+==========
+This function allows to find an embedding of points given a dissimilarity matrix
+that will be given by the output of the algorithm
+"""
+
+
+def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+"""
+Data preparation
+================
+The four distributions are constructed from 4 simple images
+"""
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+shapes = [square, cross, triangle, star]
+
+S = 4
+xs = [[] for i in range(S)]
+
+
+for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+
+"""
+Barycenter computation
+======================
+The four distributions are constructed from 4 simple images
+"""
+ns = [len(xs[s]) for s in range(S)]
+n_samples = 30
+
+"""Compute all distances matrices for the four shapes"""
+Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+Cs = [cs / cs.max() for cs in Cs]
+
+ps = [ot.unif(ns[s]) for s in range(S)]
+p = ot.unif(n_samples)
+
+
+lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+Ct01 = [0 for i in range(2)]
+for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+Ct02 = [0 for i in range(2)]
+for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+Ct13 = [0 for i in range(2)]
+for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+Ct23 = [0 for i in range(2)]
+for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+"""
+Visualization
+=============
+"""
+
+"""The PCA helps in getting consistency between the rotations"""
+
+clf = PCA(n_components=2)
+npos = [0, 0, 0, 0]
+npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+npost01 = [0, 0]
+npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+npost02 = [0, 0]
+npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+npost13 = [0, 0]
+npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+npost23 = [0, 0]
+npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+fig = pl.figure(figsize=(10, 10))
+
+ax1 = pl.subplot2grid((4, 4), (0, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ax2 = pl.subplot2grid((4, 4), (0, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ax3 = pl.subplot2grid((4, 4), (0, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ax4 = pl.subplot2grid((4, 4), (0, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ax5 = pl.subplot2grid((4, 4), (1, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ax6 = pl.subplot2grid((4, 4), (1, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ax7 = pl.subplot2grid((4, 4), (2, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ax8 = pl.subplot2grid((4, 4), (2, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ax9 = pl.subplot2grid((4, 4), (3, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ax10 = pl.subplot2grid((4, 4), (3, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ax11 = pl.subplot2grid((4, 4), (3, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ax12 = pl.subplot2grid((4, 4), (3, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/ot/__init__.py b/ot/__init__.py
index 6d4c4c6..a295e1b 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -5,6 +5,7 @@
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -17,11 +18,13 @@ from . import utils
from . import datasets
from . import plot
from . import da
+from . import gromov
# OT functions
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
+from .gromov import gromov_wasserstein, gromov_wasserstein2
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -30,4 +33,5 @@ __version__ = "0.3.1"
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
- 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
+ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
+ 'gromov_wasserstein','gromov_wasserstein2']
diff --git a/ot/da.py b/ot/da.py
index 564c7b7..1d3d0ba 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -966,8 +966,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -989,7 +993,7 @@ class BaseTransport(BaseEstimator):
# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
- classes = np.unique(ys)
+ classes = [c for c in np.unique(ys) if c != -1]
for c in classes:
idx_s = np.where((ys != c) & (ys != -1))
idx_t = np.where(yt == c)
@@ -1023,8 +1027,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1045,8 +1053,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
@@ -1110,8 +1122,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
@@ -1241,8 +1257,12 @@ class SinkhornTransport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1333,8 +1353,12 @@ class EMDTransport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1434,8 +1458,12 @@ class SinkhornLpl1Transport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1545,8 +1573,12 @@ class SinkhornL1l2Transport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1662,8 +1694,12 @@ class MappingTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
diff --git a/ot/gromov.py b/ot/gromov.py
new file mode 100644
index 0000000..20bf7ee
--- /dev/null
+++ b/ot/gromov.py
@@ -0,0 +1,472 @@
+
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein transport method
+===================================
+
+
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+from .bregman import sinkhorn
+from .utils import dist
+
+
+def square_loss(a, b):
+ """
+ Returns the value of L(a,b)=(1/2)*|a-b|^2
+ """
+
+ return 0.5 * (a - b)**2
+
+
+def kl_loss(a, b):
+ """
+ Returns the value of L(a,b)=a*log(a/b)-a+b
+ """
+
+ return a * np.log(a / b) - a + b
+
+
+def tensor_square_loss(C1, C2, T):
+ """
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the square loss
+ function as the loss function of Gromow-Wasserstein discrepancy.
+
+ Where :
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ T : A coupling between those two spaces
+
+ The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ f1(a)=(a^2)/2
+ f2(b)=(b^2)/2
+ h1(a)=a
+ h2(b)=b
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ T : ndarray, shape (ns, nt)
+ Coupling between source and target spaces
+
+ Returns
+ -------
+ tens : ndarray, shape (ns, nt)
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+ T = np.asarray(T, dtype=np.float64)
+
+ def f1(a):
+ return (a**2) / 2
+
+ def f2(b):
+ return (b**2) / 2
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return b
+
+ tens = -np.dot(h1(C1), T).dot(h2(C2).T)
+ tens -= tens.min()
+
+ return tens
+
+
+def tensor_kl_loss(C1, C2, T):
+ """
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the square loss
+ function as the loss function of Gromow-Wasserstein discrepancy.
+
+ Where :
+
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ T : A coupling between those two spaces
+
+ The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ f1(a)=a*log(a)-a
+ f2(b)=b
+ h1(a)=a
+ h2(b)=log(b)
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ T : ndarray, shape (ns, nt)
+ Coupling between source and target spaces
+
+ Returns
+ -------
+ tens : ndarray, shape (ns, nt)
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+ T = np.asarray(T, dtype=np.float64)
+
+ def f1(a):
+ return a * np.log(a + 1e-15) - a
+
+ def f2(b):
+ return b
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return np.log(b + 1e-15)
+
+ tens = -np.dot(h1(C1), T).dot(h2(C2).T)
+ tens -= tens.min()
+
+ return tens
+
+
+def update_square_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices
+
+ Returns
+ ----------
+ C : ndarray, shape (nt,nt)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_kl_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ weights in the targeted barycenter
+ lambdas : list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices
+
+ Returns
+ ----------
+ C : ndarray, shape (ns,ns)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.exp(np.divide(tmpsum, ppt))
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein coupling between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ s.t. \GW 1 = p
+
+ \GW^T 1= q
+
+ \GW\geq 0
+
+ Where :
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ p : distribution in the source space
+ q : distribution in the target space
+ L : loss function to account for the misfit between the similarity matrices
+ H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ coupling between the two spaces that minimizes :
+ \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+
+ T = np.outer(p, q) # Initialization
+
+ cpt = 0
+ err = 1
+
+ while (err > tol and cpt < max_iter):
+
+ Tprev = T
+
+ if loss_fun == 'square_loss':
+ tens = tensor_square_loss(C1, C2, T)
+
+ elif loss_fun == 'kl_loss':
+ tens = tensor_kl_loss(C1, C2, T)
+
+ T = sinkhorn(p, q, tens, epsilon)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ return T, log
+ else:
+ return T
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \GW_Dist = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ Where :
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ p : distribution in the source space
+ q : distribution in the target space
+ L : loss function to account for the misfit between the similarity matrices
+ H : entropy
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ """
+
+ if log:
+ gw, logv = gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log)
+ else:
+ gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
+ epsilon, max_iter, tol, verbose, log)
+
+ if loss_fun == 'square_loss':
+ gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
+
+ elif loss_fun == 'kl_loss':
+ gw_dist = np.sum(gw * tensor_kl_loss(C1, C2, gw))
+
+ if log:
+ return gw_dist, logv
+ else:
+ return gw_dist
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem:
+
+ .. math::
+ C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+
+ Where :
+
+ Cs : metric cost matrix
+ ps : distribution
+
+ Parameters
+ ----------
+ N : Integer
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray(ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray(ns,)
+ sample weights in the S spaces
+ p : ndarray, shape(N,)
+ weights in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ loss_fun : tensor-matrix multiplication function based on specific loss function
+ update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ epsilon : float
+ Regularization term >0
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ init_C : bool, ndarray, shape(N,N)
+ random initial value for the C matrix provided by user
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+ """
+
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while(err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+ max_iter, 1e-5, verbose, log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index aa92441..f42e222 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -26,9 +26,10 @@ typedef unsigned int node_id_type;
enum ProblemType {
INFEASIBLE,
OPTIMAL,
- UNBOUNDED
+ UNBOUNDED,
+ MAX_ITER_REACHED
};
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
#endif
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index c8c2eb3..fc7ca63 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -15,62 +15,60 @@
#include "EMD.h"
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) {
+int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int maxIter) {
// beware M and C anre strored in row major C style!!!
- int n, m, i,cur;
- double max;
+ int n, m, i, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
// Get the number of non zero coordinates for r and c
n=0;
- for (node_id_type i=0; i<n1; i++) {
+ for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
n++;
- }
+ }else if(val<0){
+ return INFEASIBLE;
+ }
}
m=0;
- for (node_id_type i=0; i<n2; i++) {
+ for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
m++;
- }
+ }else if(val<0){
+ return INFEASIBLE;
+ }
}
-
// Define the graph
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
// Set supply and demand, don't account for 0 values (faster)
- max=0;
cur=0;
- for (node_id_type i=0; i<n1; i++) {
+ for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
- weights1[ di.nodeFromId(cur) ] = val;
- max+=val;
+ weights1[ cur ] = val;
indI[cur++]=i;
}
}
// Demand is actually negative supply...
- max=0;
cur=0;
- for (node_id_type i=0; i<n2; i++) {
+ for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
- weights2[ di.nodeFromId(cur) ] = -val;
+ weights2[ cur ] = -val;
indJ[cur++]=i;
-
- max-=val;
}
}
@@ -78,14 +76,10 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
- max=0;
- for (node_id_type i=0; i<n; i++) {
- for (node_id_type j=0; j<m; j++) {
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(i*m+j), val);
- if (val>max) {
- max=val;
- }
}
}
@@ -93,26 +87,20 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
// Solve the problem with the network simplex algorithm
int ret=net.run();
- if (ret!=(int)net.OPTIMAL) {
- if (ret==(int)net.INFEASIBLE) {
- std::cout << "Infeasible problem";
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ int i = di.source(a);
+ int j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
}
- if (ret==(int)net.UNBOUNDED)
- {
- std::cout << "Unbounded problem";
- }
- } else
- {
- for (node_id_type i=0; i<n; i++)
- {
- for (node_id_type j=0; j<m; j++)
- {
- *(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j));
- }
- };
- *cost = net.totalCost();
-
- };
+
+ }
return ret;
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index de91e74..5c09da2 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -7,14 +7,16 @@ Solvers for the original linear program OT problem
#
# License: MIT License
+import multiprocessing
+
import numpy as np
+
# import compiled emd
-from .emd_wrap import emd_c, emd2_c
+from .emd_wrap import emd_c, check_result
from ..utils import parmap
-import multiprocessing
-def emd(a, b, M, numItermax=100000):
+def emd(a, b, M, numItermax=100000, log=False):
"""Solves the Earth Movers distance problem and returns the OT matrix
@@ -42,11 +44,17 @@ def emd(a, b, M, numItermax=100000):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation matrix.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
Examples
@@ -82,14 +90,24 @@ def emd(a, b, M, numItermax=100000):
# if empty array given then use unifor distributions
if len(a) == 0:
- a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
-
- return emd_c(a, b, M, numItermax)
-
-
-def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ result_code_string = check_result(result_code)
+ if log:
+ log = {}
+ log['cost'] = cost
+ log['u'] = u
+ log['v'] = v
+ log['warning'] = result_code_string
+ log['result_code'] = result_code
+ return G, log
+ return G
+
+
+def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False):
"""Solves the Earth Movers distance problem and returns the loss
.. math::
@@ -116,11 +134,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost and dual
+ variables. Otherwise returns only the optimal transportation cost.
+ return_matrix: boolean, optional (default=False)
+ If True, returns the optimal transportation matrix in the log.
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is true, a dictionary containing the cost and dual
+ variables and exit status
Examples
@@ -156,17 +182,31 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
# if empty array given then use unifor distributions
if len(a) == 0:
- a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- if len(b.shape) == 1:
- return emd2_c(a, b, M, numItermax)
+ if log or return_matrix:
+ def f(b):
+ G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
+ result_code_string = check_result(resultCode)
+ log = {}
+ if return_matrix:
+ log['G'] = G
+ log['u'] = u
+ log['v'] = v
+ log['warning'] = result_code_string
+ log['result_code'] = resultCode
+ return [cost, log]
else:
- nb = b.shape[1]
- # res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
-
def f(b):
- return emd2_c(a, b, M, numItermax)
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
- return np.array(res)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ check_result(result_code)
+ return cost
+
+ if len(b.shape) == 1:
+ return f(b)
+ nb = b.shape[1]
+
+ res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ return res
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 26d3330..83ee6aa 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -12,81 +12,33 @@ cimport numpy as np
cimport cython
+import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter)
- cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter)
+ cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
+def check_result(result_code):
+ if result_code == OPTIMAL:
+ return None
-@cython.boundscheck(False)
-@cython.wraparound(False)
-def emd_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
- """
- Solves the Earth Movers distance problem and returns the optimal transport matrix
-
- gamm=emd(a,b,M)
-
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
-
- s.t. \gamma 1 = a
-
- \gamma^T 1= b
-
- \gamma\geq 0
- where :
+ if result_code == INFEASIBLE:
+ message = "Problem infeasible. Check that a and b are in the simplex"
+ elif result_code == UNBOUNDED:
+ message = "Problem unbounded"
+ elif result_code == MAX_ITER_REACHED:
+ message = "numItermax reached before optimality. Try to increase numItermax."
+ warnings.warn(message)
+ return message
- - M is the metric cost matrix
- - a and b are the sample weights
-
- Parameters
- ----------
- a : (ns,) ndarray, float64
- source histogram
- b : (nt,) ndarray, float64
- target histogram
- M : (ns,nt) ndarray, float64
- loss matrix
- max_iter : int
- The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
-
-
- Returns
- -------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
-
- """
- cdef int n1= M.shape[0]
- cdef int n2= M.shape[1]
-
- cdef float cost=0
- cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
-
- if not len(a):
- a=np.ones((n1,))/n1
-
- if not len(b):
- b=np.ones((n2,))/n2
-
- # calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- print("Problem infeasible. Try to increase numItermax.")
- elif resultSolver == UNBOUNDED:
- print("Problem unbounded")
-
- return G
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mode="c"] b,np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
"""
- Solves the Earth Movers distance problem and returns the optimal transport loss
+ Solves the Earth Movers distance problem and returns the optimal transport matrix
gamm=emd(a,b,M)
@@ -125,8 +77,11 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
cdef int n1= M.shape[0]
cdef int n2= M.shape[1]
- cdef float cost=0
+ cdef double cost=0
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([n1, n2])
+ cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1)
+ cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2)
+
if not len(a):
a=np.ones((n1,))/n1
@@ -135,17 +90,6 @@ def emd2_c( np.ndarray[double, ndim=1, mode="c"] a,np.ndarray[double, ndim=1, mo
b=np.ones((n2,))/n2
# calling the function
- cdef int resultSolver = EMD_wrap(n1,n2,<double*> a.data,<double*> b.data,<double*> M.data,<double*> G.data,<double*> &cost, max_iter)
- if resultSolver != OPTIMAL:
- if resultSolver == INFEASIBLE:
- print("Problem infeasible. Try to inscrease numItermax.")
- elif resultSolver == UNBOUNDED:
- print("Problem unbounded")
-
- cost=0
- for i in range(n1):
- for j in range(n2):
- cost+=G[i,j]*M[i,j]
-
- return cost
+ cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ return G, cost, alpha, beta, result_code
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 64856a0..7c6a4ce 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -28,7 +28,14 @@
#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H
#define LEMON_NETWORK_SIMPLEX_SIMPLE_H
#define DEBUG_LVL 0
-#define EPSILON 10*2.2204460492503131e-016
+
+#if DEBUG_LVL>0
+#include <iomanip>
+#endif
+
+
+#define EPSILON 2.2204460492503131e-15
+#define _EPSILON 1e-8
#define MAX_DEBUG_ITER 100000
@@ -43,6 +50,7 @@
#include <vector>
#include <limits>
#include <algorithm>
+#include <cstdio>
#ifdef HASHMAP
#include <hash_map>
#else
@@ -220,7 +228,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,double maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -253,7 +261,9 @@ namespace lemon {
/// The objective function of the problem is unbounded, i.e.
/// there is a directed cycle having negative total cost and
/// infinite upper bound.
- UNBOUNDED
+ UNBOUNDED,
+ /// The maximum number of iteration has been reached
+ MAX_ITER_REACHED
};
/// \brief Constants for selecting the type of the supply constraints.
@@ -278,7 +288,7 @@ namespace lemon {
private:
- double max_iter;
+ int max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
@@ -676,14 +686,12 @@ namespace lemon {
/// \see resetParams(), reset()
ProblemType run() {
#if DEBUG_LVL>0
- mexPrintf("OPTIMAL = %d\nINFEASIBLE = %d\nUNBOUNDED = %d\n",OPTIMAL,INFEASIBLE,UNBOUNDED);
- mexEvalString("drawnow;");
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED\n";
#endif
if (!init()) return INFEASIBLE;
#if DEBUG_LVL>0
- mexPrintf("Init done, starting iterations\n");
- mexEvalString("drawnow;");
+ std::cout << "Init done, starting iterations\n";
#endif
return start();
}
@@ -936,15 +944,15 @@ namespace lemon {
// Initialize internal data structures
bool init() {
if (_node_num == 0) return false;
- /*
+
// Check the sum of supply values
_sum_supply = 0;
for (int i = 0; i != _node_num; ++i) {
_sum_supply += _supply[i];
}
- if ( !((_stype == GEQ && _sum_supply <= _epsilon ) ||
- (_stype == LEQ && _sum_supply >= -_epsilon )) ) return false;
- */
+ if ( fabs(_sum_supply) > _EPSILON ) return false;
+
+ _sum_supply = 0;
// Initialize artifical cost
Cost ART_COST;
@@ -1411,27 +1419,26 @@ namespace lemon {
ProblemType start() {
PivotRuleImpl pivot(*this);
double prevCost=-1;
+ ProblemType retVal = OPTIMAL;
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
-#if DEBUG_LVL>0
- int niter=0;
-#endif
int iter_number=0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
- if(++iter_number>=max_iter&&max_iter>0){
+ if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
char errMess[1000];
- // sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher",iter_number );
- // mexWarnMsgTxt(errMess);
+ sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
+ std::cerr << errMess;
+ retVal = MAX_ITER_REACHED;
break;
}
#if DEBUG_LVL>0
- if(niter>MAX_DEBUG_ITER)
+ if(iter_number>MAX_DEBUG_ITER)
break;
- if(++niter%1000==0||niter%1000==1){
+ if(iter_number%1000==0||iter_number%1000==1){
double curCost=totalCost();
double sumFlow=0;
double a;
@@ -1440,12 +1447,13 @@ namespace lemon {
for (int i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
- mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f\n",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a));
- mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc]));
- mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]);
-
- mexPrintf("%.30f\n%.30f\n%.30f\n%.30f\n%",_cost[in_arc],_pi[_source[in_arc]],_pi[_target[in_arc]],a);
- mexEvalString("drawnow;");
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+ std::cout << _cost[in_arc] << "\n";
+ std::cout << _pi[_source[in_arc]] << "\n";
+ std::cout << _pi[_target[in_arc]] << "\n";
+ std::cout << a << "\n";
}
#endif
@@ -1459,11 +1467,11 @@ namespace lemon {
}
#if DEBUG_LVL>0
else{
- mexPrintf("No change\n");
+ std::cout << "No change\n";
}
#endif
#if DEBUG_LVL>1
- mexPrintf("Arc in = (%d,%d)\n",_source[in_arc],_target[in_arc]);
+ std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n";
#endif
}
@@ -1478,34 +1486,36 @@ namespace lemon {
for (int i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
- mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\nReduced cost=%.30f\nPrecision =%.30f",sumFlow,niter, curCost,_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]), -EPSILON*(a));
- mexPrintf("Arc in = (%d,%d)\n",_node_id(_source[in_arc]),_node_id(_target[in_arc]));
- mexPrintf("Supplies = (%f,%f)\n",_supply[_source[in_arc]],_supply[_target[in_arc]]);
+
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
- mexEvalString("drawnow;");
#endif
#if DEBUG_LVL>1
- double sumFlow=0;
+ sumFlow=0;
for (int i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
if (_state[i]==STATE_TREE) {
- mexPrintf("Non zero value at (%d,%d)\n",_node_num+1-_source[i],_node_num+1-_target[i]);
+ std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n";
}
}
- mexPrintf("Sum of the flow %.100f\n%d iterations, current cost=%.20f\n",sumFlow,niter, totalCost());
- mexEvalString("drawnow;");
+ std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n";
#endif
// Check feasibility
- for (int e = _search_arc_num; e != _all_arc_num; ++e) {
- if (_flow[e] != 0){
- if (abs(_flow[e]) > EPSILON)
- return INFEASIBLE;
- else
- _flow[e]=0;
+ if( retVal == OPTIMAL){
+ for (int e = _search_arc_num; e != _all_arc_num; ++e) {
+ if (_flow[e] != 0){
+ if (abs(_flow[e]) > EPSILON)
+ return INFEASIBLE;
+ else
+ _flow[e]=0;
+ }
}
- }
+ }
// Shift potentials to meet the requirements of the GEQ/LEQ type
// optimality conditions
@@ -1531,7 +1541,7 @@ namespace lemon {
}
}
- return OPTIMAL;
+ return retVal;
}
}; //class NetworkSimplexSimple
diff --git a/test/test_da.py b/test/test_da.py
index 104a798..593dc53 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -22,60 +22,68 @@ def test_sinkhorn_lpl1_transport_class():
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
- clf = ot.da.SinkhornLpl1Transport()
+ otda = ot.da.SinkhornLpl1Transport()
# test its computed
- clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- assert hasattr(clf, "cost_")
- assert hasattr(clf, "coupling_")
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
# test dimensions of coupling
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
Xs_new, _ = get_data_classif('3gauss', ns + 1)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
- transp_Xt = clf.inverse_transform(Xt=Xt)
+ transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
- transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
- transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
+ transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
- # test semi supervised mode
- clf = ot.da.SinkhornLpl1Transport()
- clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(clf.cost_)
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornLpl1Transport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
- # test semi supervised mode
- clf = ot.da.SinkhornLpl1Transport()
- clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.cost_)
+ otda_semi = ot.da.SinkhornLpl1Transport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+ # check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ assert mass_semi == 0, "semisupervised mode not working"
+
def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
@@ -87,65 +95,75 @@ def test_sinkhorn_l1l2_transport_class():
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
- clf = ot.da.SinkhornL1l2Transport()
+ otda = ot.da.SinkhornL1l2Transport()
# test its computed
- clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- assert hasattr(clf, "cost_")
- assert hasattr(clf, "coupling_")
- assert hasattr(clf, "log_")
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
# test dimensions of coupling
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
Xs_new, _ = get_data_classif('3gauss', ns + 1)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
- transp_Xt = clf.inverse_transform(Xt=Xt)
+ transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
- transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
- transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
+ transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
- # test semi supervised mode
- clf = ot.da.SinkhornL1l2Transport()
- clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(clf.cost_)
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornL1l2Transport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
- # test semi supervised mode
- clf = ot.da.SinkhornL1l2Transport()
- clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.cost_)
+ otda_semi = ot.da.SinkhornL1l2Transport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+ # check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
+ assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ rtol=1e-9, atol=1e-9)
+
# check everything runs well with log=True
- clf = ot.da.SinkhornL1l2Transport(log=True)
- clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- assert len(clf.log_.keys()) != 0
+ otda = ot.da.SinkhornL1l2Transport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
def test_sinkhorn_transport_class():
@@ -158,65 +176,73 @@ def test_sinkhorn_transport_class():
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
- clf = ot.da.SinkhornTransport()
+ otda = ot.da.SinkhornTransport()
# test its computed
- clf.fit(Xs=Xs, Xt=Xt)
- assert hasattr(clf, "cost_")
- assert hasattr(clf, "coupling_")
- assert hasattr(clf, "log_")
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
# test dimensions of coupling
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
Xs_new, _ = get_data_classif('3gauss', ns + 1)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
- transp_Xt = clf.inverse_transform(Xt=Xt)
+ transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
- transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
- transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
- # test semi supervised mode
- clf = ot.da.SinkhornTransport()
- clf.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(clf.cost_)
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
- # test semi supervised mode
- clf = ot.da.SinkhornTransport()
- clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.cost_)
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+ # check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ assert mass_semi == 0, "semisupervised mode not working"
+
# check everything runs well with log=True
- clf = ot.da.SinkhornTransport(log=True)
- clf.fit(Xs=Xs, ys=ys, Xt=Xt)
- assert len(clf.log_.keys()) != 0
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
def test_emd_transport_class():
@@ -229,60 +255,72 @@ def test_emd_transport_class():
Xs, ys = get_data_classif('3gauss', ns)
Xt, yt = get_data_classif('3gauss2', nt)
- clf = ot.da.EMDTransport()
+ otda = ot.da.EMDTransport()
# test its computed
- clf.fit(Xs=Xs, Xt=Xt)
- assert hasattr(clf, "cost_")
- assert hasattr(clf, "coupling_")
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
# test dimensions of coupling
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
Xs_new, _ = get_data_classif('3gauss', ns + 1)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# test inverse transform
- transp_Xt = clf.inverse_transform(Xt=Xt)
+ transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
- transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
assert_equal(transp_Xt_new.shape, Xt_new.shape)
# test fit_transform
- transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
assert_equal(transp_Xs.shape, Xs.shape)
- # test semi supervised mode
- clf = ot.da.EMDTransport()
- clf.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(clf.cost_)
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.EMDTransport()
+ otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
- # test semi supervised mode
- clf = ot.da.EMDTransport()
- clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
- assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(clf.cost_)
+ otda_semi = ot.da.EMDTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+ # check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check that the coupling forbids mass transport between labeled source
+ # and labeled target samples
+ mass_semi = np.sum(
+ otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
+ mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
+
+ # we need to use a small tolerance here, otherwise the test breaks
+ assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ rtol=1e-2, atol=1e-2)
+
def test_mapping_transport_class():
"""test_mapping_transport
@@ -300,47 +338,51 @@ def test_mapping_transport_class():
##########################################################################
# check computation and dimensions if bias == False
- clf = ot.da.MappingTransport(kernel="linear", bias=False)
- clf.fit(Xs=Xs, Xt=Xt)
- assert hasattr(clf, "coupling_")
- assert hasattr(clf, "mapping_")
- assert hasattr(clf, "log_")
+ otda = ot.da.MappingTransport(kernel="linear", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "mapping_")
+ assert hasattr(otda, "log_")
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# check computation and dimensions if bias == True
- clf = ot.da.MappingTransport(kernel="linear", bias=True)
- clf.fit(Xs=Xs, Xt=Xt)
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1])))
+ otda = ot.da.MappingTransport(kernel="linear", bias=True)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
@@ -350,52 +392,56 @@ def test_mapping_transport_class():
##########################################################################
# check computation and dimensions if bias == False
- clf = ot.da.MappingTransport(kernel="gaussian", bias=False)
- clf.fit(Xs=Xs, Xt=Xt)
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.mapping_.shape, ((Xs.shape[0], Xt.shape[1])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# check computation and dimensions if bias == True
- clf = ot.da.MappingTransport(kernel="gaussian", bias=True)
- clf.fit(Xs=Xs, Xt=Xt)
- assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(clf.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1])))
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=True)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
- assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ assert_allclose(
+ np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
# test transform
- transp_Xs = clf.transform(Xs=Xs)
+ transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- transp_Xs_new = clf.transform(Xs_new)
+ transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# check everything runs well with log=True
- clf = ot.da.MappingTransport(kernel="gaussian", log=True)
- clf.fit(Xs=Xs, Xt=Xt)
- assert len(clf.log_.keys()) != 0
+ otda = ot.da.MappingTransport(kernel="gaussian", log=True)
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
def test_otda():
@@ -424,7 +470,8 @@ def test_otda():
da_entrop.interp()
da_entrop.predict(xs)
- np.testing.assert_allclose(a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(
+ a, np.sum(da_entrop.G, 1), rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(b, np.sum(da_entrop.G, 0), rtol=1e-3, atol=1e-3)
# non-convex Group lasso regularization
@@ -458,12 +505,3 @@ def test_otda():
da_emd = ot.da.OTDA_mapping_kernel() # init class
da_emd.fit(xs, xt, numItermax=10) # fit distributions
da_emd.predict(xs) # interpolation of source samples
-
-
-# if __name__ == "__main__":
-
-# test_sinkhorn_transport_class()
-# test_emd_transport_class()
-# test_sinkhorn_l1l2_transport_class()
-# test_sinkhorn_lpl1_transport_class()
-# test_mapping_transport_class()
diff --git a/test/test_gromov.py b/test/test_gromov.py
new file mode 100644
index 0000000..e808292
--- /dev/null
+++ b/test/test_gromov.py
@@ -0,0 +1,37 @@
+"""Tests for module gromov """
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
diff --git a/test/test_ot.py b/test/test_ot.py
index acd8718..ea6d9dc 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -4,12 +4,15 @@
#
# License: MIT License
+import warnings
+
import numpy as np
+
import ot
+from ot.datasets import get_1D_gauss as gauss
def test_doctest():
-
import doctest
# test lp solver
@@ -66,9 +69,6 @@ def test_emd_empty():
def test_emd2_multi():
-
- from ot.datasets import get_1D_gauss as gauss
-
n = 1000 # nb bins
# bin positions
@@ -100,3 +100,109 @@ def test_emd2_multi():
ot.toc('multi proc : {} s')
np.testing.assert_allclose(emd1, emdn)
+
+ # emd loss multipro proc with log
+ ot.tic()
+ emdn = ot.emd2(a, b, M, log=True, return_matrix=True)
+ ot.toc('multi proc : {} s')
+
+ for i in range(len(emdn)):
+ emd = emdn[i]
+ log = emd[1]
+ cost = emd[0]
+ check_duality_gap(a, b[:, i], M, log['G'], log['u'], log['v'], cost)
+ emdn[i] = cost
+
+ emdn = np.array(emdn)
+ np.testing.assert_allclose(emd1, emdn)
+
+
+def test_warnings():
+ n = 100 # nb bins
+ m = 100 # nb bins
+
+ mean1 = 30
+ mean2 = 50
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+ y = np.arange(m, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=mean1, s=5) # m= mean, s= std
+
+ b = gauss(m, m=mean2, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
+
+ print('Computing {} EMD '.format(1))
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ print('Computing {} EMD '.format(1))
+ ot.emd(a, b, M, numItermax=1)
+ assert "numItermax" in str(w[-1].message)
+ assert len(w) == 1
+ a[0] = 100
+ print('Computing {} EMD '.format(2))
+ ot.emd(a, b, M)
+ assert "infeasible" in str(w[-1].message)
+ assert len(w) == 2
+ a[0] = -1
+ print('Computing {} EMD '.format(2))
+ ot.emd(a, b, M)
+ assert "infeasible" in str(w[-1].message)
+ assert len(w) == 3
+
+
+def test_dual_variables():
+ n = 5000 # nb bins
+ m = 6000 # nb bins
+
+ mean1 = 1000
+ mean2 = 1100
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+ y = np.arange(m, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=mean1, s=5) # m= mean, s= std
+
+ b = gauss(m, m=mean2, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((-1, 1)), y.reshape((-1, 1))) ** (1. / 2)
+
+ print('Computing {} EMD '.format(1))
+
+ # emd loss 1 proc
+ ot.tic()
+ G, log = ot.emd(a, b, M, log=True)
+ ot.toc('1 proc : {} s')
+
+ ot.tic()
+ G2 = ot.emd(b, a, np.ascontiguousarray(M.T))
+ ot.toc('1 proc : {} s')
+
+ cost1 = (G * M).sum()
+ # Check symmetry
+ np.testing.assert_array_almost_equal(cost1, (M * G2.T).sum())
+ # Check with closed-form solution for gaussians
+ np.testing.assert_almost_equal(cost1, np.abs(mean1 - mean2))
+
+ # Check that both cost computations are equivalent
+ np.testing.assert_almost_equal(cost1, log['cost'])
+ check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
+
+
+def check_duality_gap(a, b, M, G, u, v, cost):
+ cost_dual = np.vdot(a, u) + np.vdot(b, v)
+ # Check that dual and primal cost are equal
+ np.testing.assert_almost_equal(cost_dual, cost)
+
+ [ind1, ind2] = np.nonzero(G)
+
+ # Check that reduced cost is zero on transport arcs
+ np.testing.assert_array_almost_equal((M - u.reshape(-1, 1) - v.reshape(1, -1))[ind1, ind2],
+ np.zeros(ind1.size))