summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index abf7fe0..390c32d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -517,7 +517,8 @@ def emd2(a, b, M, processes=1,
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (log['u'], log['v'], G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -540,8 +541,8 @@ def emd2(a, b, M, processes=1,
)
G = nx.from_numpy(G, type_as=type_as)
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
- nx.from_numpy(v, type_as=type_as), G))
+ (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
+ nx.from_numpy(v - np.mean(v), type_as=type_as), G))
check_result(result_code)
return cost