diff options
author | arincbulgur <37184019+arincbulgur@users.noreply.github.com> | 2022-12-23 11:45:23 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-23 17:45:23 +0100 |
commit | 058d275565f0f65c23e06853812d5eb3a6ebdcef (patch) | |
tree | 669668c3df3f556f9af885f00f4bb2c81ccf4929 | |
parent | c9578b4cc29b58d9cde9ff586870140021471fc1 (diff) |
[MRG] Fix warning bug in sinkhorn2 (#417)
* Pass warn argument downstream in sinkhorn2 method.
* releases.md
* Fix unittest.
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 3 | ||||
-rw-r--r-- | ot/bregman.py | 18 | ||||
-rw-r--r-- | test/test_bregman.py | 6 |
3 files changed, 19 insertions, 8 deletions
diff --git a/RELEASES.md b/RELEASES.md index 4e41af6..c78319d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -30,8 +30,9 @@ roughly 2^31) (PR #381) - Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402) - Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409) - Fixed weak optimal transport docstring (Issue #404, PR #410) -- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, +- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, PR #413) +- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417) - Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls that explicitly specified `stopThr=1e-9` (Issue #421, PR #422). - Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425) diff --git a/ot/bregman.py b/ot/bregman.py index aa3cf1a..c33c92c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if len(b.shape) < 2: if method.lower() == 'sinkhorn': res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) diff --git a/test/test_bregman.py b/test/test_bregman.py index 0f47c3f..ce15642 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,6 +7,7 @@ # # License: MIT License +import warnings from itertools import product import numpy as np @@ -58,7 +59,10 @@ def test_convergence_warning(method): with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False) def test_not_implemented_method(): |