summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-28 16:50:00 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-28 16:50:00 +0200
commitb1b514f5d9de009e63bd407dfd9c0a0cf6128876 (patch)
tree1a0b9e972d09af049d10bfbab30f10f3b657487f /ot/gromov.py
parent549b95b5736b42f3fe74daf9805303a08b1ae01d (diff)
bary fgw
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 7491664..31bd657 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -883,8 +883,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
return C
-def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,p=None,loss_fun='square_loss',
- max_iter=100, tol=1e-9,verbose=False,log=True,init_C=None,init_X=None):
+def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,
+ p=None,loss_fun='square_loss',max_iter=100, tol=1e-9,
+ verbose=False,log=True,init_C=None,init_X=None):
"""
Compute the fgw barycenter as presented eq (5) in [3].
@@ -957,7 +958,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
X=np.zeros((N,d))
else:
X = init_X
-
+
T=[np.outer(p,q) for q in ps]
# X is N,d
@@ -981,7 +982,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if not fixed_features:
Ys_temp=[y.T for y in Ys]
- X=update_feature_matrix(lambdas,Ys_temp,T,p)
+ X=update_feature_matrix(lambdas,Ys_temp,T,p).T
# X must be N,d
# Ys must be ns,d
@@ -1024,11 +1025,11 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
- log_['T']=T # ce sont les matrices du barycentre de la target vers les Ys
+ log_['T']=T # from target to Ys
log_['p']=p
- log_['Ms']=Ms #Ms sont de tailles N,ns
+ log_['Ms']=Ms #Ms are N,ns
- return X.T,C,log_
+ return X,C,log_
def update_sructure_matrix(p, lambdas, T, Cs):