summaryrefslogtreecommitdiff
path: root/test/test_1d_solver.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-04 15:19:57 +0100
committerGitHub <noreply@github.com>2021-11-04 15:19:57 +0100
commit0e431c203a66c6d48e6bb1efeda149460472a0f0 (patch)
tree22a447a1dbb1505b18f9e426e1761cf6b328b6eb /test/test_1d_solver.py
parent2fe69eb130827560ada704bc25998397c4357821 (diff)
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8
Diffstat (limited to 'test/test_1d_solver.py')
-rw-r--r--test/test_1d_solver.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 2c470c2..77b1234 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -83,3 +83,96 @@ def test_wasserstein_1d(nx):
Xb = nx.from_numpy(X)
res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_wasserstein_1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
+
+ if not str(nx) == 'numpy':
+ assert res.dtype == xb.dtype
+
+
+def test_emd_1d_emd2_1d():
+ # test emd1d gives similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.randn(n, 1)
+ v = rng.randn(m, 1)
+
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd([], [], M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
+ np.testing.assert_allclose(wass_sp, wass1d_euc)
+
+ # check constraints
+ np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
+ np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
+
+ # check G is similar
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
+
+ # check AssertionError is raised if called on non 1d arrays
+ u = np.random.randn(n, 2)
+ v = np.random.randn(m, 2)
+ with pytest.raises(AssertionError):
+ ot.emd_1d(u, v, [], [])
+
+
+def test_emd1d_type_devices(nx):
+
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 5, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+
+ print(tp.dtype)
+
+ xb = nx.from_numpy(x, type_as=tp)
+ rho_ub = nx.from_numpy(rho_u, type_as=tp)
+ rho_vb = nx.from_numpy(rho_v, type_as=tp)
+
+ emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
+
+ emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
+
+ assert emd.dtype == xb.dtype
+ if not str(nx) == 'numpy':
+ assert emd2.dtype == xb.dtype