diff options
-rw-r--r-- | docs/source/all.rst | 8 | ||||
-rw-r--r-- | ot/__init__.py | 15 | ||||
-rw-r--r-- | ot/bregman.py | 8 | ||||
-rw-r--r-- | ot/da.py | 91 | ||||
-rw-r--r-- | ot/lp/__init__.py | 77 |
5 files changed, 142 insertions, 57 deletions
diff --git a/docs/source/all.rst b/docs/source/all.rst index 30f5add..d5733b8 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -6,8 +6,6 @@ Python modules ot -- -This module provide easy access to solvers for the most common OT problems - .. automodule:: ot :members: @@ -28,6 +26,12 @@ ot.optim .. automodule:: ot.optim :members: +ot.da +-------- + +.. automodule:: ot.da + :members: + ot.utils -------- diff --git a/ot/__init__.py b/ot/__init__.py index 0602eed..72c820a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,22 +1,21 @@ -# Python Optimal Transport toolbox +"""Python Optimal Transport toolbox""" # All submodules and packages -from . import lp +from . import lp from . import bregman -from . import optim +from . import optim from . import utils from . import datasets from . import plot from . import da - - # OT functions from .lp import emd -from .bregman import sinkhorn,barycenter +from .bregman import sinkhorn, barycenter from .da import sinkhorn_lpl1_mm # utils functions -from .utils import dist,unif +from .utils import dist, unif -__all__ = ["emd","sinkhorn","utils",'datasets','bregman','lp','plot','dist','unif','barycenter','sinkhorn_lpl1_mm','da','optim'] +__all__ = ["emd", "sinkhorn", "utils", 'datasets', 'bregman', 'lp', 'plot', + 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim'] diff --git a/ot/bregman.py b/ot/bregman.py index ad9a67a..2d82ae4 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -6,12 +6,12 @@ Bregman projections for regularized OT import numpy as np -def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False): +def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False): """ - Solve the entropic regularization optimal transport problem and return the OT matrix - + Solve the entropic regularization optimal transport problem + The function solves the following optimization problem: - + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) @@ -1,6 +1,8 @@ +# -*- coding: utf-8 -*- """ -domain adaptation with optimal transport +Domain adaptation with optimal transport """ + import numpy as np from .bregman import sinkhorn @@ -9,7 +11,88 @@ from .bregman import sinkhorn def indices(a, func): return [i for (i, val) in enumerate(a) if func(val)] -def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1): +def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): + """ + Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain. + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + labels_a : np.ndarray (ns,) + labels of samples in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg: float + Regularization term for entropic regularization >0 + eta: float, optional + Regularization term for group lasso regularization >0 + numItermax: int, optional + Max number of iterations + numInnerItermax: int, optional + Max number of iterations (inner sinkhorn solver) + stopInnerThr: float, optional + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + + """ p=0.5 epsilon = 1e-3 @@ -25,9 +108,9 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1): W=np.zeros(M.shape) - for cpt in range(10): + for cpt in range(numItermax): Mreg = M + eta*W - transp=sinkhorn(a,b,Mreg,reg,numItermax = 200) + transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr) # the transport has been computed. Check if classes are really separated W = np.ones((Nini,Nfin)) for t in range(Nfin): diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 6c7822a..2adf937 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -1,31 +1,30 @@ +# -*- coding: utf-8 -*- """ Solvers for the original linear program OT problem """ +import numpy as np # import compiled emd from .emd import emd_c -import numpy as np -def emd(a,b,M): - """ - Solves the Earth Movers distance problem and returns the optimal transport matrix - - + +def emd(a, b, M): + """Solves the Earth Movers distance problem and returns the OT matrix + + .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F - + \gamma = arg\min_\gamma <\gamma,M>_F + s.t. \gamma 1 = a - - \gamma^T 1= b - + \gamma^T 1= b \gamma\geq 0 where : - + - M is the metric cost matrix - a and b are the sample weights - + Uses the algorithm proposed in [1]_ - + Parameters ---------- a : (ns,) ndarray, float64 @@ -33,47 +32,47 @@ def emd(a,b,M): b : (nt,) ndarray, float64 Target histogram (uniform weigth if empty list) M : (ns,nt) ndarray, float64 - loss matrix - + loss matrix + Returns ------- gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters - - + + Examples -------- - + Simple example with obvious solution. The function emd accepts lists and - perform automatic conversion to numpy arrays - + perform automatic conversion to numpy arrays + >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] >>> ot.emd(a,b,M) array([[ 0.5, 0. ], [ 0. , 0.5]]) - + References ---------- - - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. - + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + See Also -------- ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - - - """ - a=np.asarray(a,dtype=np.float64) - b=np.asarray(b,dtype=np.float64) - M=np.asarray(M,dtype=np.float64) - - if len(a)==0: - a=np.ones((M.shape[0],),dtype=np.float64)/M.shape[0] - if len(b)==0: - b=np.ones((M.shape[1],),dtype=np.float64)/M.shape[1] - - return emd_c(a,b,M) + ot.optim.cg : General regularized OT""" + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + + if len(a) == 0: + a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0] + if len(b) == 0: + b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1] + return emd_c(a, b, M) |