summaryrefslogtreecommitdiff
path: root/ot/gromov.py
diff options
context:
space:
mode:
authorNemo Fournier <nemo.fournier@ens-lyon.org>2020-03-10 15:57:41 +0100
committerNemo Fournier <nemo.fournier@ens-lyon.org>2020-03-10 15:57:41 +0100
commit18fa98fb109c935dc8d87f9c93318d8cfd118738 (patch)
tree70b8ef65c4bac1c1c0eb2fd9ba0a99cdb3d5a0d8 /ot/gromov.py
parent20f9abd8633f4a905df97cc5478eae2e53c1aa96 (diff)
fixing trailing and before arithmetic operation whitespace issues
Diffstat (limited to 'ot/gromov.py')
-rw-r--r--ot/gromov.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index e329c70..43780a4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -488,11 +488,11 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
- return cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -563,7 +563,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, (1-alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
@@ -987,13 +987,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
max_iter : int, optional
- Max number of iterations
+ Max number of iterations
tol : float, optional
Stop threshol on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
- Record log if True.
+ Record log if True.
init_C : ndarray, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.