diff options
Diffstat (limited to 'docs/source/auto_examples/plot_OTDA_classes.rst')
-rw-r--r-- | docs/source/auto_examples/plot_OTDA_classes.rst | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_OTDA_classes.rst b/docs/source/auto_examples/plot_OTDA_classes.rst new file mode 100644 index 0000000..097e9fc --- /dev/null +++ b/docs/source/auto_examples/plot_OTDA_classes.rst @@ -0,0 +1,190 @@ + + +.. _sphx_glr_auto_examples_plot_OTDA_classes.py: + + +======================== +OT for domain adaptation +======================== + + + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_001.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_004.png + :scale: 47 + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Loss |Delta loss + -------------------------------- + 0|9.171271e+00|0.000000e+00 + 1|2.133783e+00|-3.298127e+00 + 2|1.895941e+00|-1.254484e-01 + 3|1.844628e+00|-2.781709e-02 + 4|1.824983e+00|-1.076467e-02 + 5|1.815453e+00|-5.249337e-03 + 6|1.808104e+00|-4.064733e-03 + 7|1.803558e+00|-2.520475e-03 + 8|1.801061e+00|-1.386155e-03 + 9|1.799391e+00|-9.279565e-04 + 10|1.797176e+00|-1.232778e-03 + 11|1.795465e+00|-9.529479e-04 + 12|1.795316e+00|-8.322362e-05 + 13|1.794523e+00|-4.418932e-04 + 14|1.794444e+00|-4.390599e-05 + 15|1.794395e+00|-2.710318e-05 + 16|1.793713e+00|-3.804028e-04 + 17|1.793110e+00|-3.359479e-04 + 18|1.792829e+00|-1.569563e-04 + 19|1.792621e+00|-1.159469e-04 + It. |Loss |Delta loss + -------------------------------- + 20|1.791334e+00|-7.187689e-04 + + + + +| + + +.. code-block:: python + + + import matplotlib.pylab as pl + import ot + + + + + #%% parameters + + n=150 # nb samples in source and target datasets + + xs,ys=ot.datasets.get_data_classif('3gauss',n) + xt,yt=ot.datasets.get_data_classif('3gauss2',n) + + + + + #%% plot samples + + pl.figure(1) + + pl.subplot(2,2,1) + pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') + pl.legend(loc=0) + pl.title('Source distributions') + + pl.subplot(2,2,2) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + pl.legend(loc=0) + pl.title('target distributions') + + + #%% OT estimation + + # LP problem + da_emd=ot.da.OTDA() # init class + da_emd.fit(xs,xt) # fit distributions + xst0=da_emd.interp() # interpolation of source samples + + + # sinkhorn regularization + lambd=1e-1 + da_entrop=ot.da.OTDA_sinkhorn() + da_entrop.fit(xs,xt,reg=lambd) + xsts=da_entrop.interp() + + # non-convex Group lasso regularization + reg=1e-1 + eta=1e0 + da_lpl1=ot.da.OTDA_lpl1() + da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) + xstg=da_lpl1.interp() + + + # True Group lasso regularization + reg=1e-1 + eta=2e0 + da_l1l2=ot.da.OTDA_l1l2() + da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) + xstgl=da_l1l2.interp() + + + #%% plot interpolated source samples + pl.figure(4,(15,8)) + + param_img={'interpolation':'nearest','cmap':'jet'} + + pl.subplot(2,4,1) + pl.imshow(da_emd.G,**param_img) + pl.title('OT matrix') + + + pl.subplot(2,4,2) + pl.imshow(da_entrop.G,**param_img) + pl.title('OT matrix sinkhorn') + + pl.subplot(2,4,3) + pl.imshow(da_lpl1.G,**param_img) + pl.title('OT matrix non-convex Group Lasso') + + pl.subplot(2,4,4) + pl.imshow(da_l1l2.G,**param_img) + pl.title('OT matrix Group Lasso') + + + pl.subplot(2,4,5) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples') + pl.legend(loc=0) + + pl.subplot(2,4,6) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples Sinkhorn') + + pl.subplot(2,4,7) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples non-convex Group Lasso') + + pl.subplot(2,4,8) + pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) + pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) + pl.title('Interp samples Group Lasso') +**Total running time of the script:** ( 0 minutes 2.225 seconds) + + + +.. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_OTDA_classes.py <plot_OTDA_classes.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_OTDA_classes.ipynb <plot_OTDA_classes.ipynb>` + +.. rst-class:: sphx-glr-signature + + `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_ |