diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-11-16 13:07:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-16 13:07:38 +0100 |
commit | f4b363d865a79c07248176c1e36990e0cb6814ea (patch) | |
tree | 37f51d94a01ae495e28cec55a78e1c9404ac48d9 /test | |
parent | 0c589912800b23609c730871c080ade0c807cdc1 (diff) |
[WIP] Fix gradient scaling bug in emd (#310)
* orrect gradient bug in emd2
* small comment in test
* deploy properly on tag release
* subplot fail
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ot.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 92f26a7..c4d7713 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -126,6 +126,22 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + # Testing for bug #309, checking for scaling of gradient + a2 = torch.tensor(a, requires_grad=True) + b2 = torch.tensor(a, requires_grad=True) + M2 = torch.tensor(M, requires_grad=True) + + val = 10.0 * ot.emd2(a2, b2, M2) + + val.backward() + + assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(), + a2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(), + b2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(), + M2.grad.cpu().detach().numpy()) + def test_emd_emd2(): # test emd and emd2 for simple identity |