summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_gromov.py29
-rwxr-xr-xexamples/plot_gromov_barycenter.py8
-rw-r--r--ot/__init__.py3
-rw-r--r--ot/gromov.py109
4 files changed, 129 insertions, 20 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index d42c21a..5f2d826 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -81,23 +81,26 @@ pl.show()
#%%
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-ot.tic()
-gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4,verbose=True)
-ot.toc()
-ot.tic()
-gw2,log2= ot.gromov.gromov_wasserstein0(C1, C2, p, q, 'square_loss', epsilon=5e-4,log=True,verbose=True)
-ot.toc()
-gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw0,log0 = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True,log=True)
-ot.tic()
-gw0,log0=ot.gromov.gw_lp(C1, C2, p, q, 'square_loss',log=True,verbose=True)
-ot.toc()
+gw,log= ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4,log=True,verbose=True)
-print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
-pl.figure()
-pl.imshow(gw2, cmap='jet')
+
+pl.figure(1,(10,5))
+
+pl.subplot(1,2,1)
+pl.imshow(gw0, cmap='jet')
pl.colorbar()
+pl.title('Gromov Wasserstein')
+
+pl.subplot(1,2,2)
+pl.imshow(gw0, cmap='jet')
+pl.colorbar()
+pl.title('Entropic Gromov Wasserstein')
+
pl.show()
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index 180b0cf..fde822b 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -132,28 +132,28 @@ Ct01 = [0 for i in range(2)]
for i in range(2):
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
[ps[0], ps[1]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss', #5e-4,
max_iter=100, tol=1e-3)
Ct02 = [0 for i in range(2)]
for i in range(2):
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
[ps[0], ps[2]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss',# 5e-4,
max_iter=100, tol=1e-3)
Ct13 = [0 for i in range(2)]
for i in range(2):
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
[ps[1], ps[3]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss',# 5e-4,
max_iter=100, tol=1e-3)
Ct23 = [0 for i in range(2)]
for i in range(2):
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
[ps[2], ps[3]
- ], p, lambdast[i], 'square_loss', 5e-4,
+ ], p, lambdast[i], 'square_loss', #5e-4,
max_iter=100, tol=1e-3)
diff --git a/ot/__init__.py b/ot/__init__.py
index a5df43d..cee7379 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -33,5 +33,4 @@ __version__ = "0.4.0"
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
- 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
- 'gromov_wasserstein','gromov_wasserstein2']
+ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/gromov.py b/ot/gromov.py
index 8d08397..e4dd112 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -590,7 +590,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
return logv['gw_dist']
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
+def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -696,3 +696,110 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
cpt += 1
return C
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem with block
+ coordinate descent:
+
+ .. math::
+ C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+
+ Where :
+
+ Cs : metric cost matrix
+ ps : distribution
+
+ Parameters
+ ----------
+ N : Integer
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray(ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray(ns,)
+ sample weights in the S spaces
+ p : ndarray, shape(N,)
+ weights in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ loss_fun : tensor-matrix multiplication function based on specific loss function
+ update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ init_C : bool, ndarray, shape(N,N)
+ random initial value for the C matrix provided by user
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while(err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C \ No newline at end of file