diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-04-11 16:26:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-11 16:26:30 +0200 |
commit | 486b0d6397182a57cd53651dca87fcea89747490 (patch) | |
tree | 15ce87f3b2a215038454b940b528ad7328e2058f /ot | |
parent | ac4cf442735ed4c0d5405ad861eddaa02afd4edd (diff) |
[MRG] Center gradients for mass of emd2 and gw2 (#363)
* center gradients for mass of emd2 and gw2
* debug fgw gradient
* debug fgw
Diffstat (limited to 'ot')
-rw-r--r-- | ot/gromov.py | 7 | ||||
-rw-r--r-- | ot/lp/__init__.py | 7 |
2 files changed, 9 insertions, 5 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index c5a82d1..55ab0bd 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -551,7 +551,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'], log_gw['v'], gC1, gC2))
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
if log:
return gw, log_gw
@@ -793,7 +794,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T0))
if log:
return fgw_dist, log_fgw
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index abf7fe0..390c32d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -517,7 +517,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'], log['v'], G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -540,8 +541,8 @@ def emd2(a, b, M, processes=1, ) G = nx.from_numpy(G, type_as=type_as) cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (nx.from_numpy(u, type_as=type_as), - nx.from_numpy(v, type_as=type_as), G)) + (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), G)) check_result(result_code) return cost |