summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-07-24 15:54:56 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-07-24 15:54:56 +0200
commitf4bfeb73da098384aa67599e7f729fb683a1bcc9 (patch)
treec80fc4099322ff430992b0ab1d6465b90fb6b058 /ot
parentace77962d2ae6407916ee7e4377f5c7ed0a8d8f2 (diff)
ensum tets marginals sinkhorn
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py5
1 files changed, 1 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 58e74de..c755f51 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -396,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if nbb: # return only loss
- res = np.zeros((nbb))
- for i in range(nbb):
- res[i] = np.sum(
- u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
+ res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else: