diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-11-04 15:19:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-04 15:19:57 +0100 |
commit | 0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch) | |
tree | 22a447a1dbb1505b18f9e426e1761cf6b328b6eb | |
parent | 2fe69eb130827560ada704bc25998397c4357821 (diff) |
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu
* pep 8 of couse
* debug torch
* jax with gpu
* device put
* device put
* it works
* emd1d and emd2_1d working
* emd_1d and emd2_1d done
* cleanup
* of course
* should work on gpu now
* tests done+ pep8
-rw-r--r-- | ot/backend.py | 20 | ||||
-rw-r--r-- | ot/lp/solver_1d.py | 14 | ||||
-rw-r--r-- | test/test_1d_solver.py | 93 | ||||
-rw-r--r-- | test/test_ot.py | 67 |
4 files changed, 146 insertions, 48 deletions
diff --git a/ot/backend.py b/ot/backend.py index d3df44c..55e10d3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -102,6 +102,7 @@ class Backend(): __name__ = None __type__ = None + __type_list__ = None rng_ = None @@ -663,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -888,12 +891,17 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + __type_list__ = None rng_ = None def __init__(self): self.rng_ = jax.random.PRNGKey(42) + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), + jax.device_put(jnp.array(1, dtype=np.float64), d)] + def to_numpy(self, a): return np.array(a) @@ -901,7 +909,7 @@ class JaxBackend(Backend): if type_as is None: return jnp.array(a) else: - return jnp.array(a).astype(type_as.dtype) + return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1130,6 +1138,7 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + __type_list__ = None rng_ = None @@ -1138,6 +1147,13 @@ class TorchBackend(Backend): self.rng_ = torch.Generator() self.rng_.seed() + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1160,6 +1176,8 @@ class TorchBackend(Backend): return a.cpu().detach().numpy() def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return torch.from_numpy(a) else: diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 42554aa..8b4d0c3 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, # ensure that same mass np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), err_msg='a and b vector must have the same sum' ) b = b * nx.sum(a) / nx.sum(b) @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p ) @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': cost} + log = {'cost': nx.from_numpy(cost, type_as=x_a)} return G, log return G diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 2c470c2..77b1234 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -83,3 +83,96 @@ def test_wasserstein_1d(nx): Xb = nx.from_numpy(X) res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + + if not str(nx) == 'numpy': + assert res.dtype == xb.dtype + + +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d, atol=1e-15) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + assert emd.dtype == xb.dtype + if not str(nx) == 'numpy': + assert emd2.dtype == xb.dtype diff --git a/test/test_ot.py b/test/test_ot.py index 5bfde1d..dc3930a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,7 +12,6 @@ import pytest import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch -from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -77,6 +76,33 @@ def test_emd2_backends(nx): np.allclose(val, nx.to_numpy(valb)) +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + + print(tp.dtype) + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.emd(ab, ab, Mb) + + w = ot.emd2(ab, ab, Mb) + + assert Gb.dtype == Mb.dtype + if not str(nx) == 'numpy': + assert w.dtype == Mb.dtype + + def test_emd2_gradients(): n_samples = 100 n_features = 2 @@ -126,45 +152,6 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) - - # check G is similar - np.testing.assert_allclose(G, G_1d, atol=1e-15) - - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 |