diff options
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r-- | ot/lp/__init__.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index abf7fe0..390c32d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -517,7 +517,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'], log['v'], G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -540,8 +541,8 @@ def emd2(a, b, M, processes=1, ) G = nx.from_numpy(G, type_as=type_as) cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (nx.from_numpy(u, type_as=type_as), - nx.from_numpy(v, type_as=type_as), G)) + (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), G)) check_result(result_code) return cost |