diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-05-05 10:53:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-05 10:53:48 +0200 |
commit | 7e0ea27ad9cad31cfc2181430d837c0a77a61568 (patch) | |
tree | 0a41128a975500bfef52a4c21b5af634adecc71a /ot/backend.py | |
parent | 83dc498b496087aea293df1445442d8728435211 (diff) |
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul
* typo error
* einsum projections ssw
* Test broadcast matmul
* einsum projections ssw
* Test broadcast matmul
* projections SSW with einsum
* reduce number of samples in test wasserstein_circle_unif
* Update releases.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py index eecf9dd..d661c74 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -959,6 +959,14 @@ class Backend(): """ raise NotImplementedError() + def matmul(self, a, b): + r""" + Matrix product of two arrays. + + See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1293,6 +1301,9 @@ class NumpyBackend(Backend): return args[0] return args + def matmul(self, a, b): + return np.matmul(a, b) + class JaxBackend(Backend): """ @@ -1645,6 +1656,9 @@ class JaxBackend(Backend): return jax.lax.stop_gradient((args[0],))[0] return [jax.lax.stop_gradient((a,))[0] for a in args] + def matmul(self, a, b): + return jnp.matmul(a, b) + class TorchBackend(Backend): """ @@ -2098,6 +2112,9 @@ class TorchBackend(Backend): return args[0].detach() return [a.detach() for a in args] + def matmul(self, a, b): + return torch.matmul(a, b) + class CupyBackend(Backend): # pragma: no cover """ @@ -2474,6 +2491,9 @@ class CupyBackend(Backend): # pragma: no cover return args[0] return args + def matmul(self, a, b): + return cp.matmul(a, b) + class TensorflowBackend(Backend): @@ -2865,3 +2885,6 @@ class TensorflowBackend(Backend): if len(args) == 1: return tf.stop_gradient(args[0]) return [tf.stop_gradient(a) for a in args] + + def matmul(self, a, b): + return tnp.matmul(a, b) |