summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-02-23 08:31:01 +0100
committerGitHub <noreply@github.com>2023-02-23 08:31:01 +0100
commit80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch)
treee4c2e938896243842e290d8fcf78879a8f6960bf /test
parent97feeb32b6c069d7bb44cd995531c2b820d59771 (diff)
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'test')
-rw-r--r--test/test_1d_solver.py127
-rw-r--r--test/test_backend.py46
-rw-r--r--test/test_sliced.py186
-rw-r--r--test/test_utils.py10
4 files changed, 369 insertions, 0 deletions
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 20f307a..21abd1d 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -218,3 +218,130 @@ def test_emd1d_device_tf():
nx.assert_same_dtype_device(xb, emd)
nx.assert_same_dtype_device(xb, emd2)
assert nx.dtype_device(emd)[1].startswith("GPU")
+
+
+def test_wasserstein_1d_circle():
+ # test binary_search_circle and wasserstein_circle give similar results as emd
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ w_u = rng.uniform(0., 1., n)
+ w_u = w_u / w_u.sum()
+
+ w_v = rng.uniform(0., 1., m)
+ w_v = w_v / w_v.sum()
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+
+ wass1 = ot.emd2(w_u, w_v, M1)
+
+ wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1)
+ w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1)
+
+ M2 = M1**2
+ wass2 = ot.emd2(w_u, w_v, M2)
+ wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2)
+ w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass1, wass1_bsc)
+ np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2)
+ np.testing.assert_allclose(wass2, wass2_bsc)
+ np.testing.assert_allclose(wass2, w2_circle)
+
+
+@pytest.skip_backend("tf")
+def test_wasserstein1d_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+ rho_v = np.abs(rng.randn(n))
+ rho_v /= rho_v.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
+
+ w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1)
+ w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2)
+
+ nx.assert_same_dtype_device(xb, w1)
+ nx.assert_same_dtype_device(xb, w2_bsc)
+
+
+def test_wasserstein_1d_unif_circle():
+ # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
+ n = 20
+ m = 50000
+
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ # w_u = rng.uniform(0., 1., n)
+ # w_u = w_u / w_u.sum()
+
+ w_u = ot.utils.unif(n)
+ w_v = ot.utils.unif(m)
+
+ M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None]))
+ wass2 = ot.emd2(w_u, w_v, M1**2)
+
+ wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15)
+ wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
+ np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
+
+
+def test_wasserstein1d_unif_circle_devices(nx):
+ rng = np.random.RandomState(0)
+
+ n = 10
+ x = np.linspace(0, 1, n)
+ rho_u = np.abs(rng.randn(n))
+ rho_u /= rho_u.sum()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp)
+
+ w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub)
+
+ nx.assert_same_dtype_device(xb, w2)
+
+
+def test_binary_search_circle_log():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n,)
+ v = rng.rand(m,)
+
+ wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True)
+ optimal_thetas = log["optimal_theta"]
+
+ assert optimal_thetas.shape[0] == 1
+
+
+def test_wasserstein_circle_bad_shape():
+ n = 20
+ m = 30
+ rng = np.random.RandomState(0)
+ u = rng.rand(n, 2)
+ v = rng.rand(m, 1)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=2)
+
+ with pytest.raises(ValueError):
+ _ = ot.wasserstein_circle(u, v, p=1)
diff --git a/test/test_backend.py b/test/test_backend.py
index 3628f61..fd9a761 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -282,6 +282,20 @@ def test_empty_backend():
nx.array_equal(M, M)
with pytest.raises(NotImplementedError):
nx.is_floating_point(M)
+ with pytest.raises(NotImplementedError):
+ nx.tile(M, (10, 1))
+ with pytest.raises(NotImplementedError):
+ nx.floor(M)
+ with pytest.raises(NotImplementedError):
+ nx.prod(M)
+ with pytest.raises(NotImplementedError):
+ nx.sort2(M)
+ with pytest.raises(NotImplementedError):
+ nx.qr(M)
+ with pytest.raises(NotImplementedError):
+ nx.atan2(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.transpose(M)
def test_func_backends(nx):
@@ -603,6 +617,38 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("isfinite")
+ A = nx.tile(vb, (10, 1))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("tile")
+
+ A = nx.floor(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("floor")
+
+ A = nx.prod(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("prod")
+
+ A, B = nx.sort2(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("sort2 sort")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("sort2 argsort")
+
+ A, B = nx.qr(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("QR Q")
+ lst_b.append(nx.to_numpy(B))
+ lst_name.append("QR R")
+
+ A = nx.atan2(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("atan2")
+
+ A = nx.transpose(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("transpose")
+
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
assert not nx.array_equal(
diff --git a/test/test_sliced.py b/test/test_sliced.py
index eb13469..f54c799 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -266,3 +266,189 @@ def test_max_sliced_backend_device_tf():
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
+def test_projections_stiefel():
+ rng = np.random.RandomState(0)
+
+ n_projs = 500
+ x = np.random.randn(100, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs,
+ seed=rng, log=True)
+
+ P = log["projections"]
+ P_T = np.transpose(P, [0, 2, 1])
+ np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]))
+
+
+def test_sliced_sphere_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_sphere_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_sphere_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+
+ y = rng.randn(m, 2)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi)
+ u = ot.utils.unif(m)
+
+ res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2)
+ expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2)
+
+ res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1)
+ expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1)
+
+ np.testing.assert_almost_equal(res ** 2, expected)
+ np.testing.assert_almost_equal(res1, expected1, decimal=3)
+
+
+@pytest.skip_backend("tf")
+def test_sliced_sphere_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(2 * n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, yb = nx.from_numpy(x, y, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere(xb, yb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_sliced_sphere_unif_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng)
+
+
+def test_sliced_sphere_unif_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_unif_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere_unif(xb)
+
+ nx.assert_same_dtype_device(xb, valb)
diff --git a/test/test_utils.py b/test/test_utils.py
index 666c157..31b12ef 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -330,3 +330,13 @@ def test_OTResult():
for at in lst_attributes:
with pytest.raises(NotImplementedError):
getattr(res, at)
+
+
+def test_get_coordinate_circle():
+
+ u = np.random.rand(1, 100)
+ x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi))
+ x = np.concatenate([x1, y1]).T
+ x_p = ot.utils.get_coordinate_circle(x)
+
+ np.testing.assert_allclose(u[0], x_p)