From 9e40820bee3570354ffdd35a69e18bd16703719a Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 31 Oct 2016 13:43:48 +0100 Subject: firt DA class --- ot/da.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 9 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index 19d41bc..11e420b 100644 --- a/ot/da.py +++ b/ot/da.py @@ -9,7 +9,6 @@ from .lp import emd from .utils import unif,dist - def indices(a, func): return [i for (i, val) in enumerate(a) if func(val)] @@ -124,15 +123,20 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter class OTDA(): - """Class for optimal transport with domain adaptation""" + """Class for domain adaptation with optimal transport""" - def __init__(self): + def __init__(self,metric='sqeuclidean'): + """ Class initialization""" self.xs=0 self.xt=0 self.G=0 + self.metric=metric + self.computed=False def fit(self,xs,xt,ws=None,wt=None): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" self.xs=xs self.xt=xt @@ -144,17 +148,58 @@ class OTDA(): self.ws=ws self.wt=wt - self.M=dist(xs,xt) + self.M=dist(xs,xt,metric=self.metric) self.G=emd(ws,wt,self.M) + self.computed=True def interp(self,direction=1): - """Barycentric interpolation for the source (1) or target (-1)""" + """Barycentric interpolation for the source (1) or target (-1) + + This Barycentric interpolation solves for each source (resp target) + sample xs (resp xt) the following optimization problem: + + .. math:: + arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t) + + where k is the index of the sample in xs + + For the moment only squared euclidean distance is provided but more + metric c can be used in teh future. + + """ + if direction>0: # >0 then source to target + G=self.G + w=self.ws.reshape((self.xs.shape[0],1)) + x=self.xt + else: + G=self.G.T + w=self.wt.reshape((self.xt.shape[0],1)) + x=self.xs + + if self.computed: + if self.metric=='sqeuclidean': + return np.dot(G/w,x) # weighted mean + else: + print("Warning, metric not handled yet, using weighted average") + return np.dot(G/w,x) # weighted mean + return None + else: + print("Warning, model not fitted yet, returning None") + return None + - if self.G and direction>0: - return (self.G/self.ws).dot(self.xt) - elif self.G and direction<0: - return (self.G.T/self.wt).dot(self.xs) + def predict(x,direction=1): + """ Out of sample mapping using the formulation from Ferradans + """ + if direction>0: # >0 then source to target + G=self.G + w=self.ws.reshape((self.xs.shape[0],1)) + x=self.xt + else: + G=self.G.T + w=self.wt.reshape((self.xt.shape[0],1)) + x=self.xs -- cgit v1.2.3