.. _sphx_glr_auto_examples_plot_OTDA_classes.py: ======================== OT for domain adaptation ======================== .. rst-class:: sphx-glr-horizontal * .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_001.png :scale: 47 * .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_004.png :scale: 47 .. rst-class:: sphx-glr-script-out Out:: It. |Loss |Delta loss -------------------------------- 0|9.171271e+00|0.000000e+00 1|2.133783e+00|-3.298127e+00 2|1.895941e+00|-1.254484e-01 3|1.844628e+00|-2.781709e-02 4|1.824983e+00|-1.076467e-02 5|1.815453e+00|-5.249337e-03 6|1.808104e+00|-4.064733e-03 7|1.803558e+00|-2.520475e-03 8|1.801061e+00|-1.386155e-03 9|1.799391e+00|-9.279565e-04 10|1.797176e+00|-1.232778e-03 11|1.795465e+00|-9.529479e-04 12|1.795316e+00|-8.322362e-05 13|1.794523e+00|-4.418932e-04 14|1.794444e+00|-4.390599e-05 15|1.794395e+00|-2.710318e-05 16|1.793713e+00|-3.804028e-04 17|1.793110e+00|-3.359479e-04 18|1.792829e+00|-1.569563e-04 19|1.792621e+00|-1.159469e-04 It. |Loss |Delta loss -------------------------------- 20|1.791334e+00|-7.187689e-04 | .. code-block:: python import matplotlib.pylab as pl import ot #%% parameters n=150 # nb samples in source and target datasets xs,ys=ot.datasets.get_data_classif('3gauss',n) xt,yt=ot.datasets.get_data_classif('3gauss2',n) #%% plot samples pl.figure(1) pl.subplot(2,2,1) pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') pl.legend(loc=0) pl.title('Source distributions') pl.subplot(2,2,2) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') pl.legend(loc=0) pl.title('target distributions') #%% OT estimation # LP problem da_emd=ot.da.OTDA() # init class da_emd.fit(xs,xt) # fit distributions xst0=da_emd.interp() # interpolation of source samples # sinkhorn regularization lambd=1e-1 da_entrop=ot.da.OTDA_sinkhorn() da_entrop.fit(xs,xt,reg=lambd) xsts=da_entrop.interp() # non-convex Group lasso regularization reg=1e-1 eta=1e0 da_lpl1=ot.da.OTDA_lpl1() da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta) xstg=da_lpl1.interp() # True Group lasso regularization reg=1e-1 eta=2e0 da_l1l2=ot.da.OTDA_l1l2() da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True) xstgl=da_l1l2.interp() #%% plot interpolated source samples pl.figure(4,(15,8)) param_img={'interpolation':'nearest','cmap':'jet'} pl.subplot(2,4,1) pl.imshow(da_emd.G,**param_img) pl.title('OT matrix') pl.subplot(2,4,2) pl.imshow(da_entrop.G,**param_img) pl.title('OT matrix sinkhorn') pl.subplot(2,4,3) pl.imshow(da_lpl1.G,**param_img) pl.title('OT matrix non-convex Group Lasso') pl.subplot(2,4,4) pl.imshow(da_l1l2.G,**param_img) pl.title('OT matrix Group Lasso') pl.subplot(2,4,5) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples') pl.legend(loc=0) pl.subplot(2,4,6) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples Sinkhorn') pl.subplot(2,4,7) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples non-convex Group Lasso') pl.subplot(2,4,8) pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3) pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30) pl.title('Interp samples Group Lasso') **Total running time of the script:** ( 0 minutes 2.225 seconds) .. container:: sphx-glr-footer .. container:: sphx-glr-download :download:`Download Python source code: plot_OTDA_classes.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_OTDA_classes.ipynb ` .. rst-class:: sphx-glr-signature `Generated by Sphinx-Gallery `_