summaryrefslogtreecommitdiff
path: root/test/test_partial.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_partial.py')
-rwxr-xr-xtest/test_partial.py124
1 files changed, 99 insertions, 25 deletions
diff --git a/test/test_partial.py b/test/test_partial.py
index 97c611b..86f9e62 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -8,6 +8,7 @@
import numpy as np
import scipy as sp
import ot
+from ot.backend import to_numpy, torch
import pytest
@@ -79,8 +80,10 @@ def test_partial_wasserstein_lagrange():
w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 100, log=True)
-def test_partial_wasserstein():
+
+def test_partial_wasserstein(nx):
n_samples = 20 # nb samples (gaussian)
n_noise = 20 # nb of samples (noise)
@@ -100,25 +103,20 @@ def test_partial_wasserstein():
m = 0.5
+ p, q, M = nx.from_numpy(p, q, M)
+
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
- w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
- log=True, verbose=True)
+ w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m, log=True, verbose=True)
# check constraints
- np.testing.assert_equal(
- w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w0.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- w.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
# check transported mass
- np.testing.assert_allclose(
- np.sum(w0), m, atol=1e-04)
- np.testing.assert_allclose(
- np.sum(w), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), m, atol=1e-04)
+ np.testing.assert_allclose(np.sum(to_numpy(w)), m, atol=1e-04)
w0, log0 = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
w0_val = ot.partial.partial_wasserstein2(p, q, M, m=m, log=False)
@@ -128,15 +126,91 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
# check constraints
- np.testing.assert_equal(
- G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
- np.testing.assert_equal(
- G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
- np.testing.assert_allclose(
- np.sum(G), m, atol=1e-04)
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(G, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_allclose(np.sum(to_numpy(G)), m, atol=1e-04)
+
+ empty_array = nx.zeros(0, type_as=M)
+ w = ot.partial.partial_wasserstein(empty_array, empty_array, M=M, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w)), 1, atol=1e-04)
+
+ w0 = ot.partial.entropic_partial_wasserstein(empty_array, empty_array, M=M, reg=10, m=None)
+
+ # check constraints
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=1) - p) <= 1e-5, [True] * len(p))
+ np.testing.assert_equal(to_numpy(nx.sum(w0, axis=0) - q) <= 1e-5, [True] * len(q))
+
+ # check transported mass
+ np.testing.assert_allclose(np.sum(to_numpy(w0)), 1, atol=1e-04)
+
+
+def test_partial_wasserstein2_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), dtype=torch.float64)
+
+ m = 0.5
+
+ w, log = ot.partial.partial_wasserstein2(p, q, M, m=m, log=True)
+
+ w.backward()
+
+ assert M.grad is not None
+ assert M.grad.shape == M.shape
+
+
+def test_entropic_partial_wasserstein_gradient():
+ if torch:
+ n_samples = 40
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+
+ M = torch.tensor(ot.dist(xs, xt), requires_grad=True, dtype=torch.float64)
+
+ p = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+ q = torch.tensor(ot.unif(n_samples), requires_grad=True, dtype=torch.float64)
+
+ m = 0.5
+ reg = 1
+
+ _, log = ot.partial.entropic_partial_wasserstein(p, q, M, m=m, reg=reg, log=True)
+
+ log['partial_w_dist'].backward()
+
+ assert M.grad is not None
+ assert p.grad is not None
+ assert q.grad is not None
+ assert M.grad.shape == M.shape
+ assert p.grad.shape == p.shape
+ assert q.grad.shape == q.shape
def test_partial_gromov_wasserstein():
+ rng = np.random.RandomState(seed=42)
n_samples = 20 # nb samples
n_noise = 10 # nb of samples (noise)
@@ -149,11 +223,11 @@ def test_partial_gromov_wasserstein():
mu_t = np.array([0, 0, 0])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
- xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng)
+ xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
- xt = np.random.randn(n_samples, 3).dot(P) + mu_t
- xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+ xt = rng.randn(n_samples, 3).dot(P) + mu_t
+ xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0)
xt2 = xs[::-1].copy()
C1 = ot.dist(xs, xs)