diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/da.py | 45 |
1 files changed, 44 insertions, 1 deletions
@@ -5,6 +5,8 @@ Domain adaptation with optimal transport import numpy as np from .bregman import sinkhorn +from .lp import emd +from .utils import unif,dist @@ -118,4 +120,45 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter W[indices_labels[0],t]=np.min(all_maj) return transp -
\ No newline at end of file + + + +class OTDA(): + """Class for optimal transport with domain adaptation""" + + def __init__(self): + self.xs=0 + self.xt=0 + self.G=0 + + + def fit(self,xs,xt,ws=None,wt=None): + self.xs=xs + self.xt=xt + + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt) + self.G=emd(ws,wt,self.M) + + def interp(self,direction=1): + """Barycentric interpolation for the source (1) or target (-1)""" + + if self.G and direction>0: + return (self.G/self.ws).dot(self.xt) + elif self.G and direction<0: + return (self.G.T/self.wt).dot(self.xs) + + + + + + + +
\ No newline at end of file |