summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-03-11 10:39:03 +0100
committerRémi Flamary <remi.flamary@gmail.com>2019-03-11 10:39:03 +0100
commit42a501c5d839c010bbfa3a4440b43cb4f9775fc7 (patch)
treef885bb6b6edd9b00e02a35b20d1afbcb749e9923 /ot
parent90d04e0f9a3e70d76c9a42b9bbc9c6f6a168269c (diff)
add test sinkhorn+log
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 43340f7..013bc33 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -120,7 +120,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
print('Warning : unknown method using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
return sink()
@@ -499,6 +500,15 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
"""
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
n = a.shape[0]
m = b.shape[0]
@@ -514,7 +524,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
viol = G.sum(1) - a
viol_2 = G.sum(0) - b
stopThr_val = 1
+
if log:
+ log = dict()
log['u'] = u
log['v'] = v