.. _sphx_glr_auto_examples_plot_WDA.py: ================================= Wasserstein Discriminant Analysis ================================= @author: rflamary .. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png :align: center .. rst-class:: sphx-glr-script-out Out:: Compiling cost function... Computing gradient of cost function... iter cost val grad. norm 1 +5.2427396265941129e-01 8.16627951e-01 2 +1.7904850059627236e-01 1.91366819e-01 3 +1.6985797253002377e-01 1.70940682e-01 4 +1.3903474972292729e-01 1.28606342e-01 5 +7.4961734618782416e-02 6.41973980e-02 6 +7.1900245222486239e-02 4.25693592e-02 7 +7.0472023318269614e-02 2.34599232e-02 8 +6.9917568641317152e-02 5.66542766e-03 9 +6.9885086242452696e-02 4.05756115e-04 10 +6.9884967432653489e-02 2.16836017e-04 11 +6.9884923649884148e-02 5.74961622e-05 12 +6.9884921818258436e-02 3.83257203e-05 13 +6.9884920459612282e-02 9.97486224e-06 14 +6.9884920414414409e-02 7.33567875e-06 15 +6.9884920388431387e-02 5.23889187e-06 16 +6.9884920385183902e-02 4.91959084e-06 17 +6.9884920373983223e-02 3.56451669e-06 18 +6.9884920369701245e-02 2.88858709e-06 19 +6.9884920361621208e-02 1.82294279e-07 Terminated - min grad norm reached after 19 iterations, 9.65 seconds. | .. code-block:: python import numpy as np import matplotlib.pylab as pl import ot from ot.datasets import get_1D_gauss as gauss from ot.dr import wda #%% parameters n=1000 # nb samples in source and target datasets nz=0.2 xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) nbnoise=8 xs=np.hstack((xs,np.random.randn(n,nbnoise))) xt=np.hstack((xt,np.random.randn(n,nbnoise))) #%% plot samples pl.figure(1) pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') pl.legend(loc=0) pl.title('Discriminant dimensions') #%% plot distributions and loss matrix p=2 reg=1 k=10 maxiter=100 P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) #%% plot samples xsp=proj(xs) xtp=proj(xt) pl.figure(1,(10,5)) pl.subplot(1,2,1) pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') pl.legend(loc=0) pl.title('Projected training samples') pl.subplot(1,2,2) pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') pl.legend(loc=0) pl.title('Projected test samples') **Total running time of the script:** ( 0 minutes 16.902 seconds) .. container:: sphx-glr-footer .. container:: sphx-glr-download :download:`Download Python source code: plot_WDA.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: plot_WDA.ipynb ` .. rst-class:: sphx-glr-signature `Generated by Sphinx-Gallery `_