diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_ot.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 92f26a7..c4d7713 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -126,6 +126,22 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + # Testing for bug #309, checking for scaling of gradient + a2 = torch.tensor(a, requires_grad=True) + b2 = torch.tensor(a, requires_grad=True) + M2 = torch.tensor(M, requires_grad=True) + + val = 10.0 * ot.emd2(a2, b2, M2) + + val.backward() + + assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(), + a2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(), + b2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(), + M2.grad.cpu().detach().numpy()) + def test_emd_emd2(): # test emd and emd2 for simple identity |