diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-31 14:23:40 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-10-31 14:23:40 +0100 |
commit | 104627b2f69eb22d3f9010955e6765ac2b179faa (patch) | |
tree | a26f407fb3d036d8d02c31f54889c6c911685b3a | |
parent | 9e40820bee3570354ffdd35a69e18bd16703719a (diff) |
commit doc
-rw-r--r-- | docs/source/conf.py | 7 | ||||
-rw-r--r-- | ot/da.py | 39 |
2 files changed, 38 insertions, 8 deletions
diff --git a/docs/source/conf.py b/docs/source/conf.py index b56df38..ffc0da8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,9 +14,14 @@ import sys import os -from unittest.mock import MagicMock +try: + from unittest.mock import MagicMock +except ImportError: + from mock import MagicMock sys.path.insert(0, os.path.abspath("../..")) +sys.setrecursionlimit(1500) + class Mock(MagicMock): @@ -188,21 +188,46 @@ class OTDA(): return None - def predict(x,direction=1): + def predict(self,x,direction=1): """ Out of sample mapping using the formulation from Ferradans + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + """ if direction>0: # >0 then source to target - G=self.G - w=self.ws.reshape((self.xs.shape[0],1)) - x=self.xt + xf=self.xt + x0=self.xs else: - G=self.G.T - w=self.wt.reshape((self.xt.shape[0],1)) - x=self.xs + xf=self.xs + x0=self.xt + + D0=dist(x,x0) # dist netween new samples an source + idx=np.argmin(D0,1) # closest one + xf=self.interp(direction)# interp the source samples + return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation + +class OTDA_sinkhorn(OTDA): + + def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt,metric=self.metric) + self.G=sinkhorn(ws,wt,self.M,reg,**kwargs) + self.computed=True |