summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py637
1 files changed, 613 insertions, 24 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 9c85b92..80b6df4 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -3,7 +3,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#
# License: MIT License
@@ -11,18 +11,15 @@ import numpy as np
import ot
from ot.backend import NumpyBackend
from ot.backend import torch, tf
-
import pytest
def test_gromov(nx):
n_samples = 50 # nb samples
-
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
-
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1)
xt = xs[::-1].copy()
p = ot.unif(n_samples)
@@ -38,7 +35,7 @@ def test_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -51,13 +48,13 @@ def test_gromov(nx):
np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
- gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
- gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
)
G = log['T']
@@ -77,6 +74,49 @@ def test_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_gromov(nx):
+ n_samples = 30 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+
+ G, log = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'square_loss', G0=G0, log=True, symmetric=False, verbose=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'square_loss', log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['gw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04)
+
+
def test_gromov_dtype_device(nx):
# setup
n_samples = 50 # nb samples
@@ -104,7 +144,7 @@ def test_gromov_dtype_device(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -130,7 +170,7 @@ def test_gromov_device_tf():
with tf.device("/CPU:0"):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -138,7 +178,7 @@ def test_gromov_device_tf():
# Check that everything happens on the GPU
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
assert nx.dtype_device(Gb)[1].startswith("GPU")
@@ -174,6 +214,7 @@ def test_gromov2_gradients():
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
+ # Test with exact line-search
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val.backward()
@@ -184,6 +225,60 @@ def test_gromov2_gradients():
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
+ # Test with armijo line-search
+ q1.grad = None
+ p1.grad = None
+ C11.grad = None
+ C12.grad = None
+ val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert q1.shape == q1.grad.shape
+ assert p1.shape == p1.grad.shape
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_gw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ Gb, logb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=0., reg=1., nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
@@ -199,19 +294,21 @@ def test_entropic_gromov(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
-
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
- C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ G, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-2, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
- C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-2, verbose=True, log=False
))
# check constraints
@@ -222,9 +319,11 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True)
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
G = log['T']
@@ -241,6 +340,45 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_entropic_gromov(nx):
+ n_samples = 10 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, verbose=True, log=False)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None,
+ epsilon=1e-1, verbose=True, log=False
+ ))
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-1, log=False)
+ gwb = nx.to_numpy(gwb)
+
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_entropic_gromov_dtype_device(nx):
@@ -539,8 +677,8 @@ def test_fgw(nx):
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -555,8 +693,8 @@ def test_fgw(nx):
np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -573,6 +711,82 @@ def test_fgw(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_asymmetric_fgw(nx):
+ n_samples = 50 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ # add features
+ F1 = np.random.uniform(low=0., high=10, size=(n_samples, 1))
+ F2 = F1[idx, :]
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ M = ot.dist(F1, F2)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ # Tests with kl-loss:
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True)
+ Gb = nx.to_numpy(Gb)
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'kl_loss', alpha=0.5, G0=G0, log=True, symmetric=None, verbose=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'kl_loss', alpha=0.5, log=True, symmetric=False, G0=G0b, verbose=True)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log['fgw_dist'], 0., atol=1e-04)
+ np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04)
+
+
def test_fgw2_gradients():
n_samples = 20 # nb samples
@@ -617,6 +831,57 @@ def test_fgw2_gradients():
assert M1.shape == M1.grad.shape
+def test_fgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b = ot.gromov.init_matrix(C1b, C2b, pb, qb, loss_fun='square_loss')
+
+ def f(G):
+ return ot.gromov.gwloss(constCb, hC1b, hC2b, G, None)
+
+ def df(G):
+ return ot.gromov.gwggrad(constCb, hC1b, hC2b, G, None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_gromov_linesearch(G, deltaG, cost_G, C1b, C2b, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res, log = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=None)
+ # feed the precomputed local optimum Gb to cg
+ res_armijo, log_armijo = ot.optim.cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+ np.testing.assert_allclose(res_armijo, Gb, atol=1e-06)
+
+
def test_fgw_barycenter(nx):
np.random.seed(42)
@@ -1186,3 +1451,327 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
# > Compare results with/without backend
total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
+
+
+def test_semirelaxed_gromov(nx):
+ np.random.seed(0)
+ # unbalanced proportions
+ list_n = [30, 15]
+ nt = 2
+ ns = np.sum(list_n)
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns), dtype=np.float64)
+ C2 = np.array([[0.8, 0.05],
+ [0.05, 1.]], dtype=np.float64)
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ p = ot.unif(ns, type_as=C1)
+ q0 = ot.unif(C2.shape[0], type_as=C1)
+ G0 = p[:, None] * q0[None, :]
+ # asymmetric
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_gromov2_gradients():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_gromov_wasserstein2(C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+
+
+def test_srgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, 'square_loss', armijo=False, symmetric=True, G0=None, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, 0., 1., nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, 0., 1., f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)
+
+
+def test_semirelaxed_fgw(nx):
+ np.random.seed(0)
+ list_n = [16, 8]
+ nt = 2
+ ns = 24
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns))
+ C2 = np.array([[0.7, 0.05],
+ [0.05, 0.9]])
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ F1 = np.zeros((ns, 1))
+ F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1))
+ F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1))
+ F2 = np.zeros((2, 1))
+ F2[1, :] = 1.
+ M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)
+
+ p = ot.unif(ns)
+ q0 = ot.unif(C2.shape[0])
+ G0 = p[:, None] * q0[None, :]
+
+ # asymmetric
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+def test_semirelaxed_fgw2_gradients():
+ n_samples = 20 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
+
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5)
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+ M = ot.dist(xs, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ if torch:
+
+ devices = [torch.device("cpu")]
+ if torch.cuda.is_available():
+ devices.append(torch.device("cuda"))
+ for device in devices:
+ # semirelaxed solvers do not support gradients over masses yet.
+ p1 = torch.tensor(p, requires_grad=False, device=device)
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
+ M1 = torch.tensor(M, requires_grad=True, device=device)
+
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1)
+
+ val.backward()
+
+ assert val.device == p1.device
+ assert p1.grad is None
+ assert C11.shape == C11.grad.shape
+ assert C12.shape == C12.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
+def test_srfgw_helper_backend(nx):
+ n_samples = 20 # nb samples
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0)
+ ys = np.random.randn(xs.shape[0], 2)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=1)
+ yt = np.random.randn(xt.shape[0], 2)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ alpha = 0.5
+ Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, 'square_loss', alpha=0.5, armijo=False, symmetric=True, G0=G0b, log=True)
+
+ # calls with nx=None
+ constCb, hC1b, hC2b, fC2tb = ot.gromov.init_matrix_semirelaxed(C1b, C2b, pb, loss_fun='square_loss')
+ ones_pb = nx.ones(pb.shape[0], type_as=pb)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwloss(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_pb, nx.dot(qG, fC2tb))
+ return ot.gromov.gwggrad(constCb + marginal_product, hC1b, hC2b, G, nx=None)
+
+ def line_search(cost, G, deltaG, Mi, cost_G):
+ return ot.gromov.solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1b, C2b, ones_pb, M=(1 - alpha) * Mb, reg=alpha, nx=None)
+ # feed the precomputed local optimum Gb to semirelaxed_cg
+ res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
+ # check constraints
+ np.testing.assert_allclose(res, Gb, atol=1e-06)