summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/demo_barycenter_1D.py2
-rw-r--r--ot/bregman.py91
-rw-r--r--ot/lp/__init__.py5
-rw-r--r--ot/optim.py4
4 files changed, 83 insertions, 19 deletions
diff --git a/examples/demo_barycenter_1D.py b/examples/demo_barycenter_1D.py
index c9f63a2..2376f7b 100644
--- a/examples/demo_barycenter_1D.py
+++ b/examples/demo_barycenter_1D.py
@@ -43,7 +43,7 @@ bary_l2=A.mean(1)
# wasserstein
reg=1e-3
-bary_wass,log=ot.bregman.barycenter(A,M,reg)
+bary_wass=ot.bregman.barycenter(A,M,reg)
pl.figure(2)
pl.clf()
diff --git a/ot/bregman.py b/ot/bregman.py
index 08f965b..b6cdf80 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -43,9 +43,9 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
Max number of iterations
stopThr: float, optional
Stop threshol on error (>0)
- verbose : int, optional
+ verbose : bool, optional
Print information along iterations
- log : int, optional
+ log : bool, optional
record log if True
@@ -96,7 +96,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
cpt = 0
if log:
- log={'loss':[]}
+ log={'err':[]}
# we assume that no distances are null except those of the diagonal of distances
u = np.ones(Nini)/Nini
@@ -131,7 +131,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
if log:
- log['loss'].append(err)
+ log['err'].append(err)
if verbose:
if cpt%200 ==0:
@@ -146,10 +146,12 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
def geometricBar(weights,alldistribT):
+ """return the weighted geometric mean of distributions"""
assert(len(weights)==alldistribT.shape[1])
return np.exp(np.dot(np.log(alldistribT),weights.T))
def geometricMean(alldistribT):
+ """return the geometric mean of distributions"""
return np.exp(np.mean(np.log(alldistribT),axis=1))
def projR(gamma,p):
@@ -161,16 +163,66 @@ def projC(gamma,q):
return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10))
-def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict()):
- """Compute the Regularizzed wassersteien barycenter of distributions A"""
+def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=False,log=False):
+ """Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : np.ndarray (d,n)
+ n training distributions of size d
+ M : np.ndarray (ns,nt)
+ loss matrix for OT
+ reg: float
+ Regularization term >0
+ numItermax: int, optional
+ Max number of iterations
+ stopThr: float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a: (d,) ndarray
+ Wasserstein barycenter
+ log: dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+
+
+ """
if weights is None:
weights=np.ones(A.shape[1])/A.shape[1]
else:
assert(len(weights)==A.shape[1])
+
+ if log:
+ log={'err':[]}
- #compute Mmax once for all
#M = M/np.median(M) # suggested by G. Peyre
K = np.exp(-M/reg)
@@ -180,19 +232,28 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict
UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T)
u = (geometricMean(UKv)/UKv.T).T
- log['niter']=0
- log['all_err']=[]
-
- while (err>tol_error and cpt<numItermax):
+ while (err>stopThr and cpt<numItermax):
cpt = cpt +1
UKv=u*np.dot(K,np.divide(A,np.dot(K,u)))
u = (u.T*geometricBar(weights,UKv)).T/UKv
+
if cpt%10==1:
err=np.sum(np.std(UKv,axis=1))
- log['all_err'].append(err)
-
- log['niter']=cpt
- return geometricBar(weights,UKv),log
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt%200 ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
+
+ if log:
+ log['niter']=cpt
+ return geometricBar(weights,UKv),log
+ else:
+ return geometricBar(weights,UKv)
def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()):
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 72b4cb8..6c7822a 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -1,5 +1,8 @@
+"""
+Solvers for the original linear program OT problem
+"""
-
+# import compiled emd
from .emd import emd_c
import numpy as np
diff --git a/ot/optim.py b/ot/optim.py
index d1bf672..632eac1 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -103,9 +103,9 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
- verbose : int, optional
+ verbose : bool, optional
Print information along iterations
- log : int, optional
+ log : bool, optional
record log if True
Returns