From 84c272394d41d159d07174306b324590b3ffe40c Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Wed, 13 Sep 2017 01:03:21 +0200 Subject: Corrections on Gromov --- ot/gromov.py | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 82e3fd3..7968e5e 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -122,7 +122,9 @@ def tensor_kl_loss(C1, C2, T): References ---------- - .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. """ @@ -157,7 +159,8 @@ def update_square_loss(p, lambdas, T, Cs): ---------- p : ndarray, shape (N,) weights in the targeted barycenter - lambdas : list of the S spaces' weights + lambdas : list of float + list of the S spaces' weights T : list of S np.ndarray(ns,N) the S Ts couplings calculated at each iteration Cs : list of S ndarray, shape(ns,ns) @@ -168,7 +171,8 @@ def update_square_loss(p, lambdas, T, Cs): C : ndarray, shape (nt,nt) updated C matrix """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) + tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) + for s in range(len(T))]) ppt = np.outer(p, p) return np.divide(tmpsum, ppt) @@ -194,13 +198,15 @@ def update_kl_loss(p, lambdas, T, Cs): C : ndarray, shape (ns,ns) updated C matrix """ - tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))]) + tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) + for s in range(len(T))]) ppt = np.outer(p, p) return np.exp(np.divide(tmpsum, ppt)) -def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): +def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, + max_iter=1000, tol=1e-9, verbose=False, log=False): """ Returns the gromov-wasserstein coupling between the two measured similarity matrices @@ -276,7 +282,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, T = sinkhorn(p, q, tens, epsilon) if cpt % 10 == 0: - # we can speed up the process by checking for the error only all the 10th iterations + # we can speed up the process by checking for the error only all + # the 10th iterations err = np.linalg.norm(T - Tprev) if log: @@ -296,7 +303,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, return T -def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): +def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, + max_iter=1000, tol=1e-9, verbose=False, log=False): """ Returns the gromov-wasserstein discrepancy between the two measured similarity matrices @@ -363,7 +371,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9 return gw_dist -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): +def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, + max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): """ Returns the gromov-wasserstein barycenters of S measured similarity matrices @@ -390,7 +399,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, sample weights in the S spaces p : ndarray, shape(N,) weights in the targeted barycenter - lambdas : list of the S spaces' weights + lambdas : list of float + list of the S spaces' weights L : tensor-matrix multiplication function based on specific loss function update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel with the S Ts couplings calculated at each iteration @@ -404,6 +414,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, Print information along iterations log : bool, optional record log if True + init_C : bool, ndarray, shape(N,N) + random initial value for the C matrix provided by user Returns ------- @@ -416,10 +428,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)] lambdas = np.asarray(lambdas, dtype=np.float64) - # Initialization of C : random SPD matrix - xalea = np.random.randn(N, 2) - C = dist(xalea, xalea) - C /= C.max() + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + xalea = np.random.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + else: + C = init_C cpt = 0 err = 1 @@ -438,7 +453,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, C = update_kl_loss(p, lambdas, T, Cs) if cpt % 10 == 0: - # we can speed up the process by checking for the error only all the 10th iterations + # we can speed up the process by checking for the error only all + # the 10th iterations err = np.linalg.norm(C - Cprev) error.append(err) -- cgit v1.2.3