summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/demo_OT_2D_samples.py2
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/bregman.py18
3 files changed, 14 insertions, 10 deletions
diff --git a/examples/demo_OT_2D_samples.py b/examples/demo_OT_2D_samples.py
index f91bbb2..992352c 100644
--- a/examples/demo_OT_2D_samples.py
+++ b/examples/demo_OT_2D_samples.py
@@ -62,7 +62,7 @@ pl.title('OT matrix')
#%% sinkhorn
-lambd=.8e-1
+lambd=1e-1
Gs=ot.sinkhorn(a,b,M,lambd)
diff --git a/ot/__init__.py b/ot/__init__.py
index ce9d157..dd74590 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -8,9 +8,9 @@ import bregman
# OT functions
from emd import emd
-from bregman import sinkhorn
+from bregman import sinkhorn,barycenter
# utils functions
from utils import dist,unif
-__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif']
+__all__ = ["emd","sinkhorn","utils",'datasets','bregman','plot','dist','unif','barycenter']
diff --git a/ot/bregman.py b/ot/bregman.py
index 1761029..e46506b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -116,7 +116,7 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict
#compute Mmax once for all
#M = M/np.median(M) # suggested by G. Peyre
- K = np.exp(-reg*M)
+ K = np.exp(-M/reg)
cpt = 0
err=1
@@ -124,7 +124,8 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict
UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T)
u = (geometricMean(UKv)/UKv.T).T
- log = {'niter':0, 'all_err':[]}
+ log['niter']=0
+ log['all_err']=[]
while (err>tol_error and cpt<numItermax):
cpt = cpt +1
@@ -149,16 +150,19 @@ def unmixBregman(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1
alpha : how much should we trust the prior ? ([0,1])
"""
- M = M/np.median(M)
- K = np.exp(-reg*M)
+ #M = M/np.median(M)
+ K = np.exp(-M/reg)
- M0 = M0/np.median(M0)
- K0 = np.exp(-reg0*M0)
+ #M0 = M0/np.median(M0)
+ K0 = np.exp(-M0/reg0)
old = h0
err=1
cpt=0
- log = {'niter':0, 'all_err':[]}
+ #log = {'niter':0, 'all_err':[]}
+ log['niter']=0
+ log['all_err']=[]
+
while (err>tol_error and cpt<numItermax):
K = projC(K,distrib)