summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:11:48 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:11:48 +0200
commit6484c9ea301fc15ae53b4afe134941909f581ffe (patch)
tree90a5e1487524696139722990e2f2d737d3206aef /ot/gromov.py
parent11c2c26ff897e5763e714546e7021cffa8d673a7 (diff)
Tests + contributions
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index ad68a1c..297b194 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -926,6 +926,10 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+
+ class UndefinedParameter(Exception):
+ pass
+
S = len(Cs)
d = Ys[0].shape[1] #dimension on the node features
if p is None:
@@ -938,7 +942,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if fixed_structure:
if init_C is None:
- C=Cs[0]
+ raise UndefinedParameter('If C is fixed it must be initialized')
else:
C=init_C
else:
@@ -950,7 +954,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if fixed_features:
if init_X is None:
- X=Ys[0]
+ raise UndefinedParameter('If X is fixed it must be initialized')
else :
X= init_X
else:
@@ -1004,13 +1008,13 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
# Cs is ns,ns
# p is N,1
# ps is ns,1
-
+
T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
log_['Ts_iter'].append(T)
- err_feature = np.linalg.norm(X - Xprev.reshape(d,N))
+ err_feature = np.linalg.norm(X - Xprev.reshape(N,d))
err_structure = np.linalg.norm(C - Cprev)
if log: