summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-04-24 17:54:03 +0200
committerGitHub <noreply@github.com>2023-04-24 17:54:03 +0200
commit03ca4ef659a037e400975e3b2116b637a2d94265 (patch)
tree2fff6add4b430a9bb97cf594786777c7e48ea5a5 /ot/backend.py
parent25d72db09ed281c13b97aa8a68d82a4ed5ba7bf0 (diff)
[MRG] make alpha parameter in FGW diferentiable (#463)
* make alpha diferentiable * update release file * debug tensorflow to_numpy
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 74f8366..0dd6fb8 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1694,10 +1694,12 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
def _to_numpy(self, a):
+ if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
+ return np.array(a)
return a.cpu().detach().numpy()
def _from_numpy(self, a, type_as=None):
- if isinstance(a, float):
+ if isinstance(a, float) or isinstance(a, int):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
@@ -2501,6 +2503,8 @@ class TensorflowBackend(Backend):
)
def _to_numpy(self, a):
+ if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
+ return np.array(a)
return a.numpy()
def _from_numpy(self, a, type_as=None):