From 486b0d6397182a57cd53651dca87fcea89747490 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 11 Apr 2022 16:26:30 +0200 Subject: [MRG] Center gradients for mass of emd2 and gw2 (#363) * center gradients for mass of emd2 and gw2 * debug fgw gradient * debug fgw --- test/test_ot.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'test/test_ot.py') diff --git a/test/test_ot.py b/test/test_ot.py index bb258e2..bf832f6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -147,7 +147,7 @@ def test_emd2_gradients(): b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) - val = ot.emd2(a1, b1, M1) + val, log = ot.emd2(a1, b1, M1, log=True) val.backward() @@ -155,6 +155,12 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + assert np.allclose(a1.grad.cpu().detach().numpy(), + log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean()) + + assert np.allclose(b1.grad.cpu().detach().numpy(), + log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean()) + # Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) b2 = torch.tensor(a, requires_grad=True) -- cgit v1.2.3