summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-04 12:21:53 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-04 12:21:53 +0100
commitf5e9a1336af4959fcf4fee0dddd6257cc5cfa064 (patch)
tree7bd2c471cb82285e06579c9951d918ecbb5e7fc3 /ot/da.py
parenta5f2569859424a00100f7ad29ae4d715ee90c29f (diff)
etter doc for classes
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py61
1 files changed, 31 insertions, 30 deletions
diff --git a/ot/da.py b/ot/da.py
index 72ca3ac..90b96ba 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -151,7 +151,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
The algorithm used for solving the problem is the block coordinate
descent that alternates between updates of G (using conditionnal gradient)
- abd the update of L using a classical least square solver.
+ and the update of L using a classical least square solver.
Parameters
@@ -320,7 +320,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
The algorithm used for solving the problem is the block coordinate
descent that alternates between updates of G (using conditionnal gradient)
- abd the update of L using a classical kernel least square solver.
+ and the update of L using a classical kernel least square solver.
Parameters
@@ -492,7 +492,15 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
class OTDA(object):
- """Class for domain adaptation with optimal transport"""
+ """Class for domain adaptation with optimal transport as proposed in [5]
+
+
+ References
+ ----------
+
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
+ """
def __init__(self,metric='sqeuclidean'):
""" Class initialization"""
@@ -504,8 +512,7 @@ class OTDA(object):
def fit(self,xs,xt,ws=None,wt=None):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
+ """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
self.xs=xs
self.xt=xt
@@ -522,7 +529,7 @@ class OTDA(object):
self.computed=True
def interp(self,direction=1):
- """Barycentric interpolation for the source (1) or target (-1)
+ """Barycentric interpolation for the source (1) or target (-1) samples
This Barycentric interpolation solves for each source (resp target)
sample xs (resp xt) the following optimization problem:
@@ -558,10 +565,16 @@ class OTDA(object):
def predict(self,x,direction=1):
- """ Out of sample mapping using the formulation from Ferradans
+ """ Out of sample mapping using the formulation from [6]
+
+ For each sample x to map, it finds the nearest source sample xs and
+ map the samle x to the position xst+(x-xs) wher xst is the barycentric
+ interpolation of source sample xs.
+
+ References
+ ----------
- It basically find the source sample the nearset to the nex sample and
- apply the difference to the displaced source sample.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
"""
if direction>0: # >0 then source to target
@@ -582,8 +595,7 @@ class OTDA_sinkhorn(OTDA):
"""Class for domain adaptation with optimal transport with entropic regularization"""
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
self.xs=xs
self.xt=xt
@@ -601,12 +613,12 @@ class OTDA_sinkhorn(OTDA):
class OTDA_lpl1(OTDA):
- """Class for domain adaptation with optimal transport with entropic an group regularization"""
+ """Class for domain adaptation with optimal transport with entropic and group regularization"""
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
+ """ Fit regularized domain adaptation between samples is xs and xt (with optional weights),
+ See ot.da.sinkhorn_lpl1_mm for fit parameters""""
self.xs=xs
self.xt=xt
@@ -623,7 +635,7 @@ class OTDA_lpl1(OTDA):
self.computed=True
class OTDA_mapping_linear(OTDA):
- """Class for optimal transport with joint linear mapping estimation"""
+ """Class for optimal transport with joint linear mapping estimation as in [8]"""
def __init__(self):
@@ -657,12 +669,7 @@ class OTDA_mapping_linear(OTDA):
def predict(self,x):
- """ Out of sample mapping using the formulation from Ferradans
-
- It basically find the source sample the nearset to the nex sample and
- apply the difference to the displaced source sample.
-
- """
+ """ Out of sample mapping estimated during the call to fit"""
if self.computed:
if self.bias:
x=np.hstack((x,np.ones((x.shape[0],1))))
@@ -672,13 +679,12 @@ class OTDA_mapping_linear(OTDA):
return None
class OTDA_mapping_kernel(OTDA_mapping_linear):
- """Class for optimal transport with joint linear mapping estimation"""
+ """Class for optimal transport with joint nonlinear mapping estimation as in [8]"""
def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
+ """ Fit domain adaptation between samples is xs and xt """
self.xs=xs
self.xt=xt
self.bias=bias
@@ -695,12 +701,7 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
def predict(self,x):
- """ Out of sample mapping using the formulation from Ferradans
-
- It basically find the source sample the nearset to the nex sample and
- apply the difference to the displaced source sample.
-
- """
+ """ Out of sample mapping estimated during the call to fit"""
if self.computed:
K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs)