summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 16:23:57 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 16:23:57 +0200
commitd3d8689b9230ea6066409ff44969817da6f5af50 (patch)
treeb03de6e0717b8deac8286863b9a6c9e79bb1cd14 /ot/da.py
parent996c6681513184c837c2c1f17af2cae9e5106676 (diff)
first class for DA
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py45
1 files changed, 44 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index 083138f..19d41bc 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -5,6 +5,8 @@ Domain adaptation with optimal transport
import numpy as np
from .bregman import sinkhorn
+from .lp import emd
+from .utils import unif,dist
@@ -118,4 +120,45 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
W[indices_labels[0],t]=np.min(all_maj)
return transp
- \ No newline at end of file
+
+
+
+class OTDA():
+ """Class for optimal transport with domain adaptation"""
+
+ def __init__(self):
+ self.xs=0
+ self.xt=0
+ self.G=0
+
+
+ def fit(self,xs,xt,ws=None,wt=None):
+ self.xs=xs
+ self.xt=xt
+
+ if wt is None:
+ wt=unif(xt.shape[0])
+ if ws is None:
+ ws=unif(xs.shape[0])
+
+ self.ws=ws
+ self.wt=wt
+
+ self.M=dist(xs,xt)
+ self.G=emd(ws,wt,self.M)
+
+ def interp(self,direction=1):
+ """Barycentric interpolation for the source (1) or target (-1)"""
+
+ if self.G and direction>0:
+ return (self.G/self.ws).dot(self.xt)
+ elif self.G and direction<0:
+ return (self.G.T/self.wt).dot(self.xs)
+
+
+
+
+
+
+
+ \ No newline at end of file