diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-04-11 16:26:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-11 16:26:30 +0200 |
commit | 486b0d6397182a57cd53651dca87fcea89747490 (patch) | |
tree | 15ce87f3b2a215038454b940b528ad7328e2058f /test/test_ot.py | |
parent | ac4cf442735ed4c0d5405ad861eddaa02afd4edd (diff) |
[MRG] Center gradients for mass of emd2 and gw2 (#363)
* center gradients for mass of emd2 and gw2
* debug fgw gradient
* debug fgw
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index bb258e2..bf832f6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -147,7 +147,7 @@ def test_emd2_gradients(): b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) - val = ot.emd2(a1, b1, M1) + val, log = ot.emd2(a1, b1, M1, log=True) val.backward() @@ -155,6 +155,12 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + assert np.allclose(a1.grad.cpu().detach().numpy(), + log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean()) + + assert np.allclose(b1.grad.cpu().detach().numpy(), + log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean()) + # Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) b2 = torch.tensor(a, requires_grad=True) |