summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2019-03-11 16:59:57 +0100
committerGitHub <noreply@github.com>2019-03-11 16:59:57 +0100
commite757b75976ece1e6e53e655852b9f8863e7b6f5a (patch)
treefc0383786c3b997a1c205749018cfc52114fc2fd
parent90d04e0f9a3e70d76c9a42b9bbc9c6f6a168269c (diff)
parentfea0c38c4260788c0359547f7caf75a3d92a2b42 (diff)
Merge pull request #76 from rflamary/bug_log_greenkhorn
Bug Greenkhorn with log=True closes Issue #75
-rw-r--r--.travis.yml2
-rw-r--r--ot/bregman.py14
-rw-r--r--test/test_bregman.py25
3 files changed, 39 insertions, 2 deletions
diff --git a/.travis.yml b/.travis.yml
index 90a0ff4..50ff22c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -26,7 +26,7 @@ before_script: # configure a headless display to test plot generation
# command to install dependencies
install:
- pip install -r requirements.txt
- - pip install flake8 pytest pytest-cov
+ - pip install flake8 pytest "pytest-cov<2.6"
- pip install .
# command to run tests + check syntax style
script:
diff --git a/ot/bregman.py b/ot/bregman.py
index 43340f7..013bc33 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -120,7 +120,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
print('Warning : unknown method using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
return sink()
@@ -499,6 +500,15 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
"""
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ if len(a) == 0:
+ a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ if len(b) == 0:
+ b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
n = a.shape[0]
m = b.shape[0]
@@ -514,7 +524,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
viol = G.sum(1) - a
viol_2 = G.sum(0) - b
stopThr_val = 1
+
if log:
+ log = dict()
log['u'] = u
log['v'] = v
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 14edaf5..90eaf27 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -81,6 +81,31 @@ def test_sinkhorn_variants():
print(G0, G_green)
+def test_sinkhorn_variants_log():
+ # test sinkhorn
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ M = ot.dist(x, x)
+
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ Ges, loges = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+
+ # check values
+ np.testing.assert_allclose(G0, Gs, atol=1e-05)
+ np.testing.assert_allclose(G0, Ges, atol=1e-05)
+ np.testing.assert_allclose(G0, Gerr)
+ np.testing.assert_allclose(G0, G_green, atol=1e-5)
+ print(G0, G_green)
+
+
def test_bary():
n_bins = 100 # nb bins