diff options
Diffstat (limited to 'ot/gromov.py')
-rw-r--r-- | ot/gromov.py | 43 |
1 files changed, 5 insertions, 38 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index 197e3ea..9dbf463 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -208,11 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs): return(np.exp(np.divide(tmpsum, ppt)))
-<<<<<<< HEAD def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
-======= -def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """
Returns the gromov-wasserstein coupling between the two measured similarity matrices
@@ -252,11 +248,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
-<<<<<<< HEAD +<<<<<<< HEAD
max_iter : int, optional
-======= +=======
numItermax : int, optional
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d +>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
@@ -282,11 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr cpt = 0
err = 1
-<<<<<<< HEAD while (err > stopThr and cpt < max_iter):
-======= - while (err > stopThr and cpt < numItermax):
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d Tprev = T
@@ -319,11 +311,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr return T
-<<<<<<< HEAD def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
-======= -def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -358,7 +346,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
- numItermax : int, optional
+ max_iter : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
@@ -378,17 +366,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh if log:
gw, logv = gromov_wasserstein(
-<<<<<<< HEAD C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
else:
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
epsilon, max_iter, stopThr, verbose, log)
-======= - C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
- else:
- gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
- epsilon, numItermax, stopThr, verbose, log)
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d if loss_fun == 'square_loss':
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -402,11 +383,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh return gw_dist
-<<<<<<< HEAD def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
-======= -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -439,7 +416,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000 with the S Ts couplings calculated at each iteration
epsilon : float
Regularization term >0
- numItermax : int, optional
+ max_iter : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
@@ -469,21 +446,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000 error = []
-<<<<<<< HEAD while(err > stopThr and cpt < max_iter):
-======= - while(err > stopThr and cpt < numItermax):
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d -
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
-<<<<<<< HEAD max_iter, 1e-5, verbose, log) for s in range(S)]
-======= - numItermax, 1e-5, verbose, log) for s in range(S)]
->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d -
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
|