From ad450b0a5bb63ee9731e88d4a8e7423b16f1abd8 Mon Sep 17 00:00:00 2001 From: tvayer Date: Tue, 4 Jun 2019 10:32:30 +0200 Subject: changes forgotten coments --- ot/gromov.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 53349b7..ca96b31 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -17,7 +17,7 @@ import numpy as np from .bregman import sinkhorn -from .utils import dist +from .utils import dist, UndefinedParameter from .optim import cg @@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ International Conference on Machine Learning (ICML). 2019. """ - class UndefinedParameter(Exception): - pass - S = len(Cs) d = Ys[0].shape[1] # dimension on the node features if p is None: @@ -1049,10 +1046,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T = [np.outer(p, q) for q in ps] - # X is N,d - # Ys is ns,d - Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] - # Ms is N,ns + Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns cpt = 0 err_feature = 1 @@ -1072,27 +1066,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - # X must be N,d - # Ys must be ns,d Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] if not fixed_structure: if loss_fun == 'square_loss': - # T must be ns,N - # Cs must be ns,ns - # p must be N,1 T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) - # Ys must be d,ns - # Ts must be N,ns - # p must be N,1 - # Ms is N,ns - # C is N,N - # Cs is ns,ns - # p is N,1 - # ps is ns,1 - T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns @@ -1115,7 +1095,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ if log: log_['T'] = T # from target to Ys log_['p'] = p - log_['Ms'] = Ms # Ms are N,ns + log_['Ms'] = Ms if log: return X, C, log_ -- cgit v1.2.3