diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-03-11 16:59:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-03-11 16:59:57 +0100 |
commit | e757b75976ece1e6e53e655852b9f8863e7b6f5a (patch) | |
tree | fc0383786c3b997a1c205749018cfc52114fc2fd /ot | |
parent | 90d04e0f9a3e70d76c9a42b9bbc9c6f6a168269c (diff) | |
parent | fea0c38c4260788c0359547f7caf75a3d92a2b42 (diff) |
Merge pull request #76 from rflamary/bug_log_greenkhorn
Bug Greenkhorn with log=True
closes Issue #75
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 43340f7..013bc33 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -120,7 +120,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, print('Warning : unknown method using classic Sinkhorn Knopp') def sink(): - return sinkhorn_knopp(a, b, M, reg, **kwargs) + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) return sink() @@ -499,6 +500,15 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= """ + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + if len(a) == 0: + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + n = a.shape[0] m = b.shape[0] @@ -514,7 +524,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= viol = G.sum(1) - a viol_2 = G.sum(0) - b stopThr_val = 1 + if log: + log = dict() log['u'] = u log['v'] = v |