summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/demo_OTDA_classes.py37
-rw-r--r--examples/demo_optim_OTreg.py4
-rw-r--r--ot/da.py116
-rw-r--r--ot/optim.py13
4 files changed, 154 insertions, 16 deletions
diff --git a/examples/demo_OTDA_classes.py b/examples/demo_OTDA_classes.py
index 8a97c80..43fd37e 100644
--- a/examples/demo_OTDA_classes.py
+++ b/examples/demo_OTDA_classes.py
@@ -48,43 +48,62 @@ da_entrop=ot.da.OTDA_sinkhorn()
da_entrop.fit(xs,xt,reg=lambd)
xsts=da_entrop.interp()
-# Group lasso regularization
+# non-convex Group lasso regularization
reg=1e-1
eta=1e0
da_lpl1=ot.da.OTDA_lpl1()
-da_lpl1.fit(xs,ys,xt,reg=lambd,eta=eta)
+da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
xstg=da_lpl1.interp()
+
+# True Group lasso regularization
+reg=1e-1
+eta=1e1
+da_l1l2=ot.da.OTDA_l1l2()
+da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
+xstgl=da_l1l2.interp()
+
+
#%% plot interpolated source samples
-pl.figure(4,(15,10))
+pl.figure(4,(15,8))
param_img={'interpolation':'nearest','cmap':'jet'}
-pl.subplot(2,3,1)
+pl.subplot(2,4,1)
pl.imshow(da_emd.G,**param_img)
pl.title('OT matrix')
-pl.subplot(2,3,2)
+pl.subplot(2,4,2)
pl.imshow(da_entrop.G,**param_img)
pl.title('OT matrix sinkhorn')
-pl.subplot(2,3,3)
+pl.subplot(2,4,3)
pl.imshow(da_lpl1.G,**param_img)
+pl.title('OT matrix non-convex Group Lasso')
+
+pl.subplot(2,4,4)
+pl.imshow(da_l1l2.G,**param_img)
pl.title('OT matrix Group Lasso')
-pl.subplot(2,3,4)
+
+pl.subplot(2,4,5)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
pl.title('Interp samples')
pl.legend(loc=0)
-pl.subplot(2,3,5)
+pl.subplot(2,4,6)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
pl.title('Interp samples Sinkhorn')
-pl.subplot(2,3,6)
+pl.subplot(2,4,7)
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
+pl.title('Interp samples non-convex Group Lasso')
+
+pl.subplot(2,4,8)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
+pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
pl.title('Interp samples Group Lasso') \ No newline at end of file
diff --git a/examples/demo_optim_OTreg.py b/examples/demo_optim_OTreg.py
index a49b00c..0de6b08 100644
--- a/examples/demo_optim_OTreg.py
+++ b/examples/demo_optim_OTreg.py
@@ -60,8 +60,8 @@ ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')
def f(G): return 0.5*np.sum(G**2)
def df(G): return G
-reg1=1e-3
-reg2=1e-3
+reg1=1e-1
+reg2=1e-1
Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)
diff --git a/ot/da.py b/ot/da.py
index fb40782..81b6a35 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -8,6 +8,7 @@ from .bregman import sinkhorn
from .lp import emd
from .utils import unif,dist,kernel
from .optim import cg
+from .optim import gcg
def indices(a, func):
@@ -122,6 +123,100 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
return transp
+def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
+ """
+ Solve the entropic regularization optimal transport problem with group lasso regularization
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ labels_a : np.ndarray (ns,)
+ labels of samples in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term for entropic regularization >0
+ eta : float, optional
+ Regularization term for group lasso regularization >0
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner sinkhorn solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+
+ See Also
+ --------
+ ot.optim.gcg : Generalized conditional gradient for OT problems
+
+ """
+ lstlab=np.unique(labels_a)
+
+ def f(G):
+ res=0
+ for i in range(G.shape[1]):
+ for lab in lstlab:
+ temp=G[labels_a==lab,i]
+ res+=np.linalg.norm(temp)
+ return res
+
+ def df(G):
+ W=np.zeros(G.shape)
+ for i in range(G.shape[1]):
+ for lab in lstlab:
+ temp=G[labels_a==lab,i]
+ n=np.linalg.norm(temp)
+ if n:
+ W[labels_a==lab,i]=temp/n
+ return W
+
+
+ return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log)
+
+
+
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
"""Joint OT and linear mapping estimation as proposed in [8]
@@ -632,6 +727,27 @@ class OTDA_lpl1(OTDA):
self.M=dist(xs,xt,metric=self.metric)
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
self.computed=True
+
+class OTDA_l1l2(OTDA):
+ """Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
+
+
+ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
+ self.xs=xs
+ self.xt=xt
+
+ if wt is None:
+ wt=unif(xt.shape[0])
+ if ws is None:
+ ws=unif(xs.shape[0])
+
+ self.ws=ws
+ self.wt=wt
+
+ self.M=dist(xs,xt,metric=self.metric)
+ self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs)
+ self.computed=True
class OTDA_mapping_linear(OTDA):
"""Class for optimal transport with joint linear mapping estimation as in [8]"""
diff --git a/ot/optim.py b/ot/optim.py
index 598e23f..d807824 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -7,6 +7,7 @@ import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn_stabilized
+from .bregman import sinkhorn
# The corresponding scipy function does not work for matrices
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
@@ -195,7 +196,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
else:
return G
-def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False):
+def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopThr=1e-9,verbose=False,log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -235,6 +236,8 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations of Sinkhorn
stopThr : float, optional
Stop threshol on error (>0)
verbose : bool, optional
@@ -293,16 +296,16 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False
# problem linearization
Mi=M+reg2*df(G)
- # set M positive
- Mi+=Mi.min()
# solve linear program with Sinkhorn
- Gc = sinkhorn_stabilized(a,b, Mi, reg1)
+ #Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
+ Gc = sinkhorn(a,b, Mi, reg1, numItermax = numInnerItermax)
deltaG=Gc-G
# line search
- alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val)
+ dcost=Mi+reg1*np.sum(deltaG*(1+np.log(G))) #??
+ alpha,fc,f_val = line_search_armijo(cost,G,deltaG,dcost,f_val)
G=G+alpha*deltaG