diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-04 12:21:53 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-04 12:21:53 +0100 |
commit | f5e9a1336af4959fcf4fee0dddd6257cc5cfa064 (patch) | |
tree | 7bd2c471cb82285e06579c9951d918ecbb5e7fc3 /ot/da.py | |
parent | a5f2569859424a00100f7ad29ae4d715ee90c29f (diff) |
etter doc for classes
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 61 |
1 files changed, 31 insertions, 30 deletions
@@ -151,7 +151,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos The algorithm used for solving the problem is the block coordinate descent that alternates between updates of G (using conditionnal gradient) - abd the update of L using a classical least square solver. + and the update of L using a classical least square solver. Parameters @@ -320,7 +320,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b The algorithm used for solving the problem is the block coordinate descent that alternates between updates of G (using conditionnal gradient) - abd the update of L using a classical kernel least square solver. + and the update of L using a classical kernel least square solver. Parameters @@ -492,7 +492,15 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b class OTDA(object): - """Class for domain adaptation with optimal transport""" + """Class for domain adaptation with optimal transport as proposed in [5] + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + + """ def __init__(self,metric='sqeuclidean'): """ Class initialization""" @@ -504,8 +512,7 @@ class OTDA(object): def fit(self,xs,xt,ws=None,wt=None): - """ Fit domain adaptation between samples is xs and xt (with optional - weights)""" + """ Fit domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt @@ -522,7 +529,7 @@ class OTDA(object): self.computed=True def interp(self,direction=1): - """Barycentric interpolation for the source (1) or target (-1) + """Barycentric interpolation for the source (1) or target (-1) samples This Barycentric interpolation solves for each source (resp target) sample xs (resp xt) the following optimization problem: @@ -558,10 +565,16 @@ class OTDA(object): def predict(self,x,direction=1): - """ Out of sample mapping using the formulation from Ferradans + """ Out of sample mapping using the formulation from [6] + + For each sample x to map, it finds the nearest source sample xs and + map the samle x to the position xst+(x-xs) wher xst is the barycentric + interpolation of source sample xs. + + References + ---------- - It basically find the source sample the nearset to the nex sample and - apply the difference to the displaced source sample. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. """ if direction>0: # >0 then source to target @@ -582,8 +595,7 @@ class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional - weights)""" + """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)""" self.xs=xs self.xt=xt @@ -601,12 +613,12 @@ class OTDA_sinkhorn(OTDA): class OTDA_lpl1(OTDA): - """Class for domain adaptation with optimal transport with entropic an group regularization""" + """Class for domain adaptation with optimal transport with entropic and group regularization""" def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional - weights)""" + """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), + See ot.da.sinkhorn_lpl1_mm for fit parameters"""" self.xs=xs self.xt=xt @@ -623,7 +635,7 @@ class OTDA_lpl1(OTDA): self.computed=True class OTDA_mapping_linear(OTDA): - """Class for optimal transport with joint linear mapping estimation""" + """Class for optimal transport with joint linear mapping estimation as in [8]""" def __init__(self): @@ -657,12 +669,7 @@ class OTDA_mapping_linear(OTDA): def predict(self,x): - """ Out of sample mapping using the formulation from Ferradans - - It basically find the source sample the nearset to the nex sample and - apply the difference to the displaced source sample. - - """ + """ Out of sample mapping estimated during the call to fit""" if self.computed: if self.bias: x=np.hstack((x,np.ones((x.shape[0],1)))) @@ -672,13 +679,12 @@ class OTDA_mapping_linear(OTDA): return None class OTDA_mapping_kernel(OTDA_mapping_linear): - """Class for optimal transport with joint linear mapping estimation""" + """Class for optimal transport with joint nonlinear mapping estimation as in [8]""" def fit(self,xs,xt,mu=1,eta=1,bias=False,kerneltype='gaussian',sigma=1,**kwargs): - """ Fit domain adaptation between samples is xs and xt (with optional - weights)""" + """ Fit domain adaptation between samples is xs and xt """ self.xs=xs self.xt=xt self.bias=bias @@ -695,12 +701,7 @@ class OTDA_mapping_kernel(OTDA_mapping_linear): def predict(self,x): - """ Out of sample mapping using the formulation from Ferradans - - It basically find the source sample the nearset to the nex sample and - apply the difference to the displaced source sample. - - """ + """ Out of sample mapping estimated during the call to fit""" if self.computed: K=kernel(x,self.xs,method=self.kernel,sigma=self.sigma,**self.kwargs) |