diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py index 74f8366..0dd6fb8 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1694,10 +1694,12 @@ class TorchBackend(Backend): self.ValFunction = ValFunction def _to_numpy(self, a): + if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): + return np.array(a) return a.cpu().detach().numpy() def _from_numpy(self, a, type_as=None): - if isinstance(a, float): + if isinstance(a, float) or isinstance(a, int): a = np.array(a) if type_as is None: return torch.from_numpy(a) @@ -2501,6 +2503,8 @@ class TensorflowBackend(Backend): ) def _to_numpy(self, a): + if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): + return np.array(a) return a.numpy() def _from_numpy(self, a, type_as=None): |