summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-16 11:58:59 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-16 11:58:59 +0100
commit4a585de94109102c89bcd7ad43e35772e1027cd2 (patch)
tree3764b88f2a65404fcbc072ad3c6565cb921e89c0 /ot/gromov.py
parentab2fc5486005732d60714b76eecf59d1104346f6 (diff)
update examples
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py109
1 files changed, 108 insertions, 1 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 8d08397..e4dd112 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -590,7 +590,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
return logv['gw_dist']
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
+def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -696,3 +696,110 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
cpt += 1
return C
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ """
+ Returns the gromov-wasserstein barycenters of S measured similarity matrices
+
+ (Cs)_{s=1}^{s=S}
+
+ The function solves the following optimization problem with block
+ coordinate descent:
+
+ .. math::
+ C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+
+ Where :
+
+ Cs : metric cost matrix
+ ps : distribution
+
+ Parameters
+ ----------
+ N : Integer
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray(ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray(ns,)
+ sample weights in the S spaces
+ p : ndarray, shape(N,)
+ weights in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ loss_fun : tensor-matrix multiplication function based on specific loss function
+ update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ init_C : bool, ndarray, shape(N,N)
+ random initial value for the C matrix provided by user
+
+ Returns
+ -------
+ C : ndarray, shape (N, N)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+
+ References
+ ----------
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ S = len(Cs)
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while(err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = np.linalg.norm(C - Cprev)
+ error.append(err)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ return C \ No newline at end of file