From 7e16b7a80f3a2896351262a02af27a60401b6a5e Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 3 Nov 2016 17:07:22 +0100 Subject: add mapping estimation with kernels (still debugging) --- ot/da.py | 48 ++++++++++++++++++++++++++++++++++++++++++++---- ot/datasets.py | 6 +++--- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/ot/da.py b/ot/da.py index e4aa0be..49fa79e 100644 --- a/ot/da.py +++ b/ot/da.py @@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b def loss(L,G): """Compute full loss""" - return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2) + return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" @@ -450,11 +450,11 @@ class OTDA_lpl1(OTDA): self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) self.computed=True -class OTDA_mapping(OTDA): +class OTDA_mapping_linear(OTDA): """Class for optimal transport with joint linear mapping estimation""" - def __init__(self,metric='sqeuclidean'): + def __init__(self): """ Class initialization""" @@ -463,8 +463,8 @@ class OTDA_mapping(OTDA): self.G=0 self.L=0 self.bias=False - self.metric=metric self.computed=False + self.metric='sqeuclidean' def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs): """ Fit domain adaptation between samples is xs and xt (with optional @@ -473,6 +473,7 @@ class OTDA_mapping(OTDA): self.xt=xt self.bias=bias + self.ws=unif(xs.shape[0]) self.wt=unif(xt.shape[0]) @@ -498,3 +499,42 @@ class OTDA_mapping(OTDA): print("Warning, model not fitted yet, returning None") return None +class OTDA_mapping_kernel(OTDA_mapping_linear): + """Class for optimal transport with joint linear mapping estimation""" + + + + def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + self.bias=bias + + self.ws=unif(xs.shape[0]) + self.wt=unif(xt.shape[0]) + self.kernel=kerneltype + self.sigma=sigma + self.kwargs=kwargs + + + self.G,self.L=joint_OT_mapping_kernel(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs) + self.computed=True + + + def predict(self,x): + """ Out of sample mapping using the formulation from Ferradans + + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + + """ + + if self.computed: + K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs) + if self.bias: + K=np.hstack((K,np.ones((x.shape[0],1)))) + return K.dot(self.L) + else: + print("Warning, model not fitted yet, returning None") + return None \ No newline at end of file diff --git a/ot/datasets.py b/ot/datasets.py index 588f501..c750812 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -108,9 +108,9 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs): x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) elif dataset.lower()=='gaussrot' : - rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]) - m1=np.array([-1,-1]) - m2=np.array([1,1]) + rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]]) + m1=np.array([-1,1]) + m2=np.array([1,-1]) y=np.floor((np.arange(n)*1.0/n*2))+1 n1=np.sum(y==1) n2=np.sum(y==2) -- cgit v1.2.3