diff options
author | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2017-09-01 11:22:13 +0200 |
---|---|---|
committer | Nicolas Courty <Nico@MacBook-Pro-de-Nicolas.local> | 2017-09-01 11:22:13 +0200 |
commit | 64a5d3c4e49688c13d236baf9ed23420070024d6 (patch) | |
tree | ffe5db073c07e579b26ead6a8ebcb0ff78ce6a33 /ot | |
parent | ab6ed1df93cd78bb7f1a54282103d4d830e68bcb (diff) | |
parent | 986f46ddde3ce2f550cb56f66620df377326423d (diff) |
docstrings and naming
Diffstat (limited to 'ot')
-rw-r--r-- | ot/gromov.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index ad85fcd..197e3ea 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -208,7 +208,11 @@ def update_kl_loss(p, lambdas, T, Cs): return(np.exp(np.divide(tmpsum, ppt)))
+<<<<<<< HEAD def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+======= +def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """
Returns the gromov-wasserstein coupling between the two measured similarity matrices
@@ -248,7 +252,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1 loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
+<<<<<<< HEAD max_iter : int, optional
+======= + numItermax : int, optional
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
@@ -274,7 +282,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1 cpt = 0
err = 1
+<<<<<<< HEAD while (err > stopThr and cpt < max_iter):
+======= + while (err > stopThr and cpt < numItermax):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d Tprev = T
@@ -307,7 +319,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1 return T
+<<<<<<< HEAD def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+======= +def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -362,10 +378,17 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr= if log:
gw, logv = gromov_wasserstein(
+<<<<<<< HEAD C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
else:
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
epsilon, max_iter, stopThr, verbose, log)
+======= + C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
+ else:
+ gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
+ epsilon, numItermax, stopThr, verbose, log)
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d if loss_fun == 'square_loss':
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -379,7 +402,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr= return gw_dist
+<<<<<<< HEAD def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+======= +def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -442,12 +469,20 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, error = []
+<<<<<<< HEAD while(err > stopThr and cpt < max_iter):
+======= + while(err > stopThr and cpt < numItermax):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+<<<<<<< HEAD max_iter, 1e-5, verbose, log) for s in range(S)]
+======= + numItermax, 1e-5, verbose, log) for s in range(S)]
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
|