summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/backend.py20
-rw-r--r--ot/lp/solver_1d.py14
2 files changed, 26 insertions, 8 deletions
diff --git a/ot/backend.py b/ot/backend.py
index d3df44c..55e10d3 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -102,6 +102,7 @@ class Backend():
__name__ = None
__type__ = None
+ __type_list__ = None
rng_ = None
@@ -663,6 +664,8 @@ class NumpyBackend(Backend):
__name__ = 'numpy'
__type__ = np.ndarray
+ __type_list__ = [np.array(1, dtype=np.float32),
+ np.array(1, dtype=np.float64)]
rng_ = np.random.RandomState()
@@ -888,12 +891,17 @@ class JaxBackend(Backend):
__name__ = 'jax'
__type__ = jax_type
+ __type_list__ = None
rng_ = None
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
+ for d in jax.devices():
+ self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
+ jax.device_put(jnp.array(1, dtype=np.float64), d)]
+
def to_numpy(self, a):
return np.array(a)
@@ -901,7 +909,7 @@ class JaxBackend(Backend):
if type_as is None:
return jnp.array(a)
else:
- return jnp.array(a).astype(type_as.dtype)
+ return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())
def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
@@ -1130,6 +1138,7 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ __type_list__ = None
rng_ = None
@@ -1138,6 +1147,13 @@ class TorchBackend(Backend):
self.rng_ = torch.Generator()
self.rng_.seed()
+ self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
+ torch.tensor(1, dtype=torch.float64)]
+
+ if torch.cuda.is_available():
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+
from torch.autograd import Function
# define a function that takes inputs val and grads
@@ -1160,6 +1176,8 @@ class TorchBackend(Backend):
return a.cpu().detach().numpy()
def from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
else:
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 42554aa..8b4d0c3 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
# ensure that same mass
np.testing.assert_almost_equal(
- nx.sum(a, axis=0),
- nx.sum(b, axis=0),
+ nx.to_numpy(nx.sum(a, axis=0)),
+ nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum'
)
b = b * nx.sum(a) / nx.sum(b)
@@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
perm_b = nx.argsort(x_b_1d)
G_sorted, indices, cost = emd_1d_sorted(
- nx.to_numpy(a[perm_a]),
- nx.to_numpy(b[perm_b]),
- nx.to_numpy(x_a_1d[perm_a]),
- nx.to_numpy(x_b_1d[perm_b]),
+ nx.to_numpy(a[perm_a]).astype(np.float64),
+ nx.to_numpy(b[perm_b]).astype(np.float64),
+ nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
+ nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
metric=metric, p=p
)
@@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
elif str(nx) == "jax":
warnings.warn("JAX does not support sparse matrices, converting to dense")
if log:
- log = {'cost': cost}
+ log = {'cost': nx.from_numpy(cost, type_as=x_a)}
return G, log
return G