summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2019-10-24 14:39:12 +0200
committerGitHub <noreply@github.com>2019-10-24 14:39:12 +0200
commit65ca6bfde77dd11d84cbd151fe9ff98454f8e206 (patch)
treed2c22473d9dbc48ad16ce2b95863eaa2ae6242b3 /ot/bregman.py
parent5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff)
parent161d68a79bc528a0d87e421f67a419cd757c7fba (diff)
Merge pull request #106 from hichamjanati/fix-weighted-bar
MRG: Forgotten weights arg in barycenter funcs
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 2cd832b..ba5c7ba 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1037,11 +1037,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
"""
if method.lower() == 'sinkhorn':
- return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
+ return barycenter_sinkhorn(A, M, reg, weights=weights,
+ numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ return barycenter_stabilized(A, M, reg, weights=weights,
+ numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else: