summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OTDA_classes.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_OTDA_classes.rst')
-rw-r--r--docs/source/auto_examples/plot_OTDA_classes.rst190
1 files changed, 190 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_OTDA_classes.rst b/docs/source/auto_examples/plot_OTDA_classes.rst
new file mode 100644
index 0000000..097e9fc
--- /dev/null
+++ b/docs/source/auto_examples/plot_OTDA_classes.rst
@@ -0,0 +1,190 @@
+
+
+.. _sphx_glr_auto_examples_plot_OTDA_classes.py:
+
+
+========================
+OT for domain adaptation
+========================
+
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_004.png
+ :scale: 47
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|9.171271e+00|0.000000e+00
+ 1|2.133783e+00|-3.298127e+00
+ 2|1.895941e+00|-1.254484e-01
+ 3|1.844628e+00|-2.781709e-02
+ 4|1.824983e+00|-1.076467e-02
+ 5|1.815453e+00|-5.249337e-03
+ 6|1.808104e+00|-4.064733e-03
+ 7|1.803558e+00|-2.520475e-03
+ 8|1.801061e+00|-1.386155e-03
+ 9|1.799391e+00|-9.279565e-04
+ 10|1.797176e+00|-1.232778e-03
+ 11|1.795465e+00|-9.529479e-04
+ 12|1.795316e+00|-8.322362e-05
+ 13|1.794523e+00|-4.418932e-04
+ 14|1.794444e+00|-4.390599e-05
+ 15|1.794395e+00|-2.710318e-05
+ 16|1.793713e+00|-3.804028e-04
+ 17|1.793110e+00|-3.359479e-04
+ 18|1.792829e+00|-1.569563e-04
+ 19|1.792621e+00|-1.159469e-04
+ It. |Loss |Delta loss
+ --------------------------------
+ 20|1.791334e+00|-7.187689e-04
+
+
+
+
+|
+
+
+.. code-block:: python
+
+
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+ #%% parameters
+
+ n=150 # nb samples in source and target datasets
+
+ xs,ys=ot.datasets.get_data_classif('3gauss',n)
+ xt,yt=ot.datasets.get_data_classif('3gauss2',n)
+
+
+
+
+ #%% plot samples
+
+ pl.figure(1)
+
+ pl.subplot(2,2,1)
+ pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
+ pl.legend(loc=0)
+ pl.title('Source distributions')
+
+ pl.subplot(2,2,2)
+ pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
+ pl.legend(loc=0)
+ pl.title('target distributions')
+
+
+ #%% OT estimation
+
+ # LP problem
+ da_emd=ot.da.OTDA() # init class
+ da_emd.fit(xs,xt) # fit distributions
+ xst0=da_emd.interp() # interpolation of source samples
+
+
+ # sinkhorn regularization
+ lambd=1e-1
+ da_entrop=ot.da.OTDA_sinkhorn()
+ da_entrop.fit(xs,xt,reg=lambd)
+ xsts=da_entrop.interp()
+
+ # non-convex Group lasso regularization
+ reg=1e-1
+ eta=1e0
+ da_lpl1=ot.da.OTDA_lpl1()
+ da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
+ xstg=da_lpl1.interp()
+
+
+ # True Group lasso regularization
+ reg=1e-1
+ eta=2e0
+ da_l1l2=ot.da.OTDA_l1l2()
+ da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
+ xstgl=da_l1l2.interp()
+
+
+ #%% plot interpolated source samples
+ pl.figure(4,(15,8))
+
+ param_img={'interpolation':'nearest','cmap':'jet'}
+
+ pl.subplot(2,4,1)
+ pl.imshow(da_emd.G,**param_img)
+ pl.title('OT matrix')
+
+
+ pl.subplot(2,4,2)
+ pl.imshow(da_entrop.G,**param_img)
+ pl.title('OT matrix sinkhorn')
+
+ pl.subplot(2,4,3)
+ pl.imshow(da_lpl1.G,**param_img)
+ pl.title('OT matrix non-convex Group Lasso')
+
+ pl.subplot(2,4,4)
+ pl.imshow(da_l1l2.G,**param_img)
+ pl.title('OT matrix Group Lasso')
+
+
+ pl.subplot(2,4,5)
+ pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
+ pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
+ pl.title('Interp samples')
+ pl.legend(loc=0)
+
+ pl.subplot(2,4,6)
+ pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
+ pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
+ pl.title('Interp samples Sinkhorn')
+
+ pl.subplot(2,4,7)
+ pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
+ pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
+ pl.title('Interp samples non-convex Group Lasso')
+
+ pl.subplot(2,4,8)
+ pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
+ pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
+ pl.title('Interp samples Group Lasso')
+**Total running time of the script:** ( 0 minutes 2.225 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_OTDA_classes.py <plot_OTDA_classes.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_OTDA_classes.ipynb <plot_OTDA_classes.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_