From f4b363d865a79c07248176c1e36990e0cb6814ea Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 16 Nov 2021 13:07:38 +0100 Subject: [WIP] Fix gradient scaling bug in emd (#310) * orrect gradient bug in emd2 * small comment in test * deploy properly on tag release * subplot fail --- .circleci/config.yml | 33 +++++++++++++++------------------ examples/plot_Intro_OT.py | 2 +- ot/backend.py | 2 +- test/test_ot.py | 16 ++++++++++++++++ 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 85f8073..96c1fbf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -134,24 +134,21 @@ jobs: name: Deploy docs command: | set -e; - if [ "${CIRCLE_BRANCH}" == "master" ]; then - git config --global user.email "circle@PythonOT.com"; - git config --global user.name "Circle CI"; - cd ~/PythonOT.github.io; - git checkout master - git remote -v - git fetch origin - git reset --hard origin/master - git clean -xdf - echo "Deploying dev docs for ${CIRCLE_BRANCH}."; - cp -a /tmp/build/html/* .; - touch .nojekyll; - git add -A; - git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM})."; - git push origin master; - else - echo "No deployment (build: ${CIRCLE_BRANCH})."; - fi + git config --global user.email "circle@PythonOT.com"; + git config --global user.name "Circle CI"; + cd ~/PythonOT.github.io; + git checkout master + git remote -v + git fetch origin + git reset --hard origin/master + git clean -xdf + echo "Deploying dev docs for ${CIRCLE_BRANCH}."; + cp -a /tmp/build/html/* .; + touch .nojekyll; + git add -A; + git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM})."; + git push origin master; + workflows: diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py index 2e2c6fd..f282950 100644 --- a/examples/plot_Intro_OT.py +++ b/examples/plot_Intro_OT.py @@ -327,7 +327,7 @@ for k in range(len(reg_parameter)): time_sinkhorn_reg[k] = time.time() - start if k % 4 == 0 and k > 0: # we only plot a few - ax = pl.subplot(1, 5, k / 4) + ax = pl.subplot(1, 5, k // 4) im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot) pl.title('reg={0:.2g}'.format(reg_parameter[k])) pl.xlabel('Cafés') diff --git a/ot/backend.py b/ot/backend.py index a044f84..fa164c3 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1203,7 +1203,7 @@ class TorchBackend(Backend): @staticmethod def backward(ctx, grad_output): # the gradients are grad - return (None, None) + ctx.grads + return (None, None) + tuple(g * grad_output for g in ctx.grads) self.ValFunction = ValFunction diff --git a/test/test_ot.py b/test/test_ot.py index 92f26a7..c4d7713 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -126,6 +126,22 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + # Testing for bug #309, checking for scaling of gradient + a2 = torch.tensor(a, requires_grad=True) + b2 = torch.tensor(a, requires_grad=True) + M2 = torch.tensor(M, requires_grad=True) + + val = 10.0 * ot.emd2(a2, b2, M2) + + val.backward() + + assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(), + a2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(), + b2.grad.cpu().detach().numpy()) + assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(), + M2.grad.cpu().detach().numpy()) + def test_emd_emd2(): # test emd and emd2 for simple identity -- cgit v1.2.3