From e0ba31ce39a7d9e65e50ea970a574b3db54e4207 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Fri, 17 Sep 2021 18:36:33 +0200 Subject: [MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein. * Correct some lines in SaGroW and PoGroW to follow pep8 guide. * Change nb_samples name. Use rdm state. Change symmetric check. * Change names of len(p) and len(q) in SaGroW and PoGroW. * Re-add some deleted lines in the comments of gromov.py Co-authored-by: Rémi Flamary --- test/test_gromov.py | 88 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 82 insertions(+), 6 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index 56414a8..19d61b1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -33,7 +33,7 @@ def test_gromov(): G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -54,7 +54,7 @@ def test_gromov(): np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -83,7 +83,7 @@ def test_entropic_gromov(): G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -96,13 +96,89 @@ def test_entropic_gromov(): np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov +def test_pointwise_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.0 + assert log['gw_dist_std'] == 0.0 + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + + assert log['gw_dist_estimated'] == 0.10342276348494964 + assert log['gw_dist_std'] == 0.0015952535464736394 + + +def test_sampled_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.05679474884977278 + assert log['gw_dist_std'] == 0.0005986592106971995 + + def test_gromov_barycenter(): ns = 50 nt = 60 @@ -186,7 +262,7 @@ def test_fgw(): G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( @@ -203,7 +279,7 @@ def test_fgw(): np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( -- cgit v1.2.3