From a335324d008e8982be61d7ace937815a2bfa98f9 Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Tue, 2 Nov 2021 13:42:02 +0100 Subject: [MRG] Backend for gromov (#294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: RĂ©mi Flamary --- test/test_gromov.py | 297 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 207 insertions(+), 90 deletions(-) (limited to 'test/test_gromov.py') diff --git a/test/test_gromov.py b/test/test_gromov.py index 0242d72..509c54d 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -8,11 +8,12 @@ import numpy as np import ot +from ot.backend import NumpyBackend import pytest -def test_gromov(): +def test_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -31,37 +32,50 @@ def test_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) - np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) + np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1) - np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False + np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06) + np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_entropic_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_entropic_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -80,30 +94,44 @@ def test_entropic_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True + )) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) + np.testing.assert_allclose(gw, gwb, atol=1e-06) np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_pointwise_gromov(): +def test_pointwise_gromov(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -122,33 +150,52 @@ def test_pointwise_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.0 - assert log['gw_dist_std'] == 0.0 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08) G, log = ot.gromov.pointwise_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + G = NumpyBackend().todense(G) + Gb, logb = ot.gromov.pointwise_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(nx.todense(Gb)) - assert log['gw_dist_estimated'] == 0.10342276348494964 - assert log['gw_dist_std'] == 0.0015952535464736394 + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8) -def test_sampled_gromov(): +@pytest.skip_backend("jax", reason="test very slow with jax backend") +def test_sampled_gromov(nx): n_samples = 50 # nb samples - mu_s = np.array([0, 0]) - cov_s = np.array([[1, 0], [0, 1]]) + mu_s = np.array([0, 0], dtype=np.float64) + cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) @@ -163,23 +210,35 @@ def test_sampled_gromov(): C1 /= C1.max() C2 /= C2.max() + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + def loss(x, y): return np.abs(x - y) + def lossb(x, y): + return nx.abs(x - y) + G, log = ot.gromov.sampled_gromov_wasserstein( C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb, logb = ot.gromov.sampled_gromov_wasserstein( + C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov - assert log['gw_dist_estimated'] == 0.05679474884977278 - assert log['gw_dist_std'] == 0.0005986592106971995 + np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08) + np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08) -def test_gromov_barycenter(): +def test_gromov_barycenter(nx): ns = 10 nt = 20 @@ -188,26 +247,42 @@ def test_gromov_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 3 - Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', # 5e-4, - max_iter=100, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + p = ot.unif(n_samples) - Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', # 5e-4, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) @pytest.mark.filterwarnings("ignore:divide") -def test_gromov_entropic_barycenter(): +def test_gromov_entropic_barycenter(nx): ns = 10 nt = 20 @@ -216,26 +291,41 @@ def test_gromov_entropic_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1 = ot.unif(ns) + p2 = ot.unif(nt) n_samples = 2 - Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'square_loss', 1e-3, - max_iter=50, tol=1e-3, - verbose=True) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - - Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], - [ot.unif(ns), ot.unif(nt) - ], ot.unif(n_samples), [.5, .5], - 'kl_loss', 1e-3, - max_iter=100, tol=1e-3) - np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) - - -def test_fgw(): + p = ot.unif(n_samples) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Cb = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + +def test_fgw(nx): n_samples = 50 # nb samples mu_s = np.array([0, 0]) @@ -260,33 +350,46 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() + Mb = nx.from_numpy(M) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + pb = nx.from_numpy(p) + qb = nx.from_numpy(q) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + Gb = nx.to_numpy(Gb) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence fgw + p, Gb.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence fgw + q, Gb.sum(0), atol=1e-04) # cf convergence fgw Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples) np.testing.assert_allclose( - G, np.flipud(Id), atol=1e-04) # cf convergence gromov + Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) G = log['T'] + Gb = nx.to_numpy(logb['T']) - np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + np.testing.assert_allclose(fgw, fgwb, atol=1e-08) + np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1) # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose( - p, G.sum(1), atol=1e-04) # cf convergence gromov + p, Gb.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( - q, G.sum(0), atol=1e-04) # cf convergence gromov + q, Gb.sum(0), atol=1e-04) # cf convergence gromov -def test_fgw_barycenter(): +def test_fgw_barycenter(nx): np.random.seed(42) ns = 50 @@ -300,30 +403,44 @@ def test_fgw_barycenter(): C1 = ot.dist(Xs) C2 = ot.dist(Xt) - + p1, p2 = ot.unif(ns), ot.unif(nt) n_samples = 3 - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + p = ot.unif(n_samples) + + ysb = nx.from_numpy(ys) + ytb = nx.from_numpy(yt) + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + p1b = nx.from_numpy(p1) + p2b = nx.from_numpy(p2) + pb = nx.from_numpy(p) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, + fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 + ) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) - - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5, - fixed_structure=True, init_C=init_C, fixed_features=False, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Cb = nx.from_numpy(init_C) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5], + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) init_X = np.random.randn(n_samples, ys.shape[1]) - - X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3, log=True) - np.testing.assert_allclose(C.shape, (n_samples, n_samples)) - np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + init_Xb = nx.from_numpy(init_X) + + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) -- cgit v1.2.3