"""Tests for module sliced""" # Author: Adrien Corenflos # # License: MIT License import numpy as np import pytest import ot from ot.sliced import get_random_projections def test_get_random_projections(): rng = np.random.RandomState(0) projections = get_random_projections(1000, 50, rng) np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) def test_sliced_same_dist(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) u = ot.utils.unif(n) res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) np.testing.assert_almost_equal(res, 0.) def test_sliced_bad_shapes(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) y = rng.randn(n, 4) u = ot.utils.unif(n) with pytest.raises(ValueError): _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) def test_sliced_log(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 4) y = rng.randn(n, 4) u = ot.utils.unif(n) res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) assert len(log) == 2 projections = log["projections"] projected_emds = log["projected_emds"] assert len(projections) == len(projected_emds) == 10 for emd in projected_emds: assert emd > 0 def test_sliced_different_dists(): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) u = ot.utils.unif(n) y = rng.randn(n, 2) res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) assert res > 0. def test_1d_sliced_equals_emd(): n = 100 m = 120 rng = np.random.RandomState(0) x = rng.randn(n, 1) a = rng.uniform(0, 1, n) a /= a.sum() y = rng.randn(m, 1) u = ot.utils.unif(m) res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res ** 2, expected)