From 9aa96c8247afd6e98d8bd470a6adb1be0f1c467e Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Tue, 18 Apr 2023 18:01:19 +0200 Subject: [MRG] Fix Bug binary_search_circle on GPU and Gradients (#457) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn * Bug cuda w_circle + gradient ssw * Bug cuda w_circle + gradient ssw * backend detach * Add PR in Releases.md --------- Co-authored-by: Rémi Flamary --- RELEASES.md | 1 + ot/backend.py | 33 +++++++++++++++++++++++++++++++++ ot/lp/solver_1d.py | 10 +++++----- ot/sliced.py | 2 +- test/test_backend.py | 12 ++++++++++++ test/test_sliced.py | 19 ++++++++++++++++++- 6 files changed, 70 insertions(+), 7 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 6e4188f..214cc2a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ #### Closed issues - Fix circleci-redirector action and codecov (PR #460) +- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457) ## 0.9.0 diff --git a/ot/backend.py b/ot/backend.py index 0779243..74f8366 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -951,6 +951,14 @@ class Backend(): """ raise NotImplementedError() + def detach(self, *args): + r""" + Detach tensors in arguments from the current graph. + + See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1279,6 +1287,11 @@ class NumpyBackend(Backend): def transpose(self, a, axes=None): return np.transpose(a, axes) + def detach(self, *args): + if len(args) == 1: + return args[0] + return args + class JaxBackend(Backend): """ @@ -1626,6 +1639,11 @@ class JaxBackend(Backend): def transpose(self, a, axes=None): return jnp.transpose(a, axes) + def detach(self, *args): + if len(args) == 1: + return jax.lax.stop_gradient((args[0],))[0] + return [jax.lax.stop_gradient((a,))[0] for a in args] + class TorchBackend(Backend): """ @@ -2072,6 +2090,11 @@ class TorchBackend(Backend): axes = tuple(range(a.ndim)[::-1]) return a.permute(axes) + def detach(self, *args): + if len(args) == 1: + return args[0].detach() + return [a.detach() for a in args] + class CupyBackend(Backend): # pragma: no cover """ @@ -2443,6 +2466,11 @@ class CupyBackend(Backend): # pragma: no cover def transpose(self, a, axes=None): return cp.transpose(a, axes) + def detach(self, *args): + if len(args) == 1: + return args[0] + return args + class TensorflowBackend(Backend): @@ -2826,3 +2854,8 @@ class TensorflowBackend(Backend): def transpose(self, a, axes=None): return tf.transpose(a, perm=axes) + + def detach(self, *args): + if len(args) == 1: + return tf.stop_gradient(args[0]) + return [tf.stop_gradient(a) for a in args] diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index bcfc920..840801a 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -401,7 +401,7 @@ def roll_cols(M, shifts): n_rows, n_cols = M.shape - arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1)) + arange1 = nx.tile(nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1)) arange2 = (arange1 - shifts) % n_cols return nx.take_along_axis(M, arange2, 1) @@ -600,7 +600,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 using e.g. ot.utils.get_coordinate_circle(x) - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. Parameters ---------- @@ -730,7 +730,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 - w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) if log: return w, {"optimal_theta": tc[:, 0]} @@ -743,7 +743,7 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates using e.g. the atan2 function. - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t @@ -864,7 +864,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, using e.g. ot.utils.get_coordinate_circle(x) - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. Parameters ---------- diff --git a/ot/sliced.py b/ot/sliced.py index 077ff0b..fa2141e 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -271,7 +271,7 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` - The function runs on backend but tensorflow is not supported. + The function runs on backend but tensorflow and jax are not supported. Parameters ---------- diff --git a/test/test_backend.py b/test/test_backend.py index fd9a761..5351e52 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -296,6 +296,8 @@ def test_empty_backend(): nx.atan2(v, v) with pytest.raises(NotImplementedError): nx.transpose(M) + with pytest.raises(NotImplementedError): + nx.detach(M) def test_func_backends(nx): @@ -649,6 +651,16 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") + A = nx.detach(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("detach") + + A, B = nx.detach(Mb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("detach A") + lst_b.append(nx.to_numpy(B)) + lst_name.append("detach B") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index f54c799..7b7437a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -10,7 +10,7 @@ import pytest import ot from ot.sliced import get_random_projections -from ot.backend import tf +from ot.backend import tf, torch def test_get_random_projections(): @@ -408,6 +408,23 @@ def test_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) +def test_sliced_sphere_gradient(): + if torch: + import torch.nn.functional as F + + X0 = torch.randn((500, 3)) + X0 = F.normalize(X0, p=2, dim=-1) + X0.requires_grad_(True) + + X1 = torch.randn((500, 3)) + X1 = F.normalize(X1, p=2, dim=-1) + + sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2) + grad_x0 = torch.autograd.grad(sw, X0)[0] + + assert not torch.any(torch.isnan(grad_x0)) + + def test_sliced_sphere_unif_values_on_the_sphere(): n = 100 rng = np.random.RandomState(0) -- cgit v1.2.3