diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 16:23:57 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-28 16:23:57 +0200 |
commit | d3d8689b9230ea6066409ff44969817da6f5af50 (patch) | |
tree | b03de6e0717b8deac8286863b9a6c9e79bb1cd14 /ot/da.py | |
parent | 996c6681513184c837c2c1f17af2cae9e5106676 (diff) |
first class for DA
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 45 |
1 files changed, 44 insertions, 1 deletions
@@ -5,6 +5,8 @@ Domain adaptation with optimal transport import numpy as np from .bregman import sinkhorn +from .lp import emd +from .utils import unif,dist @@ -118,4 +120,45 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter W[indices_labels[0],t]=np.min(all_maj) return transp -
\ No newline at end of file + + + +class OTDA(): + """Class for optimal transport with domain adaptation""" + + def __init__(self): + self.xs=0 + self.xt=0 + self.G=0 + + + def fit(self,xs,xt,ws=None,wt=None): + self.xs=xs + self.xt=xt + + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt) + self.G=emd(ws,wt,self.M) + + def interp(self,direction=1): + """Barycentric interpolation for the source (1) or target (-1)""" + + if self.G and direction>0: + return (self.G/self.ws).dot(self.xt) + elif self.G and direction<0: + return (self.G.T/self.wt).dot(self.xs) + + + + + + + +
\ No newline at end of file |