From 142f51c080572dd2bfe2b2e9434df648fd7ab018 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 24 Oct 2016 16:32:01 +0200 Subject: correction barcenter --- examples/demo_OT_2D_samples.py | 2 +- ot/__init__.py | 4 ++-- ot/bregman.py | 18 +++++++++++------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py index f91bbb2..992352c 100644 --- a/examples/demo_OT_2D_samples.py +++ b/examples/demo_OT_2D_samples.py @@ -62,7 +62,7 @@ pl.title('OT matrix') #%% sinkhorn -lambd=.8e-1 +lambd=1e-1 Gs=ot.sinkhorn(a,b,M,lambd) diff --git a/ot/__init__.py b/ot/__init__.py index ce9d157..dd74590 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -8,9 +8,9 @@ import bregman # OT functions from emd import emd -from bregman import sinkhorn +from bregman import sinkhorn,barycenter # utils functions from utils import dist,unif -__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif'] +__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter'] diff --git a/ot/bregman.py b/ot/bregman.py index 1761029..e46506b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -116,7 +116,7 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict #compute Mmax once for all #M = M/np.median(M) # suggested by G. Peyre - K = np.exp(-reg*M) + K = np.exp(-M/reg) cpt = 0 err=1 @@ -124,7 +124,8 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T) u = (geometricMean(UKv)/UKv.T).T - log = {'niter':0, 'all_err':[]} + log['niter']=0 + log['all_err']=[] while (err>tol_error and cpttol_error and cpt