summaryrefslogtreecommitdiff
path: root/test/test_gromov.py
diff options
context:
space:
mode:
authorTanguy <tanguy.kerdoncuff@laposte.net>2021-09-17 18:36:33 +0200
committerGitHub <noreply@github.com>2021-09-17 18:36:33 +0200
commite0ba31ce39a7d9e65e50ea970a574b3db54e4207 (patch)
tree36c95fc33bd07be476c44f8b5ea65896cf1f0c9f /test/test_gromov.py
parent96bf1a46e74d6985419e14222afb0b9241a7bb36 (diff)
[MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275)
* Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein. * Correct some lines in SaGroW and PoGroW to follow pep8 guide. * Change nb_samples name. Use rdm state. Change symmetric check. * Change names of len(p) and len(q) in SaGroW and PoGroW. * Re-add some deleted lines in the comments of gromov.py Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_gromov.py')
-rw-r--r--test/test_gromov.py88
1 files changed, 82 insertions, 6 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 56414a8..19d61b1 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -33,7 +33,7 @@ def test_gromov():
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
@@ -54,7 +54,7 @@ def test_gromov():
np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
@@ -83,7 +83,7 @@ def test_entropic_gromov():
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
@@ -96,13 +96,89 @@ def test_entropic_gromov():
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+def test_pointwise_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+
+ # check constraints
+ np.testing.assert_allclose(
+ p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov
+
+ assert log['gw_dist_estimated'] == 0.0
+ assert log['gw_dist_std'] == 0.0
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+
+ assert log['gw_dist_estimated'] == 0.10342276348494964
+ assert log['gw_dist_std'] == 0.0015952535464736394
+
+
+def test_sampled_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ G, log = ot.gromov.sampled_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+
+ # check constraints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ assert log['gw_dist_estimated'] == 0.05679474884977278
+ assert log['gw_dist_std'] == 0.0005986592106971995
+
+
def test_gromov_barycenter():
ns = 50
nt = 60
@@ -186,7 +262,7 @@ def test_fgw():
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
@@ -203,7 +279,7 @@ def test_fgw():
np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(