diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-05-05 10:53:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-05 10:53:48 +0200 |
commit | 7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch) | |
tree | 0a41128a975500bfef52a4c21b5af634adecc71a /test | |
parent | 83dc498b496087aea293df1445442d8728435211 (diff) |
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul
* typo error
* einsum projections ssw
* Test broadcast matmul
* einsum projections ssw
* Test broadcast matmul
* projections SSW with einsum
* reduce number of samples in test wasserstein_circle_unif
* Update releases.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r-- | test/test_1d_solver.py | 6 | ||||
-rw-r--r-- | test/test_backend.py | 15 | ||||
-rw-r--r-- | test/test_sliced.py | 32 |
3 files changed, 46 insertions, 7 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 21abd1d..075a415 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -279,7 +279,7 @@ def test_wasserstein1d_circle_devices(nx): def test_wasserstein_1d_unif_circle(): # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle n = 20 - m = 50000 + m = 1000 rng = np.random.RandomState(0) u = rng.rand(n,) @@ -298,8 +298,8 @@ def test_wasserstein_1d_unif_circle(): wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) # check loss is similar - np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) - np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2) def test_wasserstein1d_unif_circle_devices(nx): diff --git a/test/test_backend.py b/test/test_backend.py index 5351e52..fedc62f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -298,6 +298,8 @@ def test_empty_backend(): nx.transpose(M) with pytest.raises(NotImplementedError): nx.detach(M) + with pytest.raises(NotImplementedError): + nx.matmul(M, M.T) def test_func_backends(nx): @@ -308,6 +310,9 @@ def test_func_backends(nx): v = rnd.randn(3) val = np.array([1.0]) + M1 = rnd.randn(1, 2, 10, 10) + M2 = rnd.randn(3, 1, 10, 10) + # Sparse tensors test sp_row = np.array([0, 3, 1, 0, 3]) sp_col = np.array([0, 3, 1, 2, 2]) @@ -326,6 +331,9 @@ def test_func_backends(nx): SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) + M1b = nx.from_numpy(M1) + M2b = nx.from_numpy(M2) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -661,6 +669,13 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(B)) lst_name.append("detach B") + A = nx.matmul(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matmul") + A = nx.matmul(M1b, M2b) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matmul broadcast") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index 7b7437a..6d5a27b 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -295,6 +295,26 @@ def test_sliced_sphere_same_dist(): np.testing.assert_almost_equal(res, 0.) +def test_sliced_sphere_same_proj(): + n_projections = 10 + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + seed = 42 + + cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + + assert np.allclose(log1['projections'], log2['projections']) + assert np.isclose(cost1, cost2) + + def test_sliced_sphere_bad_shapes(): n = 100 rng = np.random.RandomState(0) @@ -398,28 +418,32 @@ def test_sliced_sphere_backend_type_devices(nx): y = rng.randn(2 * n, 3) y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + sw_np, log = ot.sliced_wasserstein_sphere(x, y, log=True) + P = log["projections"] + for tp in nx.__type_list__: print(nx.dtype_device(tp)) xb, yb = nx.from_numpy(x, y, type_as=tp) - valb = ot.sliced_wasserstein_sphere(xb, yb) + valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp)) nx.assert_same_dtype_device(xb, valb) + np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) def test_sliced_sphere_gradient(): if torch: import torch.nn.functional as F - X0 = torch.randn((500, 3)) + X0 = torch.randn((20, 3)) X0 = F.normalize(X0, p=2, dim=-1) X0.requires_grad_(True) - X1 = torch.randn((500, 3)) + X1 = torch.randn((20, 3)) X1 = F.normalize(X1, p=2, dim=-1) - sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2) + sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=100, p=2) grad_x0 = torch.autograd.grad(sw, X0)[0] assert not torch.any(torch.isnan(grad_x0)) |