From 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Thu, 23 Feb 2023 08:31:01 +0100 Subject: [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434) * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- test/test_1d_solver.py | 127 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) (limited to 'test/test_1d_solver.py') diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a..21abd1d 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,130 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) -- cgit v1.2.3