summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_gromov.py15
-rwxr-xr-xexamples/plot_gromov_barycenter.py33
-rw-r--r--ot/gromov.py35
-rw-r--r--test/test_gromov.py14
4 files changed, 97 insertions, 0 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
index 92312ae..99aaf81 100644
--- a/examples/plot_gromov.py
+++ b/examples/plot_gromov.py
@@ -26,7 +26,11 @@ The Gromov-Wasserstein distance allows to compute distances with samples that do
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
"""
+<<<<<<< HEAD
n_samples = 30 # nb samples
+=======
+n = 30 # nb samples
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -35,9 +39,15 @@ mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+<<<<<<< HEAD
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+=======
+xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n, 3).dot(P) + mu_t
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
"""
@@ -75,8 +85,13 @@ Compute Gromov-Wasserstein plans and distance
=============================================
"""
+<<<<<<< HEAD
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+=======
+p = ot.unif(n)
+q = ot.unif(n)
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
index f0657e1..46ec4bc 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/plot_gromov_barycenter.py
@@ -91,12 +91,21 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+<<<<<<< HEAD
square = spi.imread('../data/carre.png').astype(np.float64) / 256
circle = spi.imread('../data/rond.png').astype(np.float64) / 256
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256
shapes = [square, circle, triangle, arrow]
+=======
+carre = spi.imread('../data/carre.png').astype(np.float64) / 256
+rond = spi.imread('../data/rond.png').astype(np.float64) / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
+fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
+
+shapes = [carre, rond, triangle, fleche]
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
S = 4
xs = [[] for i in range(S)]
@@ -118,36 +127,60 @@ Barycenter computation
The four distributions are constructed from 4 simple images
"""
ns = [len(xs[s]) for s in range(S)]
+<<<<<<< HEAD
n_samples = 30
+=======
+N = 30
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
"""Compute all distances matrices for the four shapes"""
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
Cs = [cs / cs.max() for cs in Cs]
ps = [ot.unif(ns[s]) for s in range(S)]
+<<<<<<< HEAD
p = ot.unif(n_samples)
+=======
+p = ot.unif(N)
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
Ct01 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], [
+=======
+ Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Ct02 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], [
+=======
+ Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Ct13 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], [
+=======
+ Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
Ct23 = [0 for i in range(2)]
for i in range(2):
+<<<<<<< HEAD
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], [
+=======
+ Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
"""
diff --git a/ot/gromov.py b/ot/gromov.py
index ad85fcd..197e3ea 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -208,7 +208,11 @@ def update_kl_loss(p, lambdas, T, Cs):
return(np.exp(np.divide(tmpsum, ppt)))
+<<<<<<< HEAD
def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+=======
+def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
"""
Returns the gromov-wasserstein coupling between the two measured similarity matrices
@@ -248,7 +252,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
+<<<<<<< HEAD
max_iter : int, optional
+=======
+ numItermax : int, optional
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
@@ -274,7 +282,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
cpt = 0
err = 1
+<<<<<<< HEAD
while (err > stopThr and cpt < max_iter):
+=======
+ while (err > stopThr and cpt < numItermax):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
Tprev = T
@@ -307,7 +319,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
return T
+<<<<<<< HEAD
def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+=======
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
"""
Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
@@ -362,10 +378,17 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
if log:
gw, logv = gromov_wasserstein(
+<<<<<<< HEAD
C1, C2, p, q, loss_fun, epsilon, max_iter, stopThr, verbose, log)
else:
gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
epsilon, max_iter, stopThr, verbose, log)
+=======
+ C1, C2, p, q, loss_fun, epsilon, numItermax, stopThr, verbose, log)
+ else:
+ gw = gromov_wasserstein(C1, C2, p, q, loss_fun,
+ epsilon, numItermax, stopThr, verbose, log)
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
if loss_fun == 'square_loss':
gw_dist = np.sum(gw * tensor_square_loss(C1, C2, gw))
@@ -379,7 +402,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
return gw_dist
+<<<<<<< HEAD
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, stopThr=1e-9, verbose=False, log=False):
+=======
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000, stopThr=1e-9, verbose=False, log=False):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
@@ -442,12 +469,20 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
error = []
+<<<<<<< HEAD
while(err > stopThr and cpt < max_iter):
+=======
+ while(err > stopThr and cpt < numItermax):
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
Cprev = C
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
+<<<<<<< HEAD
max_iter, 1e-5, verbose, log) for s in range(S)]
+=======
+ numItermax, 1e-5, verbose, log) for s in range(S)]
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index c26d898..a6c89f2 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -10,11 +10,16 @@ import ot
def test_gromov():
+<<<<<<< HEAD
n_samples = 50 # nb samples
+=======
+ n = 50 # nb samples
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
+<<<<<<< HEAD
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
xt = [xs[n_samples - (i + 1)] for i in range(n_samples)]
@@ -22,6 +27,15 @@ def test_gromov():
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+=======
+ xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+
+ xt = [xs[n - (i + 1)] for i in range(n)]
+ xt = np.array(xt)
+
+ p = ot.unif(n)
+ q = ot.unif(n)
+>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)