summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py
index eecf9dd..d661c74 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -959,6 +959,14 @@ class Backend():
"""
raise NotImplementedError()
+ def matmul(self, a, b):
+ r"""
+ Matrix product of two arrays.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1293,6 +1301,9 @@ class NumpyBackend(Backend):
return args[0]
return args
+ def matmul(self, a, b):
+ return np.matmul(a, b)
+
class JaxBackend(Backend):
"""
@@ -1645,6 +1656,9 @@ class JaxBackend(Backend):
return jax.lax.stop_gradient((args[0],))[0]
return [jax.lax.stop_gradient((a,))[0] for a in args]
+ def matmul(self, a, b):
+ return jnp.matmul(a, b)
+
class TorchBackend(Backend):
"""
@@ -2098,6 +2112,9 @@ class TorchBackend(Backend):
return args[0].detach()
return [a.detach() for a in args]
+ def matmul(self, a, b):
+ return torch.matmul(a, b)
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2474,6 +2491,9 @@ class CupyBackend(Backend): # pragma: no cover
return args[0]
return args
+ def matmul(self, a, b):
+ return cp.matmul(a, b)
+
class TensorflowBackend(Backend):
@@ -2865,3 +2885,6 @@ class TensorflowBackend(Backend):
if len(args) == 1:
return tf.stop_gradient(args[0])
return [tf.stop_gradient(a) for a in args]
+
+ def matmul(self, a, b):
+ return tnp.matmul(a, b)