From 4a585de94109102c89bcd7ad43e35772e1027cd2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 16 Feb 2018 11:58:59 +0100 Subject: update examples --- ot/gromov.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 8d08397..e4dd112 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -590,7 +590,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, return logv['gw_dist'] -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, +def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): """ Returns the gromov-wasserstein barycenters of S measured similarity matrices @@ -696,3 +696,110 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, cpt += 1 return C + + +def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): + """ + Returns the gromov-wasserstein barycenters of S measured similarity matrices + + (Cs)_{s=1}^{s=S} + + The function solves the following optimization problem with block + coordinate descent: + + .. math:: + C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) + + + Where : + + Cs : metric cost matrix + ps : distribution + + Parameters + ---------- + N : Integer + Size of the targeted barycenter + Cs : list of S np.ndarray(ns,ns) + Metric cost matrices + ps : list of S np.ndarray(ns,) + sample weights in the S spaces + p : ndarray, shape(N,) + weights in the targeted barycenter + lambdas : list of float + list of the S spaces' weights + loss_fun : tensor-matrix multiplication function based on specific loss function + update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel + with the S Ts couplings calculated at each iteration + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + init_C : bool, ndarray, shape(N,N) + random initial value for the C matrix provided by user + + Returns + ------- + C : ndarray, shape (N, N) + Similarity matrix in the barycenter space (permutated arbitrarily) + + References + ---------- + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + """ + + S = len(Cs) + + Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] + lambdas = np.asarray(lambdas, dtype=np.float64) + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + xalea = np.random.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + else: + C = init_C + + cpt = 0 + err = 1 + + error = [] + + while(err > tol and cpt < max_iter): + Cprev = C + + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, + numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)] + if loss_fun == 'square_loss': + C = update_square_loss(p, lambdas, T, Cs) + + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T, Cs) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = np.linalg.norm(C - Cprev) + error.append(err) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + return C \ No newline at end of file -- cgit v1.2.3