summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_gromov.py4
-rwxr-xr-xexamples/plot_gromov_barycenter.py18
-rw-r--r--ot/gromov.py44
3 files changed, 42 insertions, 24 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index 0f839a3..dce66c4 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -22,8 +22,8 @@ import ot
"""
Sample two Gaussian distributions (2D and 3D)
=============================================
-The Gromov-Wasserstein distance allows to compute distances with samples that
-do not belong to the same metric space. For demonstration purpose, we sample
+The Gromov-Wasserstein distance allows to compute distances with samples that
+do not belong to the same metric space. For demonstration purpose, we sample
two Gaussian distributions in 2- and 3-dimensional spaces.
"""
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index c138031..52f4966 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -3,7 +3,7 @@
=====================================
Gromov-Wasserstein Barycenter example
=====================================
-This example is designed to show how to use the Gromov-Wassertsein distance
+This example is designed to show how to use the Gromov-Wasserstein distance
computation in POT.
"""
@@ -34,8 +34,9 @@ that will be given by the output of the algorithm
def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
"""
- Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF
- multidimensional scaling (MDS) in specific dimensionned target space
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
Parameters
----------
@@ -51,7 +52,8 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
Returns
-------
npos : ndarray, shape (R, dim)
- Embedded coordinates of the interpolated point cloud (defined with one isometry)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
"""
rng = np.random.RandomState(seed=3)
@@ -88,10 +90,10 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
-square = spi.imread('../data/square.png').astype(np.float64)[:,:,2] / 256
-cross = spi.imread('../data/cross.png').astype(np.float64)[:,:,2] / 256
-triangle = spi.imread('../data/triangle.png').astype(np.float64)[:,:,2] / 256
-star = spi.imread('../data/star.png').astype(np.float64)[:,:,2] / 256
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
shapes = [square, cross, triangle, star]
diff --git a/ot/gromov.py b/ot/gromov.py
index 82e3fd3..7968e5e 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -122,7 +122,9 @@ def tensor_kl_loss(C1, C2, T):
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
"""
@@ -157,7 +159,8 @@ def update_square_loss(p, lambdas, T, Cs):
----------
p : ndarray, shape (N,)
weights in the targeted barycenter
- lambdas : list of the S spaces' weights
+ lambdas : list of float
+ list of the S spaces' weights
T : list of S np.ndarray(ns,N)
the S Ts couplings calculated at each iteration
Cs : list of S ndarray, shape(ns,ns)
@@ -168,7 +171,8 @@ def update_square_loss(p, lambdas, T, Cs):
C : ndarray, shape (nt,nt)
updated C matrix
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
ppt = np.outer(p, p)
return np.divide(tmpsum, ppt)
@@ -194,13 +198,15 @@ def update_kl_loss(p, lambdas, T, Cs):
C : ndarray, shape (ns,ns)
updated C matrix
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
+ for s in range(len(T))])
ppt = np.outer(p, p)
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein coupling between the two measured similarity matrices
@@ -276,7 +282,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
T = sinkhorn(p, q, tens, epsilon)
if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all the 10th iterations
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
err = np.linalg.norm(T - Tprev)
if log:
@@ -296,7 +303,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
return T
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -363,7 +371,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
return gw_dist
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False):
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -390,7 +399,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
sample weights in the S spaces
p : ndarray, shape(N,)
weights in the targeted barycenter
- lambdas : list of the S spaces' weights
+ lambdas : list of float
+ list of the S spaces' weights
L : tensor-matrix multiplication function based on specific loss function
update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
with the S Ts couplings calculated at each iteration
@@ -404,6 +414,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
Print information along iterations
log : bool, optional
record log if True
+ init_C : bool, ndarray, shape(N,N)
+ random initial value for the C matrix provided by user
Returns
-------
@@ -416,10 +428,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
lambdas = np.asarray(lambdas, dtype=np.float64)
- # Initialization of C : random SPD matrix
- xalea = np.random.randn(N, 2)
- C = dist(xalea, xalea)
- C /= C.max()
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ else:
+ C = init_C
cpt = 0
err = 1
@@ -438,7 +453,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
C = update_kl_loss(p, lambdas, T, Cs)
if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all the 10th iterations
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
err = np.linalg.norm(C - Cprev)
error.append(err)