summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_barycenter_1D.py8
-rw-r--r--examples/plot_gromov.py25
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/da.py6
-rw-r--r--ot/gpu/__init__.py2
-rw-r--r--ot/gpu/da.py3
-rw-r--r--ot/gromov.py235
-rw-r--r--ot/utils.py2
8 files changed, 152 insertions, 133 deletions
diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py
index 620936b..ecf640c 100644
--- a/examples/plot_barycenter_1D.py
+++ b/examples/plot_barycenter_1D.py
@@ -25,7 +25,7 @@ import ot
from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection
-##############################################################################
+#
# Generate data
# -------------
@@ -48,7 +48,7 @@ n_distributions = A.shape[1]
M = ot.utils.dist0(n)
M /= M.max()
-##############################################################################
+#
# Plot data
# ---------
@@ -60,7 +60,7 @@ for i in range(n_distributions):
pl.title('Distributions')
pl.tight_layout()
-##############################################################################
+#
# Barycenter computation
# ----------------------
@@ -90,7 +90,7 @@ pl.legend()
pl.title('Barycenters')
pl.tight_layout()
-##############################################################################
+#
# Barycentric interpolation
# -------------------------
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index 5f2d826..9188da9 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -20,7 +20,7 @@ from mpl_toolkits.mplot3d import Axes3D # noqa
import ot
-##############################################################################
+#
# Sample two Gaussian distributions (2D and 3D)
# ---------------------------------------------
#
@@ -43,7 +43,7 @@ P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
-##############################################################################
+#
# Plotting the distributions
# --------------------------
@@ -56,7 +56,7 @@ ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
-##############################################################################
+#
# Compute distance kernels, normalize them and then display
# ---------------------------------------------------------
@@ -74,33 +74,32 @@ pl.subplot(122)
pl.imshow(C2)
pl.show()
-##############################################################################
+#
# Compute Gromov-Wasserstein plans and distance
# ---------------------------------------------
-#%%
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-gw0,log0 = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True,log=True)
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
-gw,log= ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4,log=True,verbose=True)
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
-pl.figure(1,(10,5))
+pl.figure(1, (10, 5))
-pl.subplot(1,2,1)
+pl.subplot(1, 2, 1)
pl.imshow(gw0, cmap='jet')
-pl.colorbar()
pl.title('Gromov Wasserstein')
-pl.subplot(1,2,2)
-pl.imshow(gw0, cmap='jet')
-pl.colorbar()
+pl.subplot(1, 2, 2)
+pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index cee7379..1500e59 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -16,7 +16,6 @@ from . import bregman
from . import optim
from . import utils
from . import datasets
-from . import plot
from . import da
from . import gromov
@@ -24,7 +23,6 @@ from . import gromov
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
-from .gromov import gromov_wasserstein, gromov_wasserstein2
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -32,5 +30,5 @@ from .utils import dist, unif, tic, toc, toq
__version__ = "0.4.0"
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
- 'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
+ 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
diff --git a/ot/da.py b/ot/da.py
index 1d3d0ba..ee73ec8 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -933,6 +933,7 @@ def distribution_estimation_uniform(X):
class BaseTransport(BaseEstimator):
+
"""Base class for OTDA objects
Notes
@@ -1180,6 +1181,7 @@ class BaseTransport(BaseEstimator):
class SinkhornTransport(BaseTransport):
+
"""Domain Adapatation OT method based on Sinkhorn Algorithm
Parameters
@@ -1289,6 +1291,7 @@ class SinkhornTransport(BaseTransport):
class EMDTransport(BaseTransport):
+
"""Domain Adapatation OT method based on Earth Mover's Distance
Parameters
@@ -1377,6 +1380,7 @@ class EMDTransport(BaseTransport):
class SinkhornLpl1Transport(BaseTransport):
+
"""Domain Adapatation OT method based on sinkhorn algorithm +
LpL1 class regularization.
@@ -1486,6 +1490,7 @@ class SinkhornLpl1Transport(BaseTransport):
class SinkhornL1l2Transport(BaseTransport):
+
"""Domain Adapatation OT method based on sinkhorn algorithm +
l1l2 class regularization.
@@ -1608,6 +1613,7 @@ class SinkhornL1l2Transport(BaseTransport):
class MappingTransport(BaseEstimator):
+
"""MappingTransport: DA methods that aims at jointly estimating a optimal
transport coupling and the associated mapping
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index c8f9433..a2fdd3d 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -5,7 +5,7 @@ from . import da
from .bregman import sinkhorn
# Author: Remi Flamary <remi.flamary@unice.fr>
-# Leo Gautheron <https://github.com/aje>
+# Leo Gautheron <https://github.com/aje>
#
# License: MIT License
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 05c580f..71a485a 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -188,6 +188,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
class OTDA_GPU(OTDA):
+
def normalizeM(self, norm):
if norm == "median":
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
@@ -204,6 +205,7 @@ class OTDA_GPU(OTDA):
class OTDA_sinkhorn(OTDA_GPU):
+
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
cudamat.init()
xs = np.asarray(xs, dtype=np.float64)
@@ -228,6 +230,7 @@ class OTDA_sinkhorn(OTDA_GPU):
class OTDA_lpl1(OTDA_GPU):
+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
**kwargs):
cudamat.init()
diff --git a/ot/gromov.py b/ot/gromov.py
index e4dd112..b1e9ee0 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -20,12 +20,12 @@ from .utils import dist
from .optim import cg
-def init_matrix(C1,C2,T,p,q,loss_fun='square_loss'):
+def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
""" Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
function as the loss function of Gromow-Wasserstein discrepancy.
-
+
The matrices are computed as described in Proposition 1 in [12]
Where :
@@ -56,18 +56,18 @@ def init_matrix(C1,C2,T,p,q,loss_fun='square_loss'):
T : ndarray, shape (ns, nt)
Coupling between source and target spaces
p : ndarray, shape (ns,)
-
+
Returns
-------
-
+
constC : ndarray, shape (ns, nt)
Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
h2(C) matrix in Eq. (6)
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
@@ -76,39 +76,45 @@ def init_matrix(C1,C2,T,p,q,loss_fun='square_loss'):
"""
-
if loss_fun == 'square_loss':
def f1(a):
- return (a**2)/2
+ return (a**2) / 2
+
def f2(b):
- return (b**2)/2
+ return (b**2) / 2
+
def h1(a):
- return a
+ return a
+
def h2(b):
return b
elif loss_fun == 'kl_loss':
def f1(a):
- return a * np.log(a + 1e-15) - a
+ return a * np.log(a + 1e-15) - a
+
def f2(b):
- return b
+ return b
+
def h1(a):
- return a
+ return a
+
def h2(b):
return np.log(b + 1e-15)
- constC1=np.dot(np.dot(f1(C1),p.reshape(-1,1)),
- np.ones(len(q)).reshape(1,-1))
- constC2=np.dot(np.ones(len(p)).reshape(-1,1),
- np.dot(q.reshape(1,-1),f2(C2).T))
- constC=constC1+constC2
+ constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)),
+ np.ones(len(q)).reshape(1, -1))
+ constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
+ np.dot(q.reshape(1, -1), f2(C2).T))
+ constC = constC1 + constC2
hC1 = h1(C1)
hC2 = h2(C2)
- return constC,hC1,hC2
+ return constC, hC1, hC2
+
+
+def tensor_product(constC, hC1, hC2, T):
+ """ Return the tensor for Gromov-Wasserstein fast computation
-def tensor_product(constC,hC1,hC2,T):
- """ Return the tensor for Gromov-Wasserstein fast computation
-
The tensor is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
@@ -116,14 +122,14 @@ def tensor_product(constC,hC1,hC2,T):
constC : ndarray, shape (ns, nt)
Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
h2(C) matrix in Eq. (6)
-
+
Returns
-------
-
+
tens : ndarray, shape (ns, nt)
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
@@ -133,15 +139,16 @@ def tensor_product(constC,hC1,hC2,T):
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- """
- A=-np.dot(hC1, T).dot(hC2.T)
- tens = constC+A
- #tens -= tens.min()
+ """
+ A = -np.dot(hC1, T).dot(hC2.T)
+ tens = constC + A
+ # tens -= tens.min()
return tens
-def gwloss(constC,hC1,hC2,T):
+
+def gwloss(constC, hC1, hC2, T):
""" Return the Loss for Gromov-Wasserstein
-
+
The loss is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
@@ -149,15 +156,15 @@ def gwloss(constC,hC1,hC2,T):
constC : ndarray, shape (ns, nt)
Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
h2(C) matrix in Eq. (6)
T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ Current value of transport matrix T
Returns
-------
-
+
loss : float
Gromov Wasserstein loss
@@ -166,16 +173,17 @@ def gwloss(constC,hC1,hC2,T):
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
-
+
"""
- tens=tensor_product(constC,hC1,hC2,T)
-
- return np.sum(tens*T)
+ tens = tensor_product(constC, hC1, hC2, T)
+
+ return np.sum(tens * T)
+
+
+def gwggrad(constC, hC1, hC2, T):
+ """ Return the gradient for Gromov-Wasserstein
-def gwggrad(constC,hC1,hC2,T):
- """ Return the gradient for Gromov-Wasserstein
-
The gradient is computed as described in Proposition 2 in [12].
Parameters
@@ -183,15 +191,15 @@ def gwggrad(constC,hC1,hC2,T):
constC : ndarray, shape (ns, nt)
Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
h2(C) matrix in Eq. (6)
T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ Current value of transport matrix T
Returns
-------
-
+
grad : ndarray, shape (ns, nt)
Gromov Wasserstein gradient
@@ -200,10 +208,10 @@ def gwggrad(constC,hC1,hC2,T):
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
-
- """
- return 2*tensor_product(constC,hC1,hC2,T) # [12] Prop. 2 misses a 2 factor
+ """
+ return 2 * tensor_product(constC, hC1, hC2,
+ T) # [12] Prop. 2 misses a 2 factor
def update_square_loss(p, lambdas, T, Cs):
@@ -261,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1,C2,p,q,loss_fun,log=False,**kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -307,38 +315,40 @@ def gromov_wasserstein(C1,C2,p,q,loss_fun,log=False,**kwargs):
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
log : dict
convergence information and loss
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
-
+
"""
T = np.eye(len(p), len(q))
- constC,hC1,hC2=init_matrix(C1,C2,T,p,q,loss_fun)
-
- G0=p[:,None]*q[None,:]
-
+ constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
def f(G):
- return gwloss(constC,hC1,hC2,G)
+ return gwloss(constC, hC1, hC2, G)
+
def df(G):
- return gwggrad(constC,hC1,hC2,G)
-
+ return gwggrad(constC, hC1, hC2, G)
+
if log:
- res,log=cg(p,q,0,1,f,df,G0,log=True,**kwargs)
- log['gw_dist']=gwloss(constC,hC1,hC2,res)
- return res,log
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ return res, log
else:
- return cg(p,q,0,1,f,df,G0,**kwargs)
+ return cg(p, q, 0, 1, f, df, G0, **kwargs)
+
-def gromov_wasserstein2(C1,C2,p,q,loss_fun,log=False,**kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -383,40 +393,41 @@ def gromov_wasserstein2(C1,C2,p,q,loss_fun,log=False,**kwargs):
Gromov-Wasserstein distance
log : dict
convergence information and Coupling marix
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
-
+
"""
T = np.eye(len(p), len(q))
- constC,hC1,hC2=init_matrix(C1,C2,T,p,q,loss_fun)
-
- G0=p[:,None]*q[None,:]
-
+ constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
def f(G):
- return gwloss(constC,hC1,hC2,G)
+ return gwloss(constC, hC1, hC2, G)
+
def df(G):
- return gwggrad(constC,hC1,hC2,G)
- res,log=cg(p,q,0,1,f,df,G0,log=True,**kwargs)
- log['gw_dist']=gwloss(constC,hC1,hC2,res)
- log['T']=res
+ return gwggrad(constC, hC1, hC2, G)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, res)
+ log['T'] = res
if log:
- return log['gw_dist'],log
+ return log['gw_dist'], log
else:
return log['gw_dist']
def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -469,34 +480,34 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
T : ndarray, shape (ns, nt)
coupling between the two spaces that minimizes :
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
+ International Conference on Machine Learning (ICML). 2016.
+
"""
C1 = np.asarray(C1, dtype=np.float64)
C2 = np.asarray(C2, dtype=np.float64)
T = np.outer(p, q) # Initialization
-
- constC,hC1,hC2=init_matrix(C1,C2,T,p,q,loss_fun)
+
+ constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
cpt = 0
err = 1
-
+
if log:
- log={'err':[]}
+ log = {'err': []}
while (err > tol and cpt < max_iter):
Tprev = T
# compute the gradient
- tens=gwggrad(constC,hC1,hC2,T)
+ tens = gwggrad(constC, hC1, hC2, T)
T = sinkhorn(p, q, tens, epsilon)
@@ -517,13 +528,14 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
cpt += 1
if log:
- log['gw_dist']=gwloss(constC,hC1,hC2,T)
+ log['gw_dist'] = gwloss(constC, hC1, hC2, T)
return T, log
else:
return T
+
def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -569,20 +581,19 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
-------
gw_dist : float
Gromov-Wasserstein distance
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
+ International Conference on Machine Learning (ICML). 2016.
+ """
gw, logv = entropic_gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
-
- log['T']=gw
+ C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
+
+ log['T'] = gw
if log:
return logv['gw_dist'], logv
@@ -591,7 +602,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -640,13 +651,13 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
-------
C : ndarray, shape (N, N)
Similarity matrix in the barycenter space (permutated arbitrarily)
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
+ International Conference on Machine Learning (ICML). 2016.
+
"""
S = len(Cs)
@@ -671,7 +682,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-5, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -698,14 +709,14 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
return C
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
(Cs)_{s=1}^{s=S}
- The function solves the following optimization problem with block
+ The function solves the following optimization problem with block
coordinate descent:
.. math::
@@ -747,13 +758,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
-------
C : ndarray, shape (N, N)
Similarity matrix in the barycenter space (permutated arbitrarily)
-
+
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
+ International Conference on Machine Learning (ICML). 2016.
+
"""
S = len(Cs)
@@ -777,7 +788,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
while(err > tol and cpt < max_iter):
Cprev = C
- T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -802,4 +813,4 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
cpt += 1
- return C \ No newline at end of file
+ return C
diff --git a/ot/utils.py b/ot/utils.py
index 31a002b..9eab3fc 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -223,6 +223,7 @@ def check_params(**kwargs):
class deprecated(object):
+
"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
@@ -320,6 +321,7 @@ def _is_deprecated(func):
class BaseEstimator(object):
+
"""Base class for most objects in POT
adapted from sklearn BaseEstimator class