summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorarincbulgur <37184019+arincbulgur@users.noreply.github.com>2022-12-23 11:45:23 -0500
committerGitHub <noreply@github.com>2022-12-23 17:45:23 +0100
commit058d275565f0f65c23e06853812d5eb3a6ebdcef (patch)
tree669668c3df3f556f9af885f00f4bb2c81ccf4929
parentc9578b4cc29b58d9cde9ff586870140021471fc1 (diff)
[MRG] Fix warning bug in sinkhorn2 (#417)
* Pass warn argument downstream in sinkhorn2 method. * releases.md * Fix unittest. Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md3
-rw-r--r--ot/bregman.py18
-rw-r--r--test/test_bregman.py6
3 files changed, 19 insertions, 8 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 4e41af6..c78319d 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -30,8 +30,9 @@ roughly 2^31) (PR #381)
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
- Fixed weak optimal transport docstring (Issue #404, PR #410)
-- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
+- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
PR #413)
+- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417)
- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
diff --git a/ot/bregman.py b/ot/bregman.py
index aa3cf1a..c33c92c 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 0f47c3f..ce15642 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -7,6 +7,7 @@
#
# License: MIT License
+import warnings
from itertools import product
import numpy as np
@@ -58,7 +59,10 @@ def test_convergence_warning(method):
with pytest.warns(UserWarning):
ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
with pytest.warns(UserWarning):
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True)
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False)
def test_not_implemented_method():