summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index f54c799..7b7437a 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -10,7 +10,7 @@ import pytest
import ot
from ot.sliced import get_random_projections
-from ot.backend import tf
+from ot.backend import tf, torch
def test_get_random_projections():
@@ -408,6 +408,23 @@ def test_sliced_sphere_backend_type_devices(nx):
nx.assert_same_dtype_device(xb, valb)
+def test_sliced_sphere_gradient():
+ if torch:
+ import torch.nn.functional as F
+
+ X0 = torch.randn((500, 3))
+ X0 = F.normalize(X0, p=2, dim=-1)
+ X0.requires_grad_(True)
+
+ X1 = torch.randn((500, 3))
+ X1 = F.normalize(X1, p=2, dim=-1)
+
+ sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2)
+ grad_x0 = torch.autograd.grad(sw, X0)[0]
+
+ assert not torch.any(torch.isnan(grad_x0))
+
+
def test_sliced_sphere_unif_values_on_the_sphere():
n = 100
rng = np.random.RandomState(0)