summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md4
-rwxr-xr-xdata/carre.pngbin0 -> 168 bytes
-rwxr-xr-xdata/coeur.pngbin0 -> 225 bytes
-rwxr-xr-xdata/rond.pngbin0 -> 230 bytes
-rwxr-xr-xdata/triangle.pngbin0 -> 254 bytes
-rw-r--r--examples/plot_gromov.py89
-rwxr-xr-xexamples/plot_gromov_barycenter.py241
-rw-r--r--ot/__init__.py6
-rw-r--r--ot/gromov.py474
-rw-r--r--test/test_gromov.py38
10 files changed, 850 insertions, 2 deletions
diff --git a/README.md b/README.md
index 33eea6e..22b20a4 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ It provides the following solvers:
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
* Joint OT matrix and mapping estimation [8].
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
-
+* Gromov-Wasserstein distances [12]
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
@@ -184,3 +184,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816.
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.
diff --git a/data/carre.png b/data/carre.png
new file mode 100755
index 0000000..45ff0ef
--- /dev/null
+++ b/data/carre.png
Binary files differ
diff --git a/data/coeur.png b/data/coeur.png
new file mode 100755
index 0000000..3f511a6
--- /dev/null
+++ b/data/coeur.png
Binary files differ
diff --git a/data/rond.png b/data/rond.png
new file mode 100755
index 0000000..1c1a068
--- /dev/null
+++ b/data/rond.png
Binary files differ
diff --git a/data/triangle.png b/data/triangle.png
new file mode 100755
index 0000000..ca36d09
--- /dev/null
+++ b/data/triangle.png
Binary files differ
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
new file mode 100644
index 0000000..9bbdbde
--- /dev/null
+++ b/examples/plot_gromov.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Gromov-Wasserstein example
+==========================
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+
+import ot
+
+
+"""
+Sample two Gaussian distributions (2D and 3D)
+=============================================
+The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
+For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
+"""
+
+n = 30 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n, 3).dot(P) + mu_t
+
+
+"""
+Plotting the distributions
+==========================
+"""
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+"""
+Compute distance kernels, normalize them and then display
+=========================================================
+"""
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+C1 /= C1.max()
+C2 /= C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+"""
+Compute Gromov-Wasserstein plans and distance
+=============================================
+"""
+
+p = ot.unif(n)
+q = ot.unif(n)
+
+gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+
+pl.figure()
+pl.imshow(gw, cmap='jet')
+pl.colorbar()
+pl.show()
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
new file mode 100755
index 0000000..da52768
--- /dev/null
+++ b/examples/plot_gromov_barycenter.py
@@ -0,0 +1,241 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+
+import scipy.ndimage as spi
+import matplotlib.pylab as pl
+from sklearn import manifold
+from sklearn.decomposition import PCA
+
+import ot
+
+"""
+
+Smacof MDS
+==========
+This function allows to find an embedding of points given a dissimilarity matrix
+that will be given by the output of the algorithm
+"""
+
+
+def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF
+ multidimensional scaling (MDS) in specific dimensionned target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+
+ eps : relative tolerance w.r.t stress to declare converge
+
+
+ Returns
+ -------
+ npos : R**dim ndarray
+ Embedded coordinates of the interpolated point cloud (defined with one isometry)
+
+
+ """
+
+ seed = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=seed,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+"""
+Data preparation
+================
+The four distributions are constructed from 4 simple images
+"""
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+carre = spi.imread('../data/carre.png').astype(np.float64) / 256
+rond = spi.imread('../data/rond.png').astype(np.float64) / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
+fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
+
+shapes = [carre, rond, triangle, fleche]
+
+S = 4
+xs = [[] for i in range(S)]
+
+
+for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+
+"""
+Barycenter computation
+======================
+The four distributions are constructed from 4 simple images
+"""
+ns = [len(xs[s]) for s in range(S)]
+N = 30
+
+"""Compute all distances matrices for the four shapes"""
+Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+Cs = [cs / cs.max() for cs in Cs]
+
+ps = [ot.unif(ns[s]) for s in range(S)]
+p = ot.unif(N)
+
+
+lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+Ct01 = [0 for i in range(2)]
+for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
+ ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
+
+Ct02 = [0 for i in range(2)]
+for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
+ ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
+
+Ct13 = [0 for i in range(2)]
+for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
+ ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
+
+Ct23 = [0 for i in range(2)]
+for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
+ ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
+
+"""
+Visualization
+=============
+"""
+
+"""The PCA helps in getting consistency between the rotations"""
+
+clf = PCA(n_components=2)
+npos = [0, 0, 0, 0]
+npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+npost01 = [0, 0]
+npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+npost02 = [0, 0]
+npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+npost13 = [0, 0]
+npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+npost23 = [0, 0]
+npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+fig = pl.figure(figsize=(10, 10))
+
+ax1 = pl.subplot2grid((4, 4), (0, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ax2 = pl.subplot2grid((4, 4), (0, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ax3 = pl.subplot2grid((4, 4), (0, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ax4 = pl.subplot2grid((4, 4), (0, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ax5 = pl.subplot2grid((4, 4), (1, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ax6 = pl.subplot2grid((4, 4), (1, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ax7 = pl.subplot2grid((4, 4), (2, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ax8 = pl.subplot2grid((4, 4), (2, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ax9 = pl.subplot2grid((4, 4), (3, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ax10 = pl.subplot2grid((4, 4), (3, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ax11 = pl.subplot2grid((4, 4), (3, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ax12 = pl.subplot2grid((4, 4), (3, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/ot/__init__.py b/ot/__init__.py
index 6d4c4c6..a295e1b 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -5,6 +5,7 @@
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -17,11 +18,13 @@ from . import utils
from . import datasets
from . import plot
from . import da
+from . import gromov
# OT functions
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
+from .gromov import gromov_wasserstein, gromov_wasserstein2
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -30,4 +33,5 @@ __version__ = "0.3.1"
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
- 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
+ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
+ 'gromov_wasserstein','gromov_wasserstein2']
diff --git a/ot/gromov.py b/ot/gromov.py
new file mode 100644
index 0000000..421ed3f
--- /dev/null
+++ b/ot/gromov.py
@@ -0,0 +1,474 @@
+
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein transport method
+===================================
+
+
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+from .bregman import sinkhorn
+from .utils import dist
+
+
+def square_loss(a, b):
+ """
+ Returns the value of L(a,b)=(1/2)*|a-b|^2
+ """
+
+ return 0.5 * (a - b)**2
+
+
+def kl_loss(a, b):
+ """
+ Returns the value of L(a,b)=a*log(a/b)-a+b
+ """
+
+ return a * np.log(a / b) - a + b
+
+
+def tensor_square_loss(C1, C2, T):
+ """
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the square loss
+ function as the loss function of Gromow-Wasserstein discrepancy.
+
+ Where :
+
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ T : A coupling between those two spaces
+
+ The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ f1(a)=(a^2)/2
+ f2(b)=(b^2)/2
+ h1(a)=a
+ h2(b)=b
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ T : np.ndarray(ns,nt)
+ Coupling between source and target spaces
+
+
+ Returns
+ -------
+ tens : (ns*nt) ndarray
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+
+
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+ T = np.asarray(T, dtype=np.float64)
+
+ def f1(a):
+ return (a**2) / 2
+
+ def f2(b):
+ return (b**2) / 2
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return b
+
+ tens = -np.dot(h1(C1), T).dot(h2(C2).T)
+ tens -= tens.min()
+
+ return np.array(tens)
+
+
+def tensor_kl_loss(C1, C2, T):
+ """
+ Returns the value of \mathcal{L}(C1,C2) \otimes T with the square loss
+ function as the loss function of Gromow-Wasserstein discrepancy.
+
+ Where :
+
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ T : A coupling between those two spaces
+
+ The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
+ f1(a)=a*log(a)-a
+ f2(b)=b
+ h1(a)=a
+ h2(b)=log(b)
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ T : np.ndarray(ns,nt)
+ Coupling between source and target spaces
+
+
+ Returns
+ -------
+ tens : (ns*nt) ndarray
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+
+ References
+ ----------
+
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+ T = np.asarray(T, dtype=np.float64)
+
+ def f1(a):
+ return a * np.log(a + 1e-15) - a
+
+ def f2(b):
+ return b
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return np.log(b + 1e-15)
+
+ tens = -np.dot(h1(C1), T).dot(h2(C2).T)
+ tens -= tens.min()
+
+ return np.array(tens)
+
+
+def update_square_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : np.ndarray(N,)
+ weights in the targeted barycenter
+ lambdas : list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : Cs : list of S np.ndarray(ns,ns)
+ Metric cost matrices
+
+ Returns
+ ----------
+ C updated
+
+
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return(np.divide(tmpsum, ppt))
+
+
+def update_kl_loss(p, lambdas, T, Cs):
+ """
+ Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : np.ndarray(N,)
+ weights in the targeted barycenter
+ lambdas : list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : Cs : list of S np.ndarray(ns,ns)
+ Metric cost matrices
+
+ Returns
+ ----------
+ C updated
+
+
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return(np.exp(np.divide(tmpsum, ppt)))
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein coupling between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ s.t. \GW 1 = p
+
+ \GW^T 1= q
+
+ \GW\geq 0
+
+ Where :
+
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ p : distribution in the source space
+ q : distribution in the target space
+ L : loss function to account for the misfit between the similarity matrices
+ H : entropy
+
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : np.ndarray(ns,)
+ distribution in the source space
+ q : np.ndarray(nt)
+ distribution in the target space
+ loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ forcing : np.ndarray(N,2)
+ list of forced couplings (where N is the number of forcing)
+
+ Returns
+ -------
+ T : coupling between the two spaces that minimizes :
+ \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ """
+
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+
+ T = np.outer(p, q) # Initialization
+
+ cpt = 0
+ err = 1
+
+ while (err > stopThr and cpt < numItermax):
+
+ Tprev = T
+
+ if loss_fun == 'square_loss':
+ tens = tensor_square_loss(C1, C2, T)
+
+ elif loss_fun == 'kl_loss':
+ tens = tensor_kl_loss(C1, C2, T)
+
+ T = sinkhorn(p, q, tens, epsilon)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+ err = np.linalg.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt = cpt + 1
+
+ if log:
+ return T, log
+ else:
+ return T
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
+
+ (C1,p) and (C2,q)
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \GW_Dist = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+
+ Where :
+
+ C1 : Metric cost matrix in the source space
+ C2 : Metric cost matrix in the target space
+ p : distribution in the source space
+ q : distribution in the target space
+ L : loss function to account for the misfit between the similarity matrices
+ H : entropy
+
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : np.ndarray(ns,)
+ distribution in the source space
+ q : np.ndarray(nt)
+ distribution in the target space
+ loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ forcing : np.ndarray(N,2)
+ list of forced couplings (where N is the number of forcing)
+
+ Returns
+ -------
+ T : coupling between the two spaces that minimizes :
+ \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+
+ """
+
+ if log:
+ gw, logv = gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
+ else:
+ gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
+ epsilon, numItermax, stopThr, verbose, log)
+
+ if loss_fun == 'square_loss':
+ gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
+
+ elif loss_fun == 'kl_loss':
+ gw_dist = np.sum(gw * tensor_kl_loss(C1, C2, gw))
+
+ if log:
+ return gw_dist, logv
+ else:
+ return gw_dist
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem:
+
+ .. math::
+ C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+
+ Where :
+
+ Cs : metric cost matrix
+ ps : distribution
+
+ Parameters
+ ----------
+ N : Integer
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray(ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray(ns,)
+ sample weights in the S spaces
+ p : np.ndarray(N,)
+ weights in the targeted barycenter
+ lambdas : list of the S spaces' weights
+ L : tensor-matrix multiplication function based on specific loss function
+ update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ epsilon : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ C : Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ """
+
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while(err > stopThr and cpt < numItermax):
+
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+ numItermax, 1e-5, verbose, log) for s in range(S)]
+
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt = cpt + 1
+
+ return C
diff --git a/test/test_gromov.py b/test/test_gromov.py
new file mode 100644
index 0000000..75eeaab
--- /dev/null
+++ b/test/test_gromov.py
@@ -0,0 +1,38 @@
+"""Tests for module gromov """
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import ot
+
+
+def test_gromov():
+ n = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+
+ xt = [xs[n - (i + 1)] for i in range(n)]
+ xt = np.array(xt)
+
+ p = ot.unif(n)
+ q = ot.unif(n)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov