From 058d275565f0f65c23e06853812d5eb3a6ebdcef Mon Sep 17 00:00:00 2001 From: arincbulgur <37184019+arincbulgur@users.noreply.github.com> Date: Fri, 23 Dec 2022 11:45:23 -0500 Subject: [MRG] Fix warning bug in sinkhorn2 (#417) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Pass warn argument downstream in sinkhorn2 method. * releases.md * Fix unittest. Co-authored-by: RĂ©mi Flamary --- test/test_bregman.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'test/test_bregman.py') diff --git a/test/test_bregman.py b/test/test_bregman.py index 0f47c3f..ce15642 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,6 +7,7 @@ # # License: MIT License +import warnings from itertools import product import numpy as np @@ -58,7 +59,10 @@ def test_convergence_warning(method): with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False) def test_not_implemented_method(): -- cgit v1.2.3