summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 74f8366..0dd6fb8 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1694,10 +1694,12 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
def _to_numpy(self, a):
+ if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
+ return np.array(a)
return a.cpu().detach().numpy()
def _from_numpy(self, a, type_as=None):
- if isinstance(a, float):
+ if isinstance(a, float) or isinstance(a, int):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
@@ -2501,6 +2503,8 @@ class TensorflowBackend(Backend):
)
def _to_numpy(self, a):
+ if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
+ return np.array(a)
return a.numpy()
def _from_numpy(self, a, type_as=None):