From 6484c9ea301fc15ae53b4afe134941909f581ffe Mon Sep 17 00:00:00 2001 From: tvayer Date: Wed, 29 May 2019 14:11:48 +0200 Subject: Tests + contributions --- ot/gromov.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index ad68a1c..297b194 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -926,6 +926,10 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ + + class UndefinedParameter(Exception): + pass + S = len(Cs) d = Ys[0].shape[1] #dimension on the node features if p is None: @@ -938,7 +942,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_structure: if init_C is None: - C=Cs[0] + raise UndefinedParameter('If C is fixed it must be initialized') else: C=init_C else: @@ -950,7 +954,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature if fixed_features: if init_X is None: - X=Ys[0] + raise UndefinedParameter('If X is fixed it must be initialized') else : X= init_X else: @@ -1004,13 +1008,13 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature # Cs is ns,ns # p is N,1 # ps is ns,1 - + T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns log_['Ts_iter'].append(T) - err_feature = np.linalg.norm(X - Xprev.reshape(d,N)) + err_feature = np.linalg.norm(X - Xprev.reshape(N,d)) err_structure = np.linalg.norm(C - Cprev) if log: -- cgit v1.2.3