summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:07:22 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:07:22 +0100
commit7e16b7a80f3a2896351262a02af27a60401b6a5e (patch)
tree47d2ba913bc5ddc726441d284638db9d7c2cff86 /ot/da.py
parent86b1c88eb0c2c43853bca38de96d2278cc90ceba (diff)
add mapping estimation with kernels (still debugging)
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py48
1 files changed, 44 insertions, 4 deletions
diff --git a/ot/da.py b/ot/da.py
index e4aa0be..49fa79e 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -247,7 +247,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
def loss(L,G):
"""Compute full loss"""
- return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2)
+ return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -450,11 +450,11 @@ class OTDA_lpl1(OTDA):
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
self.computed=True
-class OTDA_mapping(OTDA):
+class OTDA_mapping_linear(OTDA):
"""Class for optimal transport with joint linear mapping estimation"""
- def __init__(self,metric='sqeuclidean'):
+ def __init__(self):
""" Class initialization"""
@@ -463,8 +463,8 @@ class OTDA_mapping(OTDA):
self.G=0
self.L=0
self.bias=False
- self.metric=metric
self.computed=False
+ self.metric='sqeuclidean'
def fit(self,xs,xt,mu=1,eta=1,bias=False,**kwargs):
""" Fit domain adaptation between samples is xs and xt (with optional
@@ -473,6 +473,7 @@ class OTDA_mapping(OTDA):
self.xt=xt
self.bias=bias
+
self.ws=unif(xs.shape[0])
self.wt=unif(xt.shape[0])
@@ -498,3 +499,42 @@ class OTDA_mapping(OTDA):
print("Warning, model not fitted yet, returning None")
return None
+class OTDA_mapping_kernel(OTDA_mapping_linear):
+ """Class for optimal transport with joint linear mapping estimation"""
+
+
+
+ def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
+ self.xs=xs
+ self.xt=xt
+ self.bias=bias
+
+ self.ws=unif(xs.shape[0])
+ self.wt=unif(xt.shape[0])
+ self.kernel=kerneltype
+ self.sigma=sigma
+ self.kwargs=kwargs
+
+
+ self.G,self.L=joint_OT_mapping_kernel(xs,xt,mu=mu,eta=eta,bias=bias,**kwargs)
+ self.computed=True
+
+
+ def predict(self,x):
+ """ Out of sample mapping using the formulation from Ferradans
+
+ It basically find the source sample the nearset to the nex sample and
+ apply the difference to the displaced source sample.
+
+ """
+
+ if self.computed:
+ K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs)
+ if self.bias:
+ K=np.hstack((K,np.ones((x.shape[0],1))))
+ return K.dot(self.L)
+ else:
+ print("Warning, model not fitted yet, returning None")
+ return None \ No newline at end of file