summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-10-19 12:05:59 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-10-19 12:05:59 +0200
commita919f960df57b4ea4beff57a3f7262b8064d8159 (patch)
treea19f4b94657f35fbf62cb21cdd5e5c36b5667a30
parent2f3741299989ffb105bed986f7a85d567fa6cb6a (diff)
same for unbalanced
-rw-r--r--ot/unbalanced.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index d516dfc..978df08 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -1002,12 +1002,14 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
if method.lower() == 'sinkhorn':
return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ weights=weights,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
@@ -1015,6 +1017,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
return barycenter_unbalanced(A, M, reg, reg_m,
+ weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)