summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 10:58:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 10:58:04 +0200
commit8cd50c55f398cc371db2ef334c803dec99cc209a (patch)
treed6084ced937c38603dab8a72d0cc5e64aaf83480 /ot/bregman.py
parenta0d8139af3407e567e1dc9a5e8c10d9218ddd185 (diff)
update doc optim+bregman; add log to sinkhorn
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py27
1 files changed, 21 insertions, 6 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b749b13..08f965b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
"""
-Bregman projection for regularized Otimal transport
+Bregman projections for regularized OT
"""
import numpy as np
-def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
+def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -43,14 +43,18 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
Max number of iterations
stopThr: float, optional
Stop threshol on error (>0)
-
+ verbose : int, optional
+ Print information along iterations
+ log : int, optional
+ record log if True
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
+ log: dict
+ log dictionary return only if log==True in parameters
Examples
--------
@@ -91,6 +95,8 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
cpt = 0
+ if log:
+ log={'loss':[]}
# we assume that no distances are null except those of the diagonal of distances
u = np.ones(Nini)/Nini
@@ -124,10 +130,19 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
# we can speed up the process by checking for the error only all the 10th iterations
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if log:
+ log['loss'].append(err)
+
+ if verbose:
+ if cpt%200 ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
cpt = cpt +1
#print 'err=',err,' cpt=',cpt
-
- return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+ if log:
+ return np.dot(np.diag(u),np.dot(K,np.diag(v))),log
+ else:
+ return np.dot(np.diag(u),np.dot(K,np.diag(v)))
def geometricBar(weights,alldistribT):