diff options
author | AdrienCorenflos <adrien.corenflos@gmail.com> | 2020-10-22 09:28:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-22 10:28:53 +0200 |
commit | 78b44af2434f494c8f9e4c8c91003fbc0e1d4415 (patch) | |
tree | 013002f0a65918cee5eb95648965d4361f0c3dc2 /test/test_sliced.py | |
parent | 7adc1b1aa73c55dc07983ff08dcb23fd71e9e8b6 (diff) |
[MRG] Sliced wasserstein (#203)
* example for log treatment in bregman.py
* Improve doc
* Revert "example for log treatment in bregman.py"
This reverts commit 9f51c14e
* Add comments by Flamary
* Delete repetitive description
* Added raw string to avoid pbs with backslashes
* Implements sliced wasserstein
* Changed formatting of string for py3.5 support
* Docstest, expected 0.0 and not 0.
* Adressed comments by @rflamary
* No 3d plot here
* add sliced to the docs
* Incorporate comments by @rflamary
* add link to pdf
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r-- | test/test_sliced.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py new file mode 100644 index 0000000..a07d975 --- /dev/null +++ b/test/test_sliced.py @@ -0,0 +1,85 @@ +"""Tests for module sliced""" + +# Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.sliced import get_random_projections + + +def test_get_random_projections(): + rng = np.random.RandomState(0) + projections = get_random_projections(1000, 50, rng) + np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + + +def test_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + + +def test_sliced_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert len(projections) == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 1) + a = rng.uniform(0, 1, n) + a /= a.sum() + y = rng.randn(m, 1) + u = ot.utils.unif(m) + res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) + expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) + np.testing.assert_almost_equal(res ** 2, expected) |