diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/backend.py | 20 | ||||
-rw-r--r-- | ot/lp/solver_1d.py | 14 |
2 files changed, 26 insertions, 8 deletions
diff --git a/ot/backend.py b/ot/backend.py index d3df44c..55e10d3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -102,6 +102,7 @@ class Backend(): __name__ = None __type__ = None + __type_list__ = None rng_ = None @@ -663,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -888,12 +891,17 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + __type_list__ = None rng_ = None def __init__(self): self.rng_ = jax.random.PRNGKey(42) + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), + jax.device_put(jnp.array(1, dtype=np.float64), d)] + def to_numpy(self, a): return np.array(a) @@ -901,7 +909,7 @@ class JaxBackend(Backend): if type_as is None: return jnp.array(a) else: - return jnp.array(a).astype(type_as.dtype) + return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1130,6 +1138,7 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + __type_list__ = None rng_ = None @@ -1138,6 +1147,13 @@ class TorchBackend(Backend): self.rng_ = torch.Generator() self.rng_.seed() + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1160,6 +1176,8 @@ class TorchBackend(Backend): return a.cpu().detach().numpy() def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return torch.from_numpy(a) else: diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 42554aa..8b4d0c3 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, # ensure that same mass np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), err_msg='a and b vector must have the same sum' ) b = b * nx.sum(a) / nx.sum(b) @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p ) @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': cost} + log = {'cost': nx.from_numpy(cost, type_as=x_a)} return G, log return G |