diff options
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r-- | test/test_bregman.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 1ebd21f..7c5162a 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -9,6 +9,10 @@ import numpy as np import pytest import ot +from ot.backend import get_backend_list +from ot.backend import torch + +backend_list = get_backend_list() def test_sinkhorn(): @@ -30,6 +34,76 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn2(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.sinkhorn2(a1, b1, M1, 1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + + def test_sinkhorn_empty(): # test sinkhorn n = 100 |