summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:07:22 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:07:22 +0100
commit7e16b7a80f3a2896351262a02af27a60401b6a5e (patch)
tree47d2ba913bc5ddc726441d284638db9d7c2cff86 /ot
parent86b1c88eb0c2c43853bca38de96d2278cc90ceba (diff)
add mapping estimation with kernels (still debugging)
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py48
-rw-r--r--ot/datasets.py6
2 files changed, 47 insertions, 7 deletions
diff --git a/ot/da.py b/ot/da.py
index e4aa0be..49fa79e 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
def loss(L,G):
"""Compute full loss"""
- return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2)
+ return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -450,11 +450,11 @@ class OTDA_lpl1(OTDA):
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
self.computed=True
-class OTDA_mapping(OTDA):
+class OTDA_mapping_linear(OTDA):
"""Class for optimal transport with joint linear mapping estimation"""
- def __init__(self,metric='sqeuclidean'):
+ def __init__(self):
""" Class initialization"""
@@ -463,8 +463,8 @@ class OTDA_mapping(OTDA):
self.G=0
self.L=0
self.bias=False
- self.metric=metric
self.computed=False
+ self.metric='sqeuclidean'
def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
""" Fit domain adaptation between samples is xs and xt (with optional
@@ -473,6 +473,7 @@ class OTDA_mapping(OTDA):
self.xt=xt
self.bias=bias
+
self.ws=unif(xs.shape[0])
self.wt=unif(xt.shape[0])
@@ -498,3 +499,42 @@ class OTDA_mapping(OTDA):
print("Warning, model not fitted yet, returning None")
return None
+class OTDA_mapping_kernel(OTDA_mapping_linear):
+ """Class for optimal transport with joint linear mapping estimation"""
+
+
+
+ def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
+ self.xs=xs
+ self.xt=xt
+ self.bias=bias
+
+ self.ws=unif(xs.shape[0])
+ self.wt=unif(xt.shape[0])
+ self.kernel=kerneltype
+ self.sigma=sigma
+ self.kwargs=kwargs
+
+
+ self.G,self.L=joint_OT_mapping_kernel(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs)
+ self.computed=True
+
+
+ def predict(self,x):
+ """ Out of sample mapping using the formulation from Ferradans
+
+ It basically find the source sample the nearset to the nex sample and
+ apply the difference to the displaced source sample.
+
+ """
+
+ if self.computed:
+ K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs)
+ if self.bias:
+ K=np.hstack((K,np.ones((x.shape[0],1))))
+ return K.dot(self.L)
+ else:
+ print("Warning, model not fitted yet, returning None")
+ return None \ No newline at end of file
diff --git a/ot/datasets.py b/ot/datasets.py
index 588f501..c750812 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -108,9 +108,9 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
elif dataset.lower()=='gaussrot' :
- rot=np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
- m1=np.array([-1,-1])
- m2=np.array([1,1])
+ rot=np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]])
+ m1=np.array([-1,1])
+ m2=np.array([1,-1])
y=np.floor((np.arange(n)*1.0/n*2))+1
n1=np.sum(y==1)
n2=np.sum(y==2)