summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:20:34 +0200
committerNicolas Courty <Nico@MacBook-Pro-de-Nicolas.local>2017-09-01 11:20:34 +0200
commitab6ed1df93cd78bb7f1a54282103d4d830e68bcb (patch)
treebbcff976e21c2e89a4656c542506cd0f728309bb /ot/gromov.py
parent4ec5b339ef527d4d49a022ddf57b38dff037548c (diff)
docstrings and naming
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index 421ed3f..ad85fcd 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -208,7 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return(np.exp(np.divide(tmpsum, ppt)))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein coupling between the two measured similarity matrices
@@ -248,7 +248,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
- numItermax : int, optional
+ max_iter : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
@@ -274,7 +274,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
cpt = 0
err = 1
- while (err > stopThr and cpt < numItermax):
+ while (err > stopThr and cpt < max_iter):
Tprev = T
@@ -307,7 +307,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
return T
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -362,10 +362,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
if log:
gw, logv = gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
+ C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
else:
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
- epsilon, numItermax, stopThr, verbose, log)
+ epsilon, max_iter, stopThr, verbose, log)
if loss_fun == 'square_loss':
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -379,7 +379,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
return gw_dist
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -442,12 +442,12 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
error = []
- while(err > stopThr and cpt < numItermax):
+ while(err > stopThr and cpt < max_iter):
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- numItermax, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-5, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)