diff options
Diffstat (limited to 'test/test_partial.py')
-rwxr-xr-x | test/test_partial.py | 113 |
1 files changed, 92 insertions, 21 deletions
diff --git a/test/test_partial.py b/test/test_partial.py index ae4a1ab..86f9e62 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -8,6 +8,7 @@ import numpy as np import scipy as sp import ot +from ot.backend import to_numpy, torch import pytest @@ -82,7 +83,7 @@ def test_partial_wasserstein_lagrange(): w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True) -def test_partial_wasserstein(): +def test_partial_wasserstein(nx): n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -102,25 +103,20 @@ def test_partial_wasserstein(): m = 0.5 + p, q, M = nx.from_numpy(p, q, M) + w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True) - w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, - log=True, verbose=True) + w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) # check constraints - np.testing.assert_equal( - w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) # check transported mass - np.testing.assert_allclose( - np.sum(w0), m, atol=1e-04) - np.testing.assert_allclose( - np.sum(w), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04) w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False) @@ -130,12 +126,87 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) # check constraints - np.testing.assert_equal( - G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(G), m, atol=1e-04) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04) + + empty_array = nx.zeros(0, type_as=M) + w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04) + + w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04) + + +def test_partial_wasserstein2_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + + m = 0.5 + + w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) + + w.backward() + + assert M.grad is not None + assert M.grad.shape == M.shape + + +def test_entropic_partial_wasserstein_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + + m = 0.5 + reg = 1 + + _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True) + + log['partial_w_dist'].backward() + + assert M.grad is not None + assert p.grad is not None + assert q.grad is not None + assert M.grad.shape == M.shape + assert p.grad.shape == p.shape + assert q.grad.shape == q.shape def test_partial_gromov_wasserstein(): |