summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-02 13:42:02 +0100
committerGitHub <noreply@github.com>2021-11-02 13:42:02 +0100
commita335324d008e8982be61d7ace937815a2bfa98f9 (patch)
tree83c7f637597f10f6f3d20b15532e53fc65b51f22 /test/test_gromov.py
parent0cb2b2efe901ed74c614046d250518769f870313 (diff)
[MRG] Backend for gromov (#294)
* bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py297
1 files changed, 207 insertions, 90 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 0242d72..509c54d 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -8,11 +8,12 @@
import numpy as np
import ot
+from ot.backend import NumpyBackend
import pytest
-def test_gromov():
+def test_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -31,37 +32,50 @@ def test_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
- np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04)
+ np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
- np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+ np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
+ np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_entropic_gromov():
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -80,30 +94,44 @@ def test_entropic_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ ))
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_pointwise_gromov():
+def test_pointwise_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -122,33 +150,52 @@ def test_pointwise_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
def loss(x, y):
return np.abs(x - y)
+ def lossb(x, y):
+ return nx.abs(x - y)
+
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- assert log['gw_dist_estimated'] == 0.0
- assert log['gw_dist_std'] == 0.0
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
- assert log['gw_dist_estimated'] == 0.10342276348494964
- assert log['gw_dist_std'] == 0.0015952535464736394
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
-def test_sampled_gromov():
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_sampled_gromov(nx):
n_samples = 50 # nb samples
- mu_s = np.array([0, 0])
- cov_s = np.array([[1, 0], [0, 1]])
+ mu_s = np.array([0, 0], dtype=np.float64)
+ cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
@@ -163,23 +210,35 @@ def test_sampled_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
def loss(x, y):
return np.abs(x - y)
+ def lossb(x, y):
+ return nx.abs(x - y)
+
G, log = ot.gromov.sampled_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb, logb = ot.gromov.sampled_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(Gb)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- assert log['gw_dist_estimated'] == 0.05679474884977278
- assert log['gw_dist_std'] == 0.0005986592106971995
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
-def test_gromov_barycenter():
+def test_gromov_barycenter(nx):
ns = 10
nt = 20
@@ -188,26 +247,42 @@ def test_gromov_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 3
- Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', # 5e-4,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
- Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', # 5e-4,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
@pytest.mark.filterwarnings("ignore:divide")
-def test_gromov_entropic_barycenter():
+def test_gromov_entropic_barycenter(nx):
ns = 10
nt = 20
@@ -216,26 +291,41 @@ def test_gromov_entropic_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 2
- Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', 1e-3,
- max_iter=50, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
-
- Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 1e-3,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
-
-
-def test_fgw():
+ p = ot.unif(n_samples)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+def test_fgw(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -260,33 +350,46 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
+ Mb = nx.from_numpy(M)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ Gb = nx.to_numpy(Gb)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence fgw
+ p, Gb.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence fgw
+ q, Gb.sum(0), atol=1e-04) # cf convergence fgw
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+ Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-08)
+ np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_fgw_barycenter():
+def test_fgw_barycenter(nx):
np.random.seed(42)
ns = 50
@@ -300,30 +403,44 @@ def test_fgw_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1, p2 = ot.unif(ns), ot.unif(nt)
n_samples = 3
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ p = ot.unif(n_samples)
+
+ ysb = nx.from_numpy(ys)
+ ytb = nx.from_numpy(yt)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
+ fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
+ )
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Cb = nx.from_numpy(init_C)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5],
+ alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
-
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3, log=True)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Xb = nx.from_numpy(init_X)
+
+ Xb, Cb, logb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_Xb,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))