summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 1ebd21f..7c5162a 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -9,6 +9,10 @@ import numpy as np
import pytest
import ot
+from ot.backend import get_backend_list
+from ot.backend import torch
+
+backend_list = get_backend_list()
def test_sinkhorn():
@@ -30,6 +34,76 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
def test_sinkhorn_empty():
# test sinkhorn
n = 100