From 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Thu, 23 Feb 2023 08:31:01 +0100 Subject: [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434) * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- test/test_1d_solver.py | 127 +++++++++++++++++++++++++++++++++ test/test_backend.py | 46 ++++++++++++ test/test_sliced.py | 186 +++++++++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 10 +++ 4 files changed, 369 insertions(+) (limited to 'test') diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a..21abd1d 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,130 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) diff --git a/test/test_backend.py b/test/test_backend.py index 3628f61..fd9a761 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -282,6 +282,20 @@ def test_empty_backend(): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M) def test_func_backends(nx): @@ -603,6 +617,38 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_sliced.py b/test/test_sliced.py index eb13469..f54c799 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -266,3 +266,189 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") + + +def test_projections_stiefel(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = np.random.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, + seed=rng, log=True) + + P = log["projections"] + P_T = np.transpose(P, [0, 2, 1]) + np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + + +def test_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2) + expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2) + + res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) + expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) + + np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res1, expected1, decimal=3) + + +@pytest.skip_backend("tf") +def test_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.sliced_wasserstein_sphere(xb, yb) + + nx.assert_same_dtype_device(xb, valb) + + +def test_sliced_sphere_unif_values_on_the_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng) + + +def test_sliced_sphere_unif_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_unif_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + + valb = ot.sliced_wasserstein_sphere_unif(xb) + + nx.assert_same_dtype_device(xb, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 666c157..31b12ef 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -330,3 +330,13 @@ def test_OTResult(): for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at) + + +def test_get_coordinate_circle(): + + u = np.random.rand(1, 100) + x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi)) + x = np.concatenate([x1, y1]).T + x_p = ot.utils.get_coordinate_circle(x) + + np.testing.assert_allclose(u[0], x_p) -- cgit v1.2.3