summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 08ab4fb..f54c799 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
assert res > 0.
+def test_sliced_same_proj():
+ n_projections = 10
+ seed = 12
+ rng = np.random.RandomState(0)
+ X = rng.randn(8, 2)
+ Y = rng.randn(8, 2)
+ cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
+ P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
+ cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
+
+ assert np.allclose(log1['projections'], log2['projections'])
+ assert np.isclose(cost1, cost2)
+
+
def test_sliced_backend(nx):
n = 100
@@ -252,3 +266,189 @@ def test_max_sliced_backend_device_tf():
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
+
+
+def test_projections_stiefel():
+ rng = np.random.RandomState(0)
+
+ n_projs = 500
+ x = np.random.randn(100, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs,
+ seed=rng, log=True)
+
+ P = log["projections"]
+ P_T = np.transpose(P, [0, 2, 1])
+ np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)]))
+
+
+def test_sliced_sphere_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_sphere_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(n, 4)
+
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_sphere_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ y = rng.randn(n, 4)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_sphere_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+
+ y = rng.randn(m, 2)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+ y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi)
+ u = ot.utils.unif(m)
+
+ res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2)
+ expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2)
+
+ res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1)
+ expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1)
+
+ np.testing.assert_almost_equal(res ** 2, expected)
+ np.testing.assert_almost_equal(res1, expected1, decimal=3)
+
+
+@pytest.skip_backend("tf")
+def test_sliced_sphere_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ y = rng.randn(2 * n, 3)
+ y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb, yb = nx.from_numpy(x, y, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere(xb, yb)
+
+ nx.assert_same_dtype_device(xb, valb)
+
+
+def test_sliced_sphere_unif_values_on_the_sphere():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng)
+
+
+def test_sliced_sphere_unif_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert projections.shape[0] == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_sphere_unif_backend_type_devices(nx):
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 3)
+ x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ xb = nx.from_numpy(x, type_as=tp)
+
+ valb = ot.sliced_wasserstein_sphere_unif(xb)
+
+ nx.assert_same_dtype_device(xb, valb)