diff options
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r-- | test/test_sliced.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py index f54c799..7b7437a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -10,7 +10,7 @@ import pytest import ot from ot.sliced import get_random_projections -from ot.backend import tf +from ot.backend import tf, torch def test_get_random_projections(): @@ -408,6 +408,23 @@ def test_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) +def test_sliced_sphere_gradient(): + if torch: + import torch.nn.functional as F + + X0 = torch.randn((500, 3)) + X0 = F.normalize(X0, p=2, dim=-1) + X0.requires_grad_(True) + + X1 = torch.randn((500, 3)) + X1 = F.normalize(X1, p=2, dim=-1) + + sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2) + grad_x0 = torch.autograd.grad(sw, X0)[0] + + assert not torch.any(torch.isnan(grad_x0)) + + def test_sliced_sphere_unif_values_on_the_sphere(): n = 100 rng = np.random.RandomState(0) |