summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2021-11-16 13:07:38 +0100
committerGitHub <noreply@github.com>2021-11-16 13:07:38 +0100
commitf4b363d865a79c07248176c1e36990e0cb6814ea (patch)
tree37f51d94a01ae495e28cec55a78e1c9404ac48d9
parent0c589912800b23609c730871c080ade0c807cdc1 (diff)
[WIP] Fix gradient scaling bug in emd (#310)
* orrect gradient bug in emd2 * small comment in test * deploy properly on tag release * subplot fail
-rw-r--r--.circleci/config.yml33
-rw-r--r--examples/plot_Intro_OT.py2
-rw-r--r--ot/backend.py2
-rw-r--r--test/test_ot.py16
4 files changed, 33 insertions, 20 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 85f8073..96c1fbf 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -134,24 +134,21 @@ jobs:
name: Deploy docs
command: |
set -e;
- if [ "${CIRCLE_BRANCH}" == "master" ]; then
- git config --global user.email "circle@PythonOT.com";
- git config --global user.name "Circle CI";
- cd ~/PythonOT.github.io;
- git checkout master
- git remote -v
- git fetch origin
- git reset --hard origin/master
- git clean -xdf
- echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
- cp -a /tmp/build/html/* .;
- touch .nojekyll;
- git add -A;
- git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
- git push origin master;
- else
- echo "No deployment (build: ${CIRCLE_BRANCH}).";
- fi
+ git config --global user.email "circle@PythonOT.com";
+ git config --global user.name "Circle CI";
+ cd ~/PythonOT.github.io;
+ git checkout master
+ git remote -v
+ git fetch origin
+ git reset --hard origin/master
+ git clean -xdf
+ echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
+ cp -a /tmp/build/html/* .;
+ touch .nojekyll;
+ git add -A;
+ git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
+ git push origin master;
+
workflows:
diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py
index 2e2c6fd..f282950 100644
--- a/examples/plot_Intro_OT.py
+++ b/examples/plot_Intro_OT.py
@@ -327,7 +327,7 @@ for k in range(len(reg_parameter)):
time_sinkhorn_reg[k] = time.time() - start
if k % 4 == 0 and k > 0: # we only plot a few
- ax = pl.subplot(1, 5, k / 4)
+ ax = pl.subplot(1, 5, k // 4)
im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
pl.title('reg={0:.2g}'.format(reg_parameter[k]))
pl.xlabel('Cafés')
diff --git a/ot/backend.py b/ot/backend.py
index a044f84..fa164c3 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1203,7 +1203,7 @@ class TorchBackend(Backend):
@staticmethod
def backward(ctx, grad_output):
# the gradients are grad
- return (None, None) + ctx.grads
+ return (None, None) + tuple(g * grad_output for g in ctx.grads)
self.ValFunction = ValFunction
diff --git a/test/test_ot.py b/test/test_ot.py
index 92f26a7..c4d7713 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -126,6 +126,22 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ # Testing for bug #309, checking for scaling of gradient
+ a2 = torch.tensor(a, requires_grad=True)
+ b2 = torch.tensor(a, requires_grad=True)
+ M2 = torch.tensor(M, requires_grad=True)
+
+ val = 10.0 * ot.emd2(a2, b2, M2)
+
+ val.backward()
+
+ assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
+ a2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
+ b2.grad.cpu().detach().numpy())
+ assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
+ M2.grad.cpu().detach().numpy())
+
def test_emd_emd2():
# test emd and emd2 for simple identity