summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 10:58:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 10:58:04 +0200
commit8cd50c55f398cc371db2ef334c803dec99cc209a (patch)
treed6084ced937c38603dab8a72d0cc5e64aaf83480
parenta0d8139af3407e567e1dc9a5e8c10d9218ddd185 (diff)
update doc optim+bregman; add log to sinkhorn
-rw-r--r--ot/__init__.py7
-rw-r--r--ot/bregman.py27
-rw-r--r--ot/lp/emd.cpp2
-rw-r--r--ot/optim.py58
4 files changed, 78 insertions, 16 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 87119e5..863f408 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,13 +1,14 @@
# Python Optimal Transport toolbox
# All submodules and packages
+from . import lp
+from . import bregman
+from . import optim
from . import utils
from . import datasets
from . import plot
-from . import bregman
-from . import lp
from . import da
-from . import optim
+
# OT functions
diff --git a/ot/bregman.py b/ot/bregman.py
index b749b13..08f965b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
"""
-Bregman projection for regularized Otimal transport
+Bregman projections for regularized OT
"""
import numpy as np
-def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
+def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -43,14 +43,18 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
Max number of iterations
stopThr: float, optional
Stop threshol on error (>0)
-
+ verbose : int, optional
+ Print information along iterations
+ log : int, optional
+ record log if True
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
-
+ log: dict
+ log dictionary return only if log==True in parameters
Examples
--------
@@ -91,6 +95,8 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
cpt = 0
+ if log:
+ log={'loss':[]}
# we assume that no distances are null except those of the diagonal of distances
u = np.ones(Nini)/Nini
@@ -124,10 +130,19 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
# we can speed up the process by checking for the error only all the 10th iterations
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
+ if log:
+ log['loss'].append(err)
+
+ if verbose:
+ if cpt%200 ==0:
+ print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
+ print('{:5d}|{:8e}|'.format(cpt,err))
cpt = cpt +1
#print 'err=',err,' cpt=',cpt
-
- return np.dot(np.diag(u),np.dot(K,np.diag(v)))
+ if log:
+ return np.dot(np.diag(u),np.dot(K,np.diag(v))),log
+ else:
+ return np.dot(np.diag(u),np.dot(K,np.diag(v)))
def geometricBar(weights,alldistribT):
diff --git a/ot/lp/emd.cpp b/ot/lp/emd.cpp
index 26d243f..6db54bb 100644
--- a/ot/lp/emd.cpp
+++ b/ot/lp/emd.cpp
@@ -1229,7 +1229,7 @@ static PyObject *__pyx_codeobj__8;
/* Python wrapper */
static PyObject *__pyx_pw_2ot_2lp_3emd_1emd_c(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
-static char __pyx_doc_2ot_2lp_3emd_emd_c[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n gamm=emd(a,b,M)\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray\n source histogram \n b : (nt,) ndarray\n target histogram\n M : (ns,nt) ndarray\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n ";
+static char __pyx_doc_2ot_2lp_3emd_emd_c[] = "\n Solves the Earth Movers distance problem and returns the optimal transport matrix\n \n gamm=emd(a,b,M)\n \n .. math::\n \\gamma = arg\\min_\\gamma <\\gamma,M>_F \n \n s.t. \\gamma 1 = a\n \n \\gamma^T 1= b \n \n \\gamma\\geq 0\n where :\n \n - M is the metric cost matrix\n - a and b are the sample weights\n \n Parameters\n ----------\n a : (ns,) ndarray, float64\n source histogram \n b : (nt,) ndarray, float64\n target histogram\n M : (ns,nt) ndarray, float64\n loss matrix \n \n \n Returns\n -------\n gamma: (ns x nt) ndarray\n Optimal transportation matrix for the given parameters\n \n ";
static PyMethodDef __pyx_mdef_2ot_2lp_3emd_1emd_c = {"emd_c", (PyCFunction)__pyx_pw_2ot_2lp_3emd_1emd_c, METH_VARARGS|METH_KEYWORDS, __pyx_doc_2ot_2lp_3emd_emd_c};
static PyObject *__pyx_pw_2ot_2lp_3emd_1emd_c(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
PyArrayObject *__pyx_v_a = 0;
diff --git a/ot/optim.py b/ot/optim.py
index e6373ce..d1bf672 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Created on Wed Oct 26 15:08:19 2016
-
-@author: rflamary
+Optimization algorithms for OT
"""
import numpy as np
@@ -12,6 +10,42 @@ from lp import emd
# The corresponding scipy function does not work for matrices
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
+ """
+ Armijo linesearch function that works with matrices
+
+ find an approximate minimum of f(xk+alpha*pk) that satifies the
+ armijo conditions.
+
+ Parameters
+ ----------
+
+ f : function
+ loss function
+ xk : np.ndarray
+ initial position
+ pk : np.ndarray
+ descent direction
+ gfk : np.ndarray
+ gradient of f at xk
+ old_fval: float
+ loss value at xk
+ args : tuple, optional
+ arguments given to f
+ c1 : float, optional
+ c1 const in armijo rule (>0)
+ alpha0 : float, optional
+ initial step (>0)
+
+ Returns
+ -------
+ alpha : float
+ step that satisfy armijo conditions
+ fc : int
+ nb of function call
+ fa : float
+ loss value at step alpha
+
+ """
xk = np.atleast_1d(xk)
fc = [0]
@@ -61,14 +95,26 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
samples in the target domain
M : np.ndarray (ns,nt)
loss matrix
- reg: float()
+ reg : float
Regularization term >0
-
+ G0 : np.ndarray (ns,nt), optional
+ initial guess (default is indep joint density)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : int, optional
+ Print information along iterations
+ log : int, optional
+ record log if True
Returns
-------
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ log: dict
+ log dictionary return only if log==True in parameters
+
References
----------
@@ -77,7 +123,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
See Also
--------
- ot.emd.emd : Unregularized optimal ransport
+ ot.lp.emd : Unregularized optimal ransport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""