diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 14:53:52 +0100 |
commit | 566645ad184e1205f7f666ea2f19021254c33d74 (patch) | |
tree | 1d740a6771ab515d0cbfe9f21fde801398eb19b6 /ot/da.py | |
parent | 981351165dbab740145d109b00782f0c41f2244b (diff) |
add mapping estimation (still debugging)
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 174 |
1 files changed, 165 insertions, 9 deletions
@@ -6,13 +6,15 @@ Domain adaptation with optimal transport import numpy as np from .bregman import sinkhorn from .lp import emd -from .utils import unif,dist +from .utils import unif,dist,kernel from .optim import cg def indices(a, func): return [i for (i, val) in enumerate(a) if func(val)] + + def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): """ Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization @@ -129,13 +131,15 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if bias: xs1=np.hstack((xs,np.ones((ns,1)))) - I=eta*np.eye(d+1) + xstxs=xs1.T.dot(xs1) + I=np.eye(d+1) I[-1]=0 I0=I[:,:-1] sel=lambda x : x[:-1,:] else: xs1=xs - I=eta*np.eye(d) + xstxs=xs1.T.dot(xs1) + I=np.eye(d) I0=I sel=lambda x : x @@ -143,20 +147,22 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos log={'err':[]} a,b=unif(ns),unif(nt) - M=dist(xs,xt) + M=dist(xs,xt)*ns G=emd(a,b,M) vloss=[] def loss(L,G): + """Compute full loss""" return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) def solve_L(G): - """ solve problem with fixed G""" + """ solve L problem with fixed G (least square)""" xst=ns*G.dot(xt) - return np.linalg.solve(xs1.T.dot(xs1)+I,xs1.T.dot(xst)+I0) + return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0) def solve_G(L,G0): + """Update G with CG algorithm""" xsi=xs1.dot(L) def f(G): return np.sum((xsi-ns*G.dot(xt))**2) @@ -175,8 +181,11 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) - # regul matrix - loop=1 + # init loop + if numItermax>0: + loop=1 + else: + loop=0 it=0 while loop: @@ -191,6 +200,9 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos vloss.append(loss(L,G)) + if it>=numItermax: + loop=0 + if abs(vloss[-1]-vloss[-2])<stopThr: loop=0 @@ -198,11 +210,106 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if it%20==0: print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + if log: + log['loss']=vloss + return G,L,log + else: + return G,L - return G,L +def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): + """Joint Ot and mapping estimation (uniform weights and ) + """ + ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1] + if bias: + K= + xs1=np.hstack((xs,np.ones((ns,1)))) + xstxs=xs1.T.dot(xs1) + I=np.eye(d+1) + I[-1]=0 + I0=I[:,:-1] + sel=lambda x : x[:-1,:] + else: + xs1=xs + xstxs=xs1.T.dot(xs1) + I=np.eye(d) + I0=I + sel=lambda x : x + + if log: + log={'err':[]} + + a,b=unif(ns),unif(nt) + M=dist(xs,xt)*ns + G=emd(a,b,M) + + vloss=[] + + def loss(L,G): + """Compute full loss""" + return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2) + + def solve_L(G): + """ solve L problem with fixed G (least square)""" + xst=ns*G.dot(xt) + return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0) + + def solve_G(L,G0): + """Update G with CG algorithm""" + xsi=xs1.dot(L) + def f(G): + return np.sum((xsi-ns*G.dot(xt))**2) + def df(G): + return -2*ns*(xsi-ns*G.dot(xt)).dot(xt.T) + G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr) + return G + + + L=solve_L(G) + + vloss.append(loss(L,G)) + + if verbose: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(0,vloss[-1],0)) + + + # init loop + if numItermax>0: + loop=1 + else: + loop=0 + it=0 + + while loop: + + it+=1 + + # update G + G=solve_G(L,G) + + #update L + L=solve_L(G) + + vloss.append(loss(L,G)) + + if it>=numItermax: + loop=0 + + if abs(vloss[-1]-vloss[-2])<stopThr: + loop=0 + + if verbose: + if it%20==0: + print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) + print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + if log: + log['loss']=vloss + return G,L,log + else: + return G,L class OTDA(object): @@ -294,6 +401,7 @@ class OTDA(object): class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" + def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): """ Fit domain adaptation between samples is xs and xt (with optional weights)""" @@ -335,3 +443,51 @@ class OTDA_lpl1(OTDA): self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True +class OTDA_mapping(OTDA): + """Class for optimal transport with joint linear mapping estimation""" + + + def __init__(self,metric='sqeuclidean'): + """ Class initialization""" + + + self.xs=0 + self.xt=0 + self.G=0 + self.L=0 + self.bias=False + self.metric=metric + self.computed=False + + def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + self.bias=bias + + self.ws=unif(xs.shape[0]) + self.wt=unif(xt.shape[0]) + + self.G,self.L=joint_OT_mapping_linear(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs) + self.computed=True + + def mapping(self): + return lambda x: self.predict(x) + + + def predict(self,x): + """ Out of sample mapping using the formulation from Ferradans + + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + + """ + if self.computed: + if self.bias: + x=np.hstack((x,np.ones((x.shape[0],1)))) + return x.dot(self.L) # aply the delta to the interpolation + else: + print("Warning, model not fitted yet, returning None") + return None + |