summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-04-18 18:01:19 +0200
committerGitHub <noreply@github.com>2023-04-18 18:01:19 +0200
commit9aa96c8247afd6e98d8bd470a6adb1be0f1c467e (patch)
tree3f213c8d844d6f24f88c83deebec55f45391e4f9 /ot
parent1078dcc3530a7f95fd77d19d115d46f39c2574bc (diff)
[MRG] Fix Bug binary_search_circle on GPU and Gradients (#457)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn * Bug cuda w_circle + gradient ssw * Bug cuda w_circle + gradient ssw * backend detach * Add PR in Releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot')
-rw-r--r--ot/backend.py33
-rw-r--r--ot/lp/solver_1d.py10
-rw-r--r--ot/sliced.py2
3 files changed, 39 insertions, 6 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 0779243..74f8366 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -951,6 +951,14 @@ class Backend():
"""
raise NotImplementedError()
+ def detach(self, *args):
+ r"""
+ Detach tensors in arguments from the current graph.
+
+ See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1279,6 +1287,11 @@ class NumpyBackend(Backend):
def transpose(self, a, axes=None):
return np.transpose(a, axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return args[0]
+ return args
+
class JaxBackend(Backend):
"""
@@ -1626,6 +1639,11 @@ class JaxBackend(Backend):
def transpose(self, a, axes=None):
return jnp.transpose(a, axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return jax.lax.stop_gradient((args[0],))[0]
+ return [jax.lax.stop_gradient((a,))[0] for a in args]
+
class TorchBackend(Backend):
"""
@@ -2072,6 +2090,11 @@ class TorchBackend(Backend):
axes = tuple(range(a.ndim)[::-1])
return a.permute(axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return args[0].detach()
+ return [a.detach() for a in args]
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2443,6 +2466,11 @@ class CupyBackend(Backend): # pragma: no cover
def transpose(self, a, axes=None):
return cp.transpose(a, axes)
+ def detach(self, *args):
+ if len(args) == 1:
+ return args[0]
+ return args
+
class TensorflowBackend(Backend):
@@ -2826,3 +2854,8 @@ class TensorflowBackend(Backend):
def transpose(self, a, axes=None):
return tf.transpose(a, perm=axes)
+
+ def detach(self, *args):
+ if len(args) == 1:
+ return tf.stop_gradient(args[0])
+ return [tf.stop_gradient(a) for a in args]
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index bcfc920..840801a 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -401,7 +401,7 @@ def roll_cols(M, shifts):
n_rows, n_cols = M.shape
- arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1))
+ arange1 = nx.tile(nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1))
arange2 = (arange1 - shifts) % n_cols
return nx.take_along_axis(M, arange2, 1)
@@ -600,7 +600,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1
using e.g. ot.utils.get_coordinate_circle(x)
- The function runs on backend but tensorflow is not supported.
+ The function runs on backend but tensorflow and jax are not supported.
Parameters
----------
@@ -730,7 +730,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1
tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2
- w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+ w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p)
if log:
return w, {"optimal_theta": tc[:, 0]}
@@ -743,7 +743,7 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ
takes the value modulo 1.
If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
using e.g. the atan2 function.
- The function runs on backend but tensorflow is not supported.
+ The function runs on backend but tensorflow and jax are not supported.
.. math::
W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
@@ -864,7 +864,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
using e.g. ot.utils.get_coordinate_circle(x)
- The function runs on backend but tensorflow is not supported.
+ The function runs on backend but tensorflow and jax are not supported.
Parameters
----------
diff --git a/ot/sliced.py b/ot/sliced.py
index 077ff0b..fa2141e 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -271,7 +271,7 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
- :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`
- The function runs on backend but tensorflow is not supported.
+ The function runs on backend but tensorflow and jax are not supported.
Parameters
----------