summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-09-24 10:14:44 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-09-24 10:14:44 +0200
commit7ffd4fef3260e086b0b1ed050f5cb4b83195b122 (patch)
treea0b6acf2ba9adf585eef4faf3e6c190bf22a19fb /ot/bregman.py
parentf3433fda3e8f5c58ec1d7e5623825d4627435ebc (diff)
remove @ for python compatibility+ comments alexandre
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 8538c92..faa6365 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -480,7 +480,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn(a,b,M,1)
+ >>> ot.bregman.greenkhorn(a,b,M,1)
array([[ 0.36552929, 0.13447071],
[ 0.13447071, 0.36552929]])
@@ -505,18 +505,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
m = b.shape[0]
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
+ K = np.empty_like(M)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- u = np.ones(n) / n
- v = np.ones(m) / m
- G = np.diag(u)@K@np.diag(v)
+ u = np.full(n, 1. / n)
+ v = np.full(m, 1. / m)
+ G = u[:, np.newaxis] * K * v[np.newaxis, :]
one_n = np.ones(n)
one_m = np.ones(m)
- viol = G@one_m - a
- viol_2 = G.T@one_n - b
+ viol = G.sum(1) - a
+ viol_2 = G.sum(0) - b
stopThr_val = 1
if log:
log['u'] = u
@@ -532,26 +532,26 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- u[i_1] = a[i_1] / (K[i_1, :]@v)
+ u[i_1] = a[i_1] / (K[i_1, :].dot(v))
G[i_1, :] = u[i_1] * K[i_1, :] * v
- viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1]
+ viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v)
else:
old_v = v[i_2]
- v[i_2] = b[i_2] / (K[:, i_2].T@u)
+ v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
G[:, i_2] = u * K[:, i_2] * v[i_2]
#aviol = (G@one_m - a)
#aviol_2 = (G.T@one_n - b)
viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u
- viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2]
+ viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
- if log:
- log['u'] = u
- log['v'] = v
+ if log:
+ log['u'] = u
+ log['v'] = v
if log:
return G, log