summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-11-16 13:07:38 +0100
committerGitHub <noreply@github.com>2021-11-16 13:07:38 +0100
commitf4b363d865a79c07248176c1e36990e0cb6814ea (patch)
tree37f51d94a01ae495e28cec55a78e1c9404ac48d9 /test
parent0c589912800b23609c730871c080ade0c807cdc1 (diff)
[WIP] Fix gradient scaling bug in emd (#310)
* orrect gradient bug in emd2 * small comment in test * deploy properly on tag release * subplot fail
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index 92f26a7..c4d7713 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -126,6 +126,22 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ # Testing for bug #309, checking for scaling of gradient
+ a2 = torch.tensor(a, requires_grad=True)
+ b2 = torch.tensor(a, requires_grad=True)
+ M2 = torch.tensor(M, requires_grad=True)
+
+ val = 10.0 * ot.emd2(a2, b2, M2)
+
+ val.backward()
+
+ assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
+ a2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
+ b2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
+ M2.grad.cpu().detach().numpy())
+
def test_emd_emd2():
# test emd and emd2 for simple identity