diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-12-02 15:38:59 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-12-02 15:38:59 +0100 |
commit | e458b7a58d9790e7c5ff40dea235402d9c4c8662 (patch) | |
tree | ac9da575654c78aa04a177723603935051b5d42d /docs/source/auto_examples/plot_OTDA_mapping.py | |
parent | 7609f9e6a4103e13beb294873f4dac562b1d45e1 (diff) |
add doc for gallery
Diffstat (limited to 'docs/source/auto_examples/plot_OTDA_mapping.py')
-rw-r--r-- | docs/source/auto_examples/plot_OTDA_mapping.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_OTDA_mapping.py b/docs/source/auto_examples/plot_OTDA_mapping.py new file mode 100644 index 0000000..78b57e7 --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_mapping.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +""" +=============================================== +OT mapping estimation for domain adaptation [8] +=============================================== + +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for + discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% dataset generation + +np.random.seed(0) # makes example reproducible + +n=100 # nb samples in source and target datasets +theta=2*np.pi/20 +nz=0.1 +xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) +xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) + +# one of the target mode changes its variance (no linear mapping) +xt[yt==2]*=3 +xt=xt+4 + + +#%% plot samples + +pl.figure(1,(8,5)) +pl.clf() + +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.legend(loc=0) +pl.title('Source and target distributions') + + + +#%% OT linear mapping estimation + +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +xst=ot_mapping.predict(xs) # use the estimated mapping +xst0=ot_mapping.interp() # use barycentric mapping + + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("barycentric mapping") + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Learned mapping") + + + +#%% Kernel mapping estimation + +eta=1e-5 # quadratic regularization for regression +mu=1e-1 # weight of the OT linear term +bias=True # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping +xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping + + +#%% Plotting the mapped samples + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') +pl.title("Bary. mapping (linear)") +pl.legend(loc=0) + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (linear)") + +pl.subplot(2,2,3) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("Bary. mapping (kernel)") + +pl.subplot(2,2,4) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (kernel)") |