summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-04-18 18:01:19 +0200
committerGitHub <noreply@github.com>2023-04-18 18:01:19 +0200
commit9aa96c8247afd6e98d8bd470a6adb1be0f1c467e (patch)
tree3f213c8d844d6f24f88c83deebec55f45391e4f9 /test
parent1078dcc3530a7f95fd77d19d115d46f39c2574bc (diff)
[MRG] Fix Bug binary_search_circle on GPU and Gradients (#457)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn * Bug cuda w_circle + gradient ssw * Bug cuda w_circle + gradient ssw * backend detach * Add PR in Releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_backend.py12
-rw-r--r--test/test_sliced.py19
2 files changed, 30 insertions, 1 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index fd9a761..5351e52 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -296,6 +296,8 @@ def test_empty_backend():
nx.atan2(v, v)
with pytest.raises(NotImplementedError):
nx.transpose(M)
+ with pytest.raises(NotImplementedError):
+ nx.detach(M)
def test_func_backends(nx):
@@ -649,6 +651,16 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("transpose")
+ A = nx.detach(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("detach")
+
+ A, B = nx.detach(Mb, Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("detach A")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("detach B")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_sliced.py b/test/test_sliced.py
index f54c799..7b7437a 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -10,7 +10,7 @@ import pytest
import ot
from ot.sliced import get_random_projections
-from ot.backend import tf
+from ot.backend import tf, torch
def test_get_random_projections():
@@ -408,6 +408,23 @@ def test_sliced_sphere_backend_type_devices(nx):
nx.assert_same_dtype_device(xb, valb)
+def test_sliced_sphere_gradient():
+ if torch:
+ import torch.nn.functional as F
+
+ X0 = torch.randn((500, 3))
+ X0 = F.normalize(X0, p=2, dim=-1)
+ X0.requires_grad_(True)
+
+ X1 = torch.randn((500, 3))
+ X1 = F.normalize(X1, p=2, dim=-1)
+
+ sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2)
+ grad_x0 = torch.autograd.grad(sw, X0)[0]
+
+ assert not torch.any(torch.isnan(grad_x0))
+
+
def test_sliced_sphere_unif_values_on_the_sphere():
n = 100
rng = np.random.RandomState(0)