summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-04-11 16:26:30 +0200
committerGitHub <noreply@github.com>2022-04-11 16:26:30 +0200
commit486b0d6397182a57cd53651dca87fcea89747490 (patch)
tree15ce87f3b2a215038454b940b528ad7328e2058f
parentac4cf442735ed4c0d5405ad861eddaa02afd4edd (diff)
[MRG] Center gradients for mass of emd2 and gw2 (#363)
* center gradients for mass of emd2 and gw2 * debug fgw gradient * debug fgw
-rw-r--r--RELEASES.md4
-rw-r--r--ot/gromov.py7
-rw-r--r--ot/lp/__init__.py7
-rw-r--r--test/test_ot.py8
4 files changed, 19 insertions, 7 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 7942a15..33d1ab6 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,7 +5,7 @@
#### New features
-- remode deprecated `ot.gpu` submodule (PR #361)
+- Remove deprecated `ot.gpu` submodule (PR #361)
- Update examples in the gallery (PR #359).
- Add stochastic loss and OT plan computation for regularized OT and
backend examples(PR #360).
@@ -23,6 +23,8 @@
#### Closed issues
+- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
+ centered (Issue #364, PR #363)
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
PR #338)
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
diff --git a/ot/gromov.py b/ot/gromov.py
index c5a82d1..55ab0bd 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -551,7 +551,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'], log_gw['v'], gC1, gC2))
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
if log:
return gw, log_gw
@@ -793,7 +794,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T0))
if log:
return fgw_dist, log_fgw
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index abf7fe0..390c32d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -517,7 +517,8 @@ def emd2(a, b, M, processes=1,
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (log['u'], log['v'], G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -540,8 +541,8 @@ def emd2(a, b, M, processes=1,
)
G = nx.from_numpy(G, type_as=type_as)
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
- nx.from_numpy(v, type_as=type_as), G))
+ (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
+ nx.from_numpy(v - np.mean(v), type_as=type_as), G))
check_result(result_code)
return cost
diff --git a/test/test_ot.py b/test/test_ot.py
index bb258e2..bf832f6 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -147,7 +147,7 @@ def test_emd2_gradients():
b1 = torch.tensor(a, requires_grad=True)
M1 = torch.tensor(M, requires_grad=True)
- val = ot.emd2(a1, b1, M1)
+ val, log = ot.emd2(a1, b1, M1, log=True)
val.backward()
@@ -155,6 +155,12 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape
+ assert np.allclose(a1.grad.cpu().detach().numpy(),
+ log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())
+
+ assert np.allclose(b1.grad.cpu().detach().numpy(),
+ log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())
+
# Testing for bug #309, checking for scaling of gradient
a2 = torch.tensor(a, requires_grad=True)
b2 = torch.tensor(a, requires_grad=True)