From 7ffd4fef3260e086b0b1ed050f5cb4b83195b122 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Sep 2018 10:14:44 +0200 Subject: remove @ for python compatibility+ comments alexandre --- ot/bregman.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) (limited to 'ot') diff --git a/ot/bregman.py b/ot/bregman.py index 8538c92..faa6365 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -480,7 +480,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.sinkhorn(a,b,M,1) + >>> ot.bregman.greenkhorn(a,b,M,1) array([[ 0.36552929, 0.13447071], [ 0.13447071, 0.36552929]]) @@ -505,18 +505,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= m = b.shape[0] # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) + K = np.empty_like(M) np.divide(M, -reg, out=K) np.exp(K, out=K) - u = np.ones(n) / n - v = np.ones(m) / m - G = np.diag(u)@K@np.diag(v) + u = np.full(n, 1. / n) + v = np.full(m, 1. / m) + G = u[:, np.newaxis] * K * v[np.newaxis, :] one_n = np.ones(n) one_m = np.ones(m) - viol = G@one_m - a - viol_2 = G.T@one_n - b + viol = G.sum(1) - a + viol_2 = G.sum(0) - b stopThr_val = 1 if log: log['u'] = u @@ -532,26 +532,26 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1] / (K[i_1, :]@v) + u[i_1] = a[i_1] / (K[i_1, :].dot(v)) G[i_1, :] = u[i_1] * K[i_1, :] * v - viol[i_1] = u[i_1] * K[i_1, :]@v - a[i_1] + viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] viol_2 = viol_2 + (K[i_1, :].T * (u[i_1] - old_u) * v) else: old_v = v[i_2] - v[i_2] = b[i_2] / (K[:, i_2].T@u) + v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) G[:, i_2] = u * K[:, i_2] * v[i_2] #aviol = (G@one_m - a) #aviol_2 = (G.T@one_n - b) viol = viol + (-old_v + v[i_2]) * K[:, i_2] * u - viol_2[i_2] = v[i_2] * K[:, i_2]@u - b[i_2] + viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) - if log: - log['u'] = u - log['v'] = v + if log: + log['u'] = u + log['v'] = v if log: return G, log -- cgit v1.2.3