From 46fc12a298c49b715ac953cff391b18b54dab0f0 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Fri, 1 Sep 2017 11:43:51 +0200 Subject: solving conflicts :/ --- ot/gromov.py | 43 +++++-------------------------------------- 1 file changed, 5 insertions(+), 38 deletions(-) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 197e3ea..9dbf463 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -208,11 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs): return(np.exp(np.divide(tmpsum, ppt))) -<<<<<<< HEAD def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False): -======= -def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False): ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """ Returns the gromov-wasserstein coupling between the two measured similarity matrices @@ -252,11 +248,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float Regularization term >0 -<<<<<<< HEAD +<<<<<<< HEAD max_iter : int, optional -======= +======= numItermax : int, optional ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d +>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d Max number of iterations stopThr : float, optional Stop threshold on error (>0) @@ -282,11 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr cpt = 0 err = 1 -<<<<<<< HEAD while (err > stopThr and cpt < max_iter): -======= - while (err > stopThr and cpt < numItermax): ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d Tprev = T @@ -319,11 +311,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr return T -<<<<<<< HEAD def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False): -======= -def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False): ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """ Returns the gromov-wasserstein discrepancy between the two measured similarity matrices @@ -358,7 +346,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss' epsilon : float Regularization term >0 - numItermax : int, optional + max_iter : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) @@ -378,17 +366,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh if log: gw, logv = gromov_wasserstein( -<<<<<<< HEAD C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log) else: gw = gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log) -======= - C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log) - else: - gw = gromov_wasserstein(C1, C2, p, q, loss_fun, - epsilon, numItermax, stopThr, verbose, log) ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d if loss_fun == 'square_loss': gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw)) @@ -402,11 +383,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh return gw_dist -<<<<<<< HEAD def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False): -======= -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False): ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d """ Returns the gromov-wasserstein barycenters of S measured similarity matrices @@ -439,7 +416,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000 with the S Ts couplings calculated at each iteration epsilon : float Regularization term >0 - numItermax : int, optional + max_iter : int, optional Max number of iterations stopThr : float, optional Stop threshol on error (>0) @@ -469,21 +446,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000 error = [] -<<<<<<< HEAD while(err > stopThr and cpt < max_iter): -======= - while(err > stopThr and cpt < numItermax): ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d - Cprev = C T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, -<<<<<<< HEAD max_iter, 1e-5, verbose, log) for s in range(S)] -======= - numItermax, 1e-5, verbose, log) for s in range(S)] ->>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d - if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) -- cgit v1.2.3