diff options
Diffstat (limited to 'test/test_utils.py')
-rw-r--r-- | test/test_utils.py | 76 |
1 files changed, 67 insertions, 9 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index db9cda6..76b1faa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,11 +4,47 @@ # # License: MIT License - +import pytest import ot import numpy as np import sys +from ot.backend import get_backend_list + +backend_list = get_backend_list() + + +@pytest.mark.parametrize('nx', backend_list) +def test_proj_simplex(nx): + n = 10 + rng = np.random.RandomState(0) + + # test on matrix when projection is done on axis 0 + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # all projections should sum to 3 + proj = ot.utils.proj_simplex(x1, 3) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = 3 * np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # tets on vector + x = rng.randn(n) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + def test_parmap(): @@ -45,8 +81,8 @@ def test_tic_toc(): def test_kernel(): n = 100 - - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) K = ot.utils.kernel(x, x) @@ -67,7 +103,8 @@ def test_dist(): n = 100 - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) D = np.zeros((n, n)) for i in range(n): @@ -78,8 +115,27 @@ def test_dist(): D3 = ot.dist(x) # dist shoul return squared euclidean - np.testing.assert_allclose(D, D2) - np.testing.assert_allclose(D, D3) + np.testing.assert_allclose(D, D2, atol=1e-14) + np.testing.assert_allclose(D, D3, atol=1e-14) + + +@ pytest.mark.parametrize('nx', backend_list) +def test_dist_backends(nx): + + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + lst_metric = ['euclidean', 'sqeuclidean'] + + for metric in lst_metric: + + D = ot.dist(x, x, metric=metric) + D1 = ot.dist(x1, x1, metric=metric) + + # low atol because jax forces float32 + np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) def test_dist0(): @@ -95,9 +151,11 @@ def test_dots(): n1, n2, n3, n4 = 100, 50, 200, 100 - A = np.random.randn(n1, n2) - B = np.random.randn(n2, n3) - C = np.random.randn(n3, n4) + rng = np.random.RandomState(0) + + A = rng.randn(n1, n2) + B = rng.randn(n2, n3) + C = rng.randn(n3, n4) X1 = ot.utils.dots(A, B, C) |