summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OTDA_classes.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_OTDA_classes.rst')
-rw-r--r--docs/source/auto_examples/plot_OTDA_classes.rst190
1 files changed, 0 insertions, 190 deletions
diff --git a/docs/source/auto_examples/plot_OTDA_classes.rst b/docs/source/auto_examples/plot_OTDA_classes.rst
deleted file mode 100644
index 097e9fc..0000000
--- a/docs/source/auto_examples/plot_OTDA_classes.rst
+++ /dev/null
@@ -1,190 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OTDA_classes.py:
-
-
-========================
-OT for domain adaptation
-========================
-
-
-
-
-
-.. rst-class:: sphx-glr-horizontal
-
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_001.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_004.png
- :scale: 47
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out::
-
- It. |Loss |Delta loss
- --------------------------------
- 0|9.171271e+00|0.000000e+00
- 1|2.133783e+00|-3.298127e+00
- 2|1.895941e+00|-1.254484e-01
- 3|1.844628e+00|-2.781709e-02
- 4|1.824983e+00|-1.076467e-02
- 5|1.815453e+00|-5.249337e-03
- 6|1.808104e+00|-4.064733e-03
- 7|1.803558e+00|-2.520475e-03
- 8|1.801061e+00|-1.386155e-03
- 9|1.799391e+00|-9.279565e-04
- 10|1.797176e+00|-1.232778e-03
- 11|1.795465e+00|-9.529479e-04
- 12|1.795316e+00|-8.322362e-05
- 13|1.794523e+00|-4.418932e-04
- 14|1.794444e+00|-4.390599e-05
- 15|1.794395e+00|-2.710318e-05
- 16|1.793713e+00|-3.804028e-04
- 17|1.793110e+00|-3.359479e-04
- 18|1.792829e+00|-1.569563e-04
- 19|1.792621e+00|-1.159469e-04
- It. |Loss |Delta loss
- --------------------------------
- 20|1.791334e+00|-7.187689e-04
-
-
-
-
-|
-
-
-.. code-block:: python
-
-
- import matplotlib.pylab as pl
- import ot
-
-
-
-
- #%% parameters
-
- n=150 # nb samples in source and target datasets
-
- xs,ys=ot.datasets.get_data_classif('3gauss',n)
- xt,yt=ot.datasets.get_data_classif('3gauss2',n)
-
-
-
-
- #%% plot samples
-
- pl.figure(1)
-
- pl.subplot(2,2,1)
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.legend(loc=0)
- pl.title('Source distributions')
-
- pl.subplot(2,2,2)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
- pl.legend(loc=0)
- pl.title('target distributions')
-
-
- #%% OT estimation
-
- # LP problem
- da_emd=ot.da.OTDA() # init class
- da_emd.fit(xs,xt) # fit distributions
- xst0=da_emd.interp() # interpolation of source samples
-
-
- # sinkhorn regularization
- lambd=1e-1
- da_entrop=ot.da.OTDA_sinkhorn()
- da_entrop.fit(xs,xt,reg=lambd)
- xsts=da_entrop.interp()
-
- # non-convex Group lasso regularization
- reg=1e-1
- eta=1e0
- da_lpl1=ot.da.OTDA_lpl1()
- da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
- xstg=da_lpl1.interp()
-
-
- # True Group lasso regularization
- reg=1e-1
- eta=2e0
- da_l1l2=ot.da.OTDA_l1l2()
- da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
- xstgl=da_l1l2.interp()
-
-
- #%% plot interpolated source samples
- pl.figure(4,(15,8))
-
- param_img={'interpolation':'nearest','cmap':'jet'}
-
- pl.subplot(2,4,1)
- pl.imshow(da_emd.G,**param_img)
- pl.title('OT matrix')
-
-
- pl.subplot(2,4,2)
- pl.imshow(da_entrop.G,**param_img)
- pl.title('OT matrix sinkhorn')
-
- pl.subplot(2,4,3)
- pl.imshow(da_lpl1.G,**param_img)
- pl.title('OT matrix non-convex Group Lasso')
-
- pl.subplot(2,4,4)
- pl.imshow(da_l1l2.G,**param_img)
- pl.title('OT matrix Group Lasso')
-
-
- pl.subplot(2,4,5)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples')
- pl.legend(loc=0)
-
- pl.subplot(2,4,6)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples Sinkhorn')
-
- pl.subplot(2,4,7)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples non-convex Group Lasso')
-
- pl.subplot(2,4,8)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples Group Lasso')
-**Total running time of the script:** ( 0 minutes 2.225 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OTDA_classes.py <plot_OTDA_classes.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OTDA_classes.ipynb <plot_OTDA_classes.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_