diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-04-18 18:01:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-18 18:01:19 +0200 |
commit | 9aa96c8247afd6e98d8bd470a6adb1be0f1c467e (patch) | |
tree | 3f213c8d844d6f24f88c83deebec55f45391e4f9 /test/test_sliced.py | |
parent | 1078dcc3530a7f95fd77d19d115d46f39c2574bc (diff) |
[MRG] Fix Bug binary_search_circle on GPU and Gradients (#457)
* W circle + SSW
* Tests + Example SSW_1
* Example Wasserstein Circle + Tests
* Wasserstein on the circle wrt Unif
* Example SSW unif
* pep8
* np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests
* np qr
* rm test python 3.11
* update names, tests, backend transpose
* Comment error batchs
* semidiscrete_wasserstein2_unif_circle example
* torch permute method instead of torch.permute for previous versions
* update comments and doc
* doc wasserstein circle model as [0,1[
* Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
* Bug cuda w_circle + gradient ssw
* Bug cuda w_circle + gradient ssw
* backend detach
* Add PR in Releases.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r-- | test/test_sliced.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py index f54c799..7b7437a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -10,7 +10,7 @@ import pytest import ot from ot.sliced import get_random_projections -from ot.backend import tf +from ot.backend import tf, torch def test_get_random_projections(): @@ -408,6 +408,23 @@ def test_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) +def test_sliced_sphere_gradient(): + if torch: + import torch.nn.functional as F + + X0 = torch.randn((500, 3)) + X0 = F.normalize(X0, p=2, dim=-1) + X0.requires_grad_(True) + + X1 = torch.randn((500, 3)) + X1 = F.normalize(X1, p=2, dim=-1) + + sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2) + grad_x0 = torch.autograd.grad(sw, X0)[0] + + assert not torch.any(torch.isnan(grad_x0)) + + def test_sliced_sphere_unif_values_on_the_sphere(): n = 100 rng = np.random.RandomState(0) |