summaryrefslogtreecommitdiff
path: root/examples/plot_WDA.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-08-31 09:28:37 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-08-31 09:28:37 +0200
commit212f3889b1114026765cda0134e02766daa82af2 (patch)
treef9ea2d2566d1544b3409152f8ebbc8ca706c96e2 /examples/plot_WDA.py
parentec67362de5ec785e3871eac75a8aa477857092c4 (diff)
update tests
Diffstat (limited to 'examples/plot_WDA.py')
-rw-r--r--examples/plot_WDA.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py
index 42789f2..06a2e38 100644
--- a/examples/plot_WDA.py
+++ b/examples/plot_WDA.py
@@ -4,6 +4,12 @@
Wasserstein Discriminant Analysis
=================================
+This example illustrate the use of WDA as proposed in [11].
+
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
+
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -16,6 +22,10 @@ import matplotlib.pylab as pl
from ot.dr import wda, fda
+##############################################################################
+# Generate data
+##############################################################################
+
#%% parameters
n = 1000 # nb samples in source and target datasets
@@ -39,6 +49,10 @@ nbnoise = 8
xs = np.hstack((xs, np.random.randn(n, nbnoise)))
xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+##############################################################################
+# Plot data
+##############################################################################
+
#%% plot samples
pl.figure(1, figsize=(6.4, 3.5))
@@ -53,11 +67,19 @@ pl.legend(loc=0)
pl.title('Other dimensions')
pl.tight_layout()
+##############################################################################
+# Compute Fisher Discriminant Analysis
+##############################################################################
+
#%% Compute FDA
p = 2
Pfda, projfda = fda(xs, ys, p)
+##############################################################################
+# Compute Wasserstein Discriminant Analysis
+##############################################################################
+
#%% Compute WDA
p = 2
reg = 1e0
@@ -66,6 +88,11 @@ maxiter = 100
Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+##############################################################################
+# Plot 2D projections
+##############################################################################
+
#%% plot samples
xsp = projfda(xs)