summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py42
1 files changed, 30 insertions, 12 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 53edf4f..bf832f6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -47,8 +47,7 @@ def test_emd_backends(nx):
G = ot.emd(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
@@ -68,8 +67,7 @@ def test_emd2_backends(nx):
val = ot.emd2(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
valb = ot.emd2(ab, ab, Mb)
@@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ab = nx.from_numpy(a, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ab, Mb = nx.from_numpy(a, M, type_as=tp)
Gb = ot.emd(ab, ab, Mb)
@@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -152,7 +147,7 @@ def test_emd2_gradients():
b1 = torch.tensor(a, requires_grad=True)
M1 = torch.tensor(M, requires_grad=True)
- val = ot.emd2(a1, b1, M1)
+ val, log = ot.emd2(a1, b1, M1, log=True)
val.backward()
@@ -160,6 +155,12 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ assert np.allclose(a1.grad.cpu().detach().numpy(),
+ log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())
+
+ assert np.allclose(b1.grad.cpu().detach().numpy(),
+ log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())
+
# Testing for bug #309, checking for scaling of gradient
a2 = torch.tensor(a, requires_grad=True)
b2 = torch.tensor(a, requires_grad=True)
@@ -232,7 +233,7 @@ def test_emd2_multi():
# Gaussian distributions
a = gauss(n, m=20, s=5) # m= mean, s= std
- ls = np.arange(20, 500, 20)
+ ls = np.arange(20, 500, 100)
nb = len(ls)
b = np.zeros((n, nb))
for i in range(nb):
@@ -302,6 +303,23 @@ def test_free_support_barycenter():
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+def test_free_support_barycenter_backends(nx):
+
+ measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
+ measures_weights = [np.array([1.]), np.array([1.])]
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
+
+ measures_locations2 = nx.from_numpy(*measures_locations)
+ measures_weights2 = nx.from_numpy(*measures_weights)
+ X_init2 = nx.from_numpy(X_init)
+
+ X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2)
+
+ np.testing.assert_allclose(X, nx.to_numpy(X2))
+
+
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]