diff options
Diffstat (limited to 'docs/source/auto_examples/plot_WDA.py')
-rw-r--r-- | docs/source/auto_examples/plot_WDA.py | 122 |
1 files changed, 93 insertions, 29 deletions
diff --git a/docs/source/auto_examples/plot_WDA.py b/docs/source/auto_examples/plot_WDA.py index bbe3888..93cc237 100644 --- a/docs/source/auto_examples/plot_WDA.py +++ b/docs/source/auto_examples/plot_WDA.py @@ -4,60 +4,124 @@ Wasserstein Discriminant Analysis ================================= -@author: rflamary +This example illustrate the use of WDA as proposed in [11]. + + +[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). +Wasserstein Discriminant Analysis. + """ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + import numpy as np import matplotlib.pylab as pl -import ot -from ot.datasets import get_1D_gauss as gauss -from ot.dr import wda +from ot.dr import wda, fda + + +############################################################################## +# Generate data +# ------------- #%% parameters -n=1000 # nb samples in source and target datasets -nz=0.2 -xs,ys=ot.datasets.get_data_classif('3gauss',n,nz) -xt,yt=ot.datasets.get_data_classif('3gauss',n,nz) +n = 1000 # nb samples in source and target datasets +nz = 0.2 -nbnoise=8 +# generate circle dataset +t = np.random.rand(n) * 2 * np.pi +ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 +xs = np.concatenate( + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) +xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) -xs=np.hstack((xs,np.random.randn(n,nbnoise))) -xt=np.hstack((xt,np.random.randn(n,nbnoise))) +t = np.random.rand(n) * 2 * np.pi +yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 +xt = np.concatenate( + (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) +xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) -#%% plot samples +nbnoise = 8 + +xs = np.hstack((xs, np.random.randn(n, nbnoise))) +xt = np.hstack((xt, np.random.randn(n, nbnoise))) -pl.figure(1) +############################################################################## +# Plot data +# --------- +#%% plot samples +pl.figure(1, figsize=(6.4, 3.5)) -pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples') +pl.subplot(1, 2, 1) +pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') pl.legend(loc=0) pl.title('Discriminant dimensions') +pl.subplot(1, 2, 2) +pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') +pl.legend(loc=0) +pl.title('Other dimensions') +pl.tight_layout() + +############################################################################## +# Compute Fisher Discriminant Analysis +# ------------------------------------ -#%% plot distributions and loss matrix -p=2 -reg=1 -k=10 -maxiter=100 +#%% Compute FDA +p = 2 -P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) +Pfda, projfda = fda(xs, ys, p) + +############################################################################## +# Compute Wasserstein Discriminant Analysis +# ----------------------------------------- + +#%% Compute WDA +p = 2 +reg = 1e0 +k = 10 +maxiter = 100 + +Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) + + +############################################################################## +# Plot 2D projections +# ------------------- #%% plot samples -xsp=proj(xs) -xtp=proj(xt) +xsp = projfda(xs) +xtp = projfda(xt) -pl.figure(1,(10,5)) +xspw = projwda(xs) +xtpw = projwda(xt) -pl.subplot(1,2,1) -pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples') +pl.figure(2) + +pl.subplot(2, 2, 1) +pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) -pl.title('Projected training samples') +pl.title('Projected training samples FDA') +pl.subplot(2, 2, 2) +pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') +pl.legend(loc=0) +pl.title('Projected test samples FDA') -pl.subplot(1,2,2) -pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples') +pl.subplot(2, 2, 3) +pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') pl.legend(loc=0) -pl.title('Projected test samples') +pl.title('Projected training samples WDA') + +pl.subplot(2, 2, 4) +pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') +pl.legend(loc=0) +pl.title('Projected test samples WDA') +pl.tight_layout() + +pl.show() |