diff options
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r-- | test/test_sliced.py | 32 |
1 files changed, 28 insertions, 4 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py index 7b7437a..6d5a27b 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -295,6 +295,26 @@ def test_sliced_sphere_same_dist(): np.testing.assert_almost_equal(res, 0.) +def test_sliced_sphere_same_proj(): + n_projections = 10 + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + seed = 42 + + cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True) + + assert np.allclose(log1['projections'], log2['projections']) + assert np.isclose(cost1, cost2) + + def test_sliced_sphere_bad_shapes(): n = 100 rng = np.random.RandomState(0) @@ -398,28 +418,32 @@ def test_sliced_sphere_backend_type_devices(nx): y = rng.randn(2 * n, 3) y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + sw_np, log = ot.sliced_wasserstein_sphere(x, y, log=True) + P = log["projections"] + for tp in nx.__type_list__: print(nx.dtype_device(tp)) xb, yb = nx.from_numpy(x, y, type_as=tp) - valb = ot.sliced_wasserstein_sphere(xb, yb) + valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp)) nx.assert_same_dtype_device(xb, valb) + np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) def test_sliced_sphere_gradient(): if torch: import torch.nn.functional as F - X0 = torch.randn((500, 3)) + X0 = torch.randn((20, 3)) X0 = F.normalize(X0, p=2, dim=-1) X0.requires_grad_(True) - X1 = torch.randn((500, 3)) + X1 = torch.randn((20, 3)) X1 = F.normalize(X1, p=2, dim=-1) - sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2) + sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=100, p=2) grad_x0 = torch.autograd.grad(sw, X0)[0] assert not torch.any(torch.isnan(grad_x0)) |