summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_WDA.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.rst')
-rw-r--r--docs/source/auto_examples/plot_WDA.rst168
1 files changed, 116 insertions, 52 deletions
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst
index 540555d..76ebaf5 100644
--- a/docs/source/auto_examples/plot_WDA.rst
+++ b/docs/source/auto_examples/plot_WDA.rst
@@ -7,13 +7,22 @@
Wasserstein Discriminant Analysis
=================================
-@author: rflamary
-.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
- :align: center
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_WDA_002.png
+ :scale: 47
.. rst-class:: sphx-glr-script-out
@@ -23,26 +32,43 @@ Wasserstein Discriminant Analysis
Compiling cost function...
Computing gradient of cost function...
iter cost val grad. norm
- 1 +5.2427396265941129e-01 8.16627951e-01
- 2 +1.7904850059627236e-01 1.91366819e-01
- 3 +1.6985797253002377e-01 1.70940682e-01
- 4 +1.3903474972292729e-01 1.28606342e-01
- 5 +7.4961734618782416e-02 6.41973980e-02
- 6 +7.1900245222486239e-02 4.25693592e-02
- 7 +7.0472023318269614e-02 2.34599232e-02
- 8 +6.9917568641317152e-02 5.66542766e-03
- 9 +6.9885086242452696e-02 4.05756115e-04
- 10 +6.9884967432653489e-02 2.16836017e-04
- 11 +6.9884923649884148e-02 5.74961622e-05
- 12 +6.9884921818258436e-02 3.83257203e-05
- 13 +6.9884920459612282e-02 9.97486224e-06
- 14 +6.9884920414414409e-02 7.33567875e-06
- 15 +6.9884920388431387e-02 5.23889187e-06
- 16 +6.9884920385183902e-02 4.91959084e-06
- 17 +6.9884920373983223e-02 3.56451669e-06
- 18 +6.9884920369701245e-02 2.88858709e-06
- 19 +6.9884920361621208e-02 1.82294279e-07
- Terminated - min grad norm reached after 19 iterations, 9.65 seconds.
+ 1 +8.9741888001949222e-01 3.71269078e-01
+ 2 +4.9103998133976140e-01 3.46687543e-01
+ 3 +4.2142651893148553e-01 1.04789602e-01
+ 4 +4.1573609749588841e-01 5.21726648e-02
+ 5 +4.1486046805261961e-01 5.35335513e-02
+ 6 +4.1315953904635105e-01 2.17803599e-02
+ 7 +4.1313030162717523e-01 6.06901182e-02
+ 8 +4.1301511591963386e-01 5.88598758e-02
+ 9 +4.1258349404769817e-01 5.14307874e-02
+ 10 +4.1139242901051226e-01 2.03198793e-02
+ 11 +4.1113798965164017e-01 1.18944721e-02
+ 12 +4.1103446820878486e-01 2.21783648e-02
+ 13 +4.1076586830791861e-01 9.51495863e-03
+ 14 +4.1036935287519144e-01 3.74973214e-02
+ 15 +4.0958729714575060e-01 1.23810902e-02
+ 16 +4.0898266309095005e-01 4.01999918e-02
+ 17 +4.0816076944357715e-01 2.27240277e-02
+ 18 +4.0788116701894767e-01 4.42815945e-02
+ 19 +4.0695443744952403e-01 3.28464304e-02
+ 20 +4.0293834480911150e-01 7.76000681e-02
+ 21 +3.8488003705202750e-01 1.49378022e-01
+ 22 +3.0767344927282614e-01 2.15432117e-01
+ 23 +2.3849425361868334e-01 1.07942382e-01
+ 24 +2.3845125762548214e-01 1.08953278e-01
+ 25 +2.3828007730494005e-01 1.07934830e-01
+ 26 +2.3760839060570119e-01 1.03822134e-01
+ 27 +2.3514215179705886e-01 8.67263481e-02
+ 28 +2.2978886197588613e-01 9.26609306e-03
+ 29 +2.2972671019495342e-01 2.59476089e-03
+ 30 +2.2972355865247496e-01 1.57205146e-03
+ 31 +2.2972296662351968e-01 1.29300760e-03
+ 32 +2.2972181557051569e-01 8.82375756e-05
+ 33 +2.2972181277025336e-01 6.20536544e-05
+ 34 +2.2972181023486152e-01 7.01884014e-06
+ 35 +2.2972181020400181e-01 1.60415765e-06
+ 36 +2.2972181020236590e-01 2.44290966e-07
+ Terminated - min grad norm reached after 36 iterations, 13.41 seconds.
@@ -53,62 +79,100 @@ Wasserstein Discriminant Analysis
.. code-block:: python
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
- import ot
- from ot.datasets import get_1D_gauss as gauss
- from ot.dr import wda
+
+ from ot.dr import wda, fda
#%% parameters
- n=1000 # nb samples in source and target datasets
- nz=0.2
- xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
- xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
+ n = 1000 # nb samples in source and target datasets
+ nz = 0.2
- nbnoise=8
+ # generate circle dataset
+ t = np.random.rand(n) * 2 * np.pi
+ ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+ xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
- xs=np.hstack((xs,np.random.randn(n,nbnoise)))
- xt=np.hstack((xt,np.random.randn(n,nbnoise)))
+ t = np.random.rand(n) * 2 * np.pi
+ yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+ xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
- #%% plot samples
+ nbnoise = 8
- pl.figure(1)
+ xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+ xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+ #%% plot samples
+ pl.figure(1, figsize=(6.4, 3.5))
- pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+ pl.subplot(1, 2, 1)
+ pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Discriminant dimensions')
+ pl.subplot(1, 2, 2)
+ pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+ pl.legend(loc=0)
+ pl.title('Other dimensions')
+ pl.tight_layout()
+
+ #%% Compute FDA
+ p = 2
- #%% plot distributions and loss matrix
- p=2
- reg=1
- k=10
- maxiter=100
+ Pfda, projfda = fda(xs, ys, p)
- P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+ #%% Compute WDA
+ p = 2
+ reg = 1e0
+ k = 10
+ maxiter = 100
+
+ Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
#%% plot samples
- xsp=proj(xs)
- xtp=proj(xt)
+ xsp = projfda(xs)
+ xtp = projfda(xt)
+
+ xspw = projwda(xs)
+ xtpw = projwda(xt)
- pl.figure(1,(10,5))
+ pl.figure(2)
- pl.subplot(1,2,1)
- pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+ pl.subplot(2, 2, 1)
+ pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
- pl.title('Projected training samples')
+ pl.title('Projected training samples FDA')
+ pl.subplot(2, 2, 2)
+ pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected test samples FDA')
+
+ pl.subplot(2, 2, 3)
+ pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected training samples WDA')
- pl.subplot(1,2,2)
- pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+ pl.subplot(2, 2, 4)
+ pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
- pl.title('Projected test samples')
+ pl.title('Projected test samples WDA')
+ pl.tight_layout()
+
+ pl.show()
-**Total running time of the script:** ( 0 minutes 16.902 seconds)
+**Total running time of the script:** ( 0 minutes 19.853 seconds)