summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-04-11 16:26:30 +0200
committerGitHub <noreply@github.com>2022-04-11 16:26:30 +0200
commit486b0d6397182a57cd53651dca87fcea89747490 (patch)
tree15ce87f3b2a215038454b940b528ad7328e2058f /test
parentac4cf442735ed4c0d5405ad861eddaa02afd4edd (diff)
[MRG] Center gradients for mass of emd2 and gw2 (#363)
* center gradients for mass of emd2 and gw2 * debug fgw gradient * debug fgw
Diffstat (limited to 'test')
-rw-r--r--test/test_ot.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/test/test_ot.py b/test/test_ot.py
index bb258e2..bf832f6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -147,7 +147,7 @@ def test_emd2_gradients():
b1 = torch.tensor(a, requires_grad=True)
M1 = torch.tensor(M, requires_grad=True)
- val = ot.emd2(a1, b1, M1)
+ val, log = ot.emd2(a1, b1, M1, log=True)
val.backward()
@@ -155,6 +155,12 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ assert np.allclose(a1.grad.cpu().detach().numpy(),
+ log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())
+
+ assert np.allclose(b1.grad.cpu().detach().numpy(),
+ log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())
+
# Testing for bug #309, checking for scaling of gradient
a2 = torch.tensor(a, requires_grad=True)
b2 = torch.tensor(a, requires_grad=True)