diff options
author | Antoine Collas <contact@antoinecollas.fr> | 2023-03-21 15:18:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-21 15:18:09 +0100 |
commit | b9ed7b1650475420cc5bbec6c31476cc098790d5 (patch) | |
tree | f6624d9509288466ac3e9c0cd475b4a984d72ce4 /test/test_partial.py | |
parent | c48cd76235569ada98af6b1bba589510a2818906 (diff) |
[MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend (#449)
* add test of partial_wasserstein with torch tensors
* WIP: differentiable ot.partial.partial_wasserstein
* change test of torch partial
* make partial_wasserstein2 work with torch
* test backward through ot.partial.partial_wasserstein2
* add test of entropic_partial_wasserstein with torch tensors
* make entropic_partial_wasserstein work with torch tensors
* add test of backward through entropic_partial_wasserstein
* rm unused import
* test partial_wasserstein with all backends
* tests of partial fcts: check if torch is available
* partial: check if marginals are empty arrays
* add tests when marginals are empty arrays and/or m=None
* add PR to RELEASES.md
---------
Co-authored-by: Antoine Collas <22830806+antoinecollas@users.noreply.github.com>
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_partial.py')
-rwxr-xr-x | test/test_partial.py | 113 |
1 files changed, 92 insertions, 21 deletions
diff --git a/test/test_partial.py b/test/test_partial.py index ae4a1ab..86f9e62 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -8,6 +8,7 @@ import numpy as np import scipy as sp import ot +from ot.backend import to_numpy, torch import pytest @@ -82,7 +83,7 @@ def test_partial_wasserstein_lagrange(): w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True) -def test_partial_wasserstein(): +def test_partial_wasserstein(nx): n_samples = 20 # nb samples (gaussian) n_noise = 20 # nb of samples (noise) @@ -102,25 +103,20 @@ def test_partial_wasserstein(): m = 0.5 + p, q, M = nx.from_numpy(p, q, M) + w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True) - w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, - log=True, verbose=True) + w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True) # check constraints - np.testing.assert_equal( - w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) # check transported mass - np.testing.assert_allclose( - np.sum(w0), m, atol=1e-04) - np.testing.assert_allclose( - np.sum(w), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04) + np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04) w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False) @@ -130,12 +126,87 @@ def test_partial_wasserstein(): np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1) # check constraints - np.testing.assert_equal( - G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein - np.testing.assert_equal( - G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein - np.testing.assert_allclose( - np.sum(G), m, atol=1e-04) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04) + + empty_array = nx.zeros(0, type_as=M) + w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04) + + w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None) + + # check constraints + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p)) + np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q)) + + # check transported mass + np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04) + + +def test_partial_wasserstein2_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), dtype=torch.float64) + + m = 0.5 + + w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True) + + w.backward() + + assert M.grad is not None + assert M.grad.shape == M.shape + + +def test_entropic_partial_wasserstein_gradient(): + if torch: + n_samples = 40 + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov) + + M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64) + + p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64) + + m = 0.5 + reg = 1 + + _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True) + + log['partial_w_dist'].backward() + + assert M.grad is not None + assert p.grad is not None + assert q.grad is not None + assert M.grad.shape == M.shape + assert p.grad.shape == p.shape + assert q.grad.shape == q.shape def test_partial_gromov_wasserstein(): |