summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-31 14:23:40 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-10-31 14:23:40 +0100
commit104627b2f69eb22d3f9010955e6765ac2b179faa (patch)
treea26f407fb3d036d8d02c31f54889c6c911685b3a
parent9e40820bee3570354ffdd35a69e18bd16703719a (diff)
commit doc
-rw-r--r--docs/source/conf.py7
-rw-r--r--ot/da.py39
2 files changed, 38 insertions, 8 deletions
diff --git a/docs/source/conf.py b/docs/source/conf.py
index b56df38..ffc0da8 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -14,9 +14,14 @@
import sys
import os
-from unittest.mock import MagicMock
+try:
+ from unittest.mock import MagicMock
+except ImportError:
+ from mock import MagicMock
sys.path.insert(0, os.path.abspath("../.."))
+sys.setrecursionlimit(1500)
+
class Mock(MagicMock):
diff --git a/ot/da.py b/ot/da.py
index 11e420b..87354b9 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -188,21 +188,46 @@ class OTDA():
return None
- def predict(x,direction=1):
+ def predict(self,x,direction=1):
""" Out of sample mapping using the formulation from Ferradans
+ It basically find the source sample the nearset to the nex sample and
+ apply the difference to the displaced source sample.
+
"""
if direction>0: # >0 then source to target
- G=self.G
- w=self.ws.reshape((self.xs.shape[0],1))
- x=self.xt
+ xf=self.xt
+ x0=self.xs
else:
- G=self.G.T
- w=self.wt.reshape((self.xt.shape[0],1))
- x=self.xs
+ xf=self.xs
+ x0=self.xt
+
+ D0=dist(x,x0) # dist netween new samples an source
+ idx=np.argmin(D0,1) # closest one
+ xf=self.interp(direction)# interp the source samples
+ return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
+
+class OTDA_sinkhorn(OTDA):
+
+ def fit(self,xs,xt,ws=None,wt=None,reg=1,**kwargs):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
+ self.xs=xs
+ self.xt=xt
+ if wt is None:
+ wt=unif(xt.shape[0])
+ if ws is None:
+ ws=unif(xs.shape[0])
+
+ self.ws=ws
+ self.wt=wt
+
+ self.M=dist(xs,xt,metric=self.metric)
+ self.G=sinkhorn(ws,wt,self.M,reg,**kwargs)
+ self.computed=True