summaryrefslogtreecommitdiff
path: root/test/test_ot.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_ot.py')
-rw-r--r--test/test_ot.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 92f26a7..c4d7713 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -126,6 +126,22 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ # Testing for bug #309, checking for scaling of gradient
+ a2 = torch.tensor(a, requires_grad=True)
+ b2 = torch.tensor(a, requires_grad=True)
+ M2 = torch.tensor(M, requires_grad=True)
+
+ val = 10.0 * ot.emd2(a2, b2, M2)
+
+ val.backward()
+
+ assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
+ a2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
+ b2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
+ M2.grad.cpu().detach().numpy())
+
def test_emd_emd2():
# test emd and emd2 for simple identity