summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index aa3cf1a..c33c92c 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)