diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2023-04-24 17:54:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-24 17:54:03 +0200 |
commit | 03ca4ef659a037e400975e3b2116b637a2d94265 (patch) | |
tree | 2fff6add4b430a9bb97cf594786777c7e48ea5a5 /ot/backend.py | |
parent | 25d72db09ed281c13b97aa8a68d82a4ed5ba7bf0 (diff) |
[MRG] make alpha parameter in FGW diferentiable (#463)
* make alpha diferentiable
* update release file
* debug tensorflow to_numpy
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/ot/backend.py b/ot/backend.py index 74f8366..0dd6fb8 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1694,10 +1694,12 @@ class TorchBackend(Backend): self.ValFunction = ValFunction def _to_numpy(self, a): + if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): + return np.array(a) return a.cpu().detach().numpy() def _from_numpy(self, a, type_as=None): - if isinstance(a, float): + if isinstance(a, float) or isinstance(a, int): a = np.array(a) if type_as is None: return torch.from_numpy(a) @@ -2501,6 +2503,8 @@ class TensorflowBackend(Backend): ) def _to_numpy(self, a): + if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): + return np.array(a) return a.numpy() def _from_numpy(self, a, type_as=None): |