summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 15:19:57 +0100
committerGitHub <noreply@github.com>2021-11-04 15:19:57 +0100
commit0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch)
tree22a447a1dbb1505b18f9e426e1761cf6b328b6eb
parent2fe69eb130827560ada704bc25998397c4357821 (diff)
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8
-rw-r--r--ot/backend.py20
-rw-r--r--ot/lp/solver_1d.py14
-rw-r--r--test/test_1d_solver.py93
-rw-r--r--test/test_ot.py67
4 files changed, 146 insertions, 48 deletions
diff --git a/ot/backend.py b/ot/backend.py
index d3df44c..55e10d3 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -102,6 +102,7 @@ class Backend():
__name__ = None
__type__ = None
+ __type_list__ = None
rng_ = None
@@ -663,6 +664,8 @@ class NumpyBackend(Backend):
__name__ = 'numpy'
__type__ = np.ndarray
+ __type_list__ = [np.array(1, dtype=np.float32),
+ np.array(1, dtype=np.float64)]
rng_ = np.random.RandomState()
@@ -888,12 +891,17 @@ class JaxBackend(Backend):
__name__ = 'jax'
__type__ = jax_type
+ __type_list__ = None
rng_ = None
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
+ for d in jax.devices():
+ self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
+ jax.device_put(jnp.array(1, dtype=np.float64), d)]
+
def to_numpy(self, a):
return np.array(a)
@@ -901,7 +909,7 @@ class JaxBackend(Backend):
if type_as is None:
return jnp.array(a)
else:
- return jnp.array(a).astype(type_as.dtype)
+ return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())
def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
@@ -1130,6 +1138,7 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ __type_list__ = None
rng_ = None
@@ -1138,6 +1147,13 @@ class TorchBackend(Backend):
self.rng_ = torch.Generator()
self.rng_.seed()
+ self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
+ torch.tensor(1, dtype=torch.float64)]
+
+ if torch.cuda.is_available():
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+
from torch.autograd import Function
# define a function that takes inputs val and grads
@@ -1160,6 +1176,8 @@ class TorchBackend(Backend):
return a.cpu().detach().numpy()
def from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
else:
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 42554aa..8b4d0c3 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
# ensure that same mass
np.testing.assert_almost_equal(
- nx.sum(a, axis=0),
- nx.sum(b, axis=0),
+ nx.to_numpy(nx.sum(a, axis=0)),
+ nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum'
)
b = b * nx.sum(a) / nx.sum(b)
@@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
perm_b = nx.argsort(x_b_1d)
G_sorted, indices, cost = emd_1d_sorted(
- nx.to_numpy(a[perm_a]),
- nx.to_numpy(b[perm_b]),
- nx.to_numpy(x_a_1d[perm_a]),
- nx.to_numpy(x_b_1d[perm_b]),
+ nx.to_numpy(a[perm_a]).astype(np.float64),
+ nx.to_numpy(b[perm_b]).astype(np.float64),
+ nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
+ nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
metric=metric, p=p
)
@@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
elif str(nx) == "jax":
warnings.warn("JAX does not support sparse matrices, converting to dense")
if log:
- log = {'cost': cost}
+ log = {'cost': nx.from_numpy(cost, type_as=x_a)}
return G, log
return G
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 2c470c2..77b1234 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -83,3 +83,96 @@ def test_wasserstein_1d(nx):
Xb = nx.from_numpy(X)
res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+
+ if not str(nx) == 'numpy':
+ assert res.dtype == xb.dtype
+
+
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_emd1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+
+ assert emd.dtype == xb.dtype
+ if not str(nx) == 'numpy':
+ assert emd2.dtype == xb.dtype
diff --git a/test/test_ot.py b/test/test_ot.py
index 5bfde1d..dc3930a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,7 +12,6 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch
-from scipy.stats import wasserstein_distance
def test_emd_dimension_and_mass_mismatch():
@@ -77,6 +76,33 @@ def test_emd2_backends(nx):
np.allclose(val, nx.to_numpy(valb))
+def test_emd_emd2_types_devices(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ ab = nx.from_numpy(a, type_as=tp)
+ Mb = nx.from_numpy(M, type_as=tp)
+
+ Gb = ot.emd(ab, ab, Mb)
+
+ w = ot.emd2(ab, ab, Mb)
+
+ assert Gb.dtype == Mb.dtype
+ if not str(nx) == 'numpy':
+ assert w.dtype == Mb.dtype
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2
@@ -126,45 +152,6 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)
-def test_emd_1d_emd2_1d():
- # test emd1d gives similar results as emd
- n = 20
- m = 30
- rng = np.random.RandomState(0)
- u = rng.randn(n, 1)
- v = rng.randn(m, 1)
-
- M = ot.dist(u, v, metric='sqeuclidean')
-
- G, log = ot.emd([], [], M, log=True)
- wass = log["cost"]
- G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
- wass1d = log["cost"]
- wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
- wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
-
- # check loss is similar
- np.testing.assert_allclose(wass, wass1d)
- np.testing.assert_allclose(wass, wass1d_emd2)
-
- # check loss is similar to scipy's implementation for Euclidean metric
- wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
- np.testing.assert_allclose(wass_sp, wass1d_euc)
-
- # check constraints
- np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
- np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
-
- # check G is similar
- np.testing.assert_allclose(G, G_1d, atol=1e-15)
-
- # check AssertionError is raised if called on non 1d arrays
- u = np.random.randn(n, 2)
- v = np.random.randn(m, 2)
- with pytest.raises(AssertionError):
- ot.emd_1d(u, v, [], [])
-
-
def test_emd_empty():
# test emd and emd2 for simple identity
n = 100