summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py52
1 files changed, 51 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 92f26a7..53edf4f 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -11,7 +11,7 @@ import pytest
import ot
from ot.datasets import make_1D_gauss as gauss
-from ot.backend import torch
+from ot.backend import torch, tf
def test_emd_dimension_and_mass_mismatch():
@@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx):
nx.assert_same_dtype_device(Mb, w)
+@pytest.mark.skipif(not tf, reason="tf not installed")
+def test_emd_emd2_devices_tf():
+ if not tf:
+ return
+ nx = ot.backend.TensorflowBackend()
+
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+ M = ot.dist(x, y)
+
+ # Check that everything stays on the CPU
+ with tf.device("/CPU:0"):
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+
+ if len(tf.config.list_physical_devices('GPU')) > 0:
+ # Check that everything happens on the GPU
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+ Gb = ot.emd(ab, ab, Mb)
+ w = ot.emd2(ab, ab, Mb)
+ nx.assert_same_dtype_device(Mb, Gb)
+ nx.assert_same_dtype_device(Mb, w)
+ assert nx.dtype_device(Gb)[1].startswith("GPU")
+
+
def test_emd2_gradients():
n_samples = 100
n_features = 2
@@ -126,6 +160,22 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ # Testing for bug #309, checking for scaling of gradient
+ a2 = torch.tensor(a, requires_grad=True)
+ b2 = torch.tensor(a, requires_grad=True)
+ M2 = torch.tensor(M, requires_grad=True)
+
+ val = 10.0 * ot.emd2(a2, b2, M2)
+
+ val.backward()
+
+ assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
+ a2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
+ b2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
+ M2.grad.cpu().detach().numpy())
+
def test_emd_emd2():
# test emd and emd2 for simple identity