summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-05-05 10:53:48 +0200
committerGitHub <noreply@github.com>2023-05-05 10:53:48 +0200
commit7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch)
tree0a41128a975500bfef52a4c21b5af634adecc71a /test
parent83dc498b496087aea293df1445442d8728435211 (diff)
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul * typo error * einsum projections ssw * Test broadcast matmul * einsum projections ssw * Test broadcast matmul * projections SSW with einsum * reduce number of samples in test wasserstein_circle_unif * Update releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_1d_solver.py6
-rw-r--r--test/test_backend.py15
-rw-r--r--test/test_sliced.py32
3 files changed, 46 insertions, 7 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 21abd1d..075a415 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -279,7 +279,7 @@ def test_wasserstein1d_circle_devices(nx):
def test_wasserstein_1d_unif_circle():
# test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
n = 20
- m = 50000
+ m = 1000
rng = np.random.RandomState(0)
u = rng.rand(n,)
@@ -298,8 +298,8 @@ def test_wasserstein_1d_unif_circle():
wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
# check loss is similar
- np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
- np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2)
def test_wasserstein1d_unif_circle_devices(nx):
diff --git a/test/test_backend.py b/test/test_backend.py
index 5351e52..fedc62f 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -298,6 +298,8 @@ def test_empty_backend():
nx.transpose(M)
with pytest.raises(NotImplementedError):
nx.detach(M)
+ with pytest.raises(NotImplementedError):
+ nx.matmul(M, M.T)
def test_func_backends(nx):
@@ -308,6 +310,9 @@ def test_func_backends(nx):
v = rnd.randn(3)
val = np.array([1.0])
+ M1 = rnd.randn(1, 2, 10, 10)
+ M2 = rnd.randn(3, 1, 10, 10)
+
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
@@ -326,6 +331,9 @@ def test_func_backends(nx):
SquareMb = nx.from_numpy(SquareM)
vb = nx.from_numpy(v)
+ M1b = nx.from_numpy(M1)
+ M2b = nx.from_numpy(M2)
+
val = nx.from_numpy(val)
sp_rowb = nx.from_numpy(sp_row)
@@ -661,6 +669,13 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(B))
lst_name.append("detach B")
+ A = nx.matmul(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matmul")
+ A = nx.matmul(M1b, M2b)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matmul broadcast")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 7b7437a..6d5a27b 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -295,6 +295,26 @@ def test_sliced_sphere_same_dist():
np.testing.assert_almost_equal(res, 0.)
+def test_sliced_sphere_same_proj():
+ n_projections = 10
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ seed = 42
+
+ cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True)
+ cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_sphere_bad_shapes():
n = 100
rng = np.random.RandomState(0)
@@ -398,28 +418,32 @@ def test_sliced_sphere_backend_type_devices(nx):
y = rng.randn(2 * n, 3)
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ sw_np, log = ot.sliced_wasserstein_sphere(x, y, log=True)
+ P = log["projections"]
+
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
xb, yb = nx.from_numpy(x, y, type_as=tp)
- valb = ot.sliced_wasserstein_sphere(xb, yb)
+ valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp))
nx.assert_same_dtype_device(xb, valb)
+ np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb))
def test_sliced_sphere_gradient():
if torch:
import torch.nn.functional as F
- X0 = torch.randn((500, 3))
+ X0 = torch.randn((20, 3))
X0 = F.normalize(X0, p=2, dim=-1)
X0.requires_grad_(True)
- X1 = torch.randn((500, 3))
+ X1 = torch.randn((20, 3))
X1 = F.normalize(X1, p=2, dim=-1)
- sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2)
+ sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=100, p=2)
grad_x0 = torch.autograd.grad(sw, X0)[0]
assert not torch.any(torch.isnan(grad_x0))