summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-31 13:43:48 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-10-31 13:43:48 +0100
commit9e40820bee3570354ffdd35a69e18bd16703719a (patch)
tree297de970a2b221ed9c07d154db293c9f79362ea6 /ot/da.py
parente3b115001ccde177d5d73e4f9dd0f52ed964371d (diff)
firt DA class
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py63
1 files changed, 54 insertions, 9 deletions
diff --git a/ot/da.py b/ot/da.py
index 19d41bc..11e420b 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -9,7 +9,6 @@ from .lp import emd
from .utils import unif,dist
-
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
@@ -124,15 +123,20 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
class OTDA():
- """Class for optimal transport with domain adaptation"""
+ """Class for domain adaptation with optimal transport"""
- def __init__(self):
+ def __init__(self,metric='sqeuclidean'):
+ """ Class initialization"""
self.xs=0
self.xt=0
self.G=0
+ self.metric=metric
+ self.computed=False
def fit(self,xs,xt,ws=None,wt=None):
+ """ Fit domain adaptation between samples is xs and xt (with optional
+ weights)"""
self.xs=xs
self.xt=xt
@@ -144,17 +148,58 @@ class OTDA():
self.ws=ws
self.wt=wt
- self.M=dist(xs,xt)
+ self.M=dist(xs,xt,metric=self.metric)
self.G=emd(ws,wt,self.M)
+ self.computed=True
def interp(self,direction=1):
- """Barycentric interpolation for the source (1) or target (-1)"""
+ """Barycentric interpolation for the source (1) or target (-1)
+
+ This Barycentric interpolation solves for each source (resp target)
+ sample xs (resp xt) the following optimization problem:
+
+ .. math::
+ arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
+
+ where k is the index of the sample in xs
+
+ For the moment only squared euclidean distance is provided but more
+ metric c can be used in teh future.
+
+ """
+ if direction>0: # >0 then source to target
+ G=self.G
+ w=self.ws.reshape((self.xs.shape[0],1))
+ x=self.xt
+ else:
+ G=self.G.T
+ w=self.wt.reshape((self.xt.shape[0],1))
+ x=self.xs
+
+ if self.computed:
+ if self.metric=='sqeuclidean':
+ return np.dot(G/w,x) # weighted mean
+ else:
+ print("Warning, metric not handled yet, using weighted average")
+ return np.dot(G/w,x) # weighted mean
+ return None
+ else:
+ print("Warning, model not fitted yet, returning None")
+ return None
+
- if self.G and direction>0:
- return (self.G/self.ws).dot(self.xt)
- elif self.G and direction<0:
- return (self.G.T/self.wt).dot(self.xs)
+ def predict(x,direction=1):
+ """ Out of sample mapping using the formulation from Ferradans
+ """
+ if direction>0: # >0 then source to target
+ G=self.G
+ w=self.ws.reshape((self.xs.shape[0],1))
+ x=self.xt
+ else:
+ G=self.G.T
+ w=self.wt.reshape((self.xt.shape[0],1))
+ x=self.xs