From 104627b2f69eb22d3f9010955e6765ac2b179faa Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Mon, 31 Oct 2016 14:23:40 +0100 Subject: commit doc --- docs/source/conf.py | 7 ++++++- ot/da.py | 39 ++++++++++++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b56df38..ffc0da8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,9 +14,14 @@ import sys import os -from unittest.mock import MagicMock +try: + from unittest.mock import MagicMock +except ImportError: + from mock import MagicMock sys.path.insert(0, os.path.abspath("../..")) +sys.setrecursionlimit(1500) + class Mock(MagicMock): diff --git a/ot/da.py b/ot/da.py index 11e420b..87354b9 100644 --- a/ot/da.py +++ b/ot/da.py @@ -188,21 +188,46 @@ class OTDA(): return None - def predict(x,direction=1): + def predict(self,x,direction=1): """ Out of sample mapping using the formulation from Ferradans + It basically find the source sample the nearset to the nex sample and + apply the difference to the displaced source sample. + """ if direction>0: # >0 then source to target - G=self.G - w=self.ws.reshape((self.xs.shape[0],1)) - x=self.xt + xf=self.xt + x0=self.xs else: - G=self.G.T - w=self.wt.reshape((self.xt.shape[0],1)) - x=self.xs + xf=self.xs + x0=self.xt + + D0=dist(x,x0) # dist netween new samples an source + idx=np.argmin(D0,1) # closest one + xf=self.interp(direction)# interp the source samples + return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation + +class OTDA_sinkhorn(OTDA): + + def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs): + """ Fit domain adaptation between samples is xs and xt (with optional + weights)""" + self.xs=xs + self.xt=xt + if wt is None: + wt=unif(xt.shape[0]) + if ws is None: + ws=unif(xs.shape[0]) + + self.ws=ws + self.wt=wt + + self.M=dist(xs,xt,metric=self.metric) + self.G=sinkhorn(ws,wt,self.M,reg,**kwargs) + self.computed=True -- cgit v1.2.3