diff options
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 52 |
1 files changed, 51 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 92f26a7..53edf4f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -11,7 +11,7 @@ import pytest import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import torch +from ot.backend import torch, tf def test_emd_dimension_and_mass_mismatch(): @@ -101,6 +101,40 @@ def test_emd_emd2_types_devices(nx): nx.assert_same_dtype_device(Mb, w) +@pytest.mark.skipif(not tf, reason="tf not installed") +def test_emd_emd2_devices_tf(): + if not tf: + return + nx = ot.backend.TensorflowBackend() + + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + M = ot.dist(x, y) + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + Gb = ot.emd(ab, ab, Mb) + w = ot.emd2(ab, ab, Mb) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + Gb = ot.emd(ab, ab, Mb) + w = ot.emd2(ab, ab, Mb) + nx.assert_same_dtype_device(Mb, Gb) + nx.assert_same_dtype_device(Mb, w) + assert nx.dtype_device(Gb)[1].startswith("GPU") + + def test_emd2_gradients(): n_samples = 100 n_features = 2 @@ -126,6 +160,22 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + # Testing for bug #309, checking for scaling of gradient + a2 = torch.tensor(a, requires_grad=True) + b2 = torch.tensor(a, requires_grad=True) + M2 = torch.tensor(M, requires_grad=True) + + val = 10.0 * ot.emd2(a2, b2, M2) + + val.backward() + + assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(), + a2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(), + b2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(), + M2.grad.cpu().detach().numpy()) + def test_emd_emd2(): # test emd and emd2 for simple identity |