diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index aa3cf1a..c33c92c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if len(b.shape) < 2: if method.lower() == 'sinkhorn': res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) |