summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-06-04 10:32:30 +0200
committertvayer <titouan.vayer@gmail.com>2019-06-04 10:32:35 +0200
commitad450b0a5bb63ee9731e88d4a8e7423b16f1abd8 (patch)
treecab0421292074e59cb4eeb2846e8cca5aa159d3a /ot/gromov.py
parent89a2e0aee4353a051d924de0457f8976c26fa5d7 (diff)
changes forgotten coments
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py26
1 files changed, 3 insertions, 23 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 53349b7..ca96b31 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -17,7 +17,7 @@ import numpy as np
from .bregman import sinkhorn
-from .utils import dist
+from .utils import dist, UndefinedParameter
from .optim import cg
@@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
International Conference on Machine Learning (ICML). 2019.
"""
- class UndefinedParameter(Exception):
- pass
-
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
@@ -1049,10 +1046,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T = [np.outer(p, q) for q in ps]
- # X is N,d
- # Ys is ns,d
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
- # Ms is N,ns
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
cpt = 0
err_feature = 1
@@ -1072,27 +1066,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
- # X must be N,d
- # Ys must be ns,d
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
- # T must be ns,N
- # Cs must be ns,ns
- # p must be N,1
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- # Ys must be d,ns
- # Ts must be N,ns
- # p must be N,1
- # Ms is N,ns
- # C is N,N
- # Cs is ns,ns
- # p is N,1
- # ps is ns,1
-
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
@@ -1115,7 +1095,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
if log:
log_['T'] = T # from target to Ys
log_['p'] = p
- log_['Ms'] = Ms # Ms are N,ns
+ log_['Ms'] = Ms
if log:
return X, C, log_