summaryrefslogtreecommitdiff
path: root/examples/plot_WDA.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2017-09-15 14:54:21 +0200
committerGitHub <noreply@github.com>2017-09-15 14:54:21 +0200
commit81b2796226f3abde29fc024752728444da77509a (patch)
treec52cec3c38552f9f8c15361758aa9a80c30c3ef3 /examples/plot_WDA.py
parente70d5420204db78691af2d0fbe04cc3d4416a8f4 (diff)
parent7fea2cd3e8ad29bf3fa442d7642bae124ee2bab0 (diff)
Merge pull request #27 from rflamary/autonb
auto notebooks + release update (fixes #16)
Diffstat (limited to 'examples/plot_WDA.py')
-rw-r--r--examples/plot_WDA.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py
index 42789f2..93cc237 100644
--- a/examples/plot_WDA.py
+++ b/examples/plot_WDA.py
@@ -4,6 +4,12 @@
Wasserstein Discriminant Analysis
=================================
+This example illustrate the use of WDA as proposed in [11].
+
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
+
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -16,6 +22,10 @@ import matplotlib.pylab as pl
from ot.dr import wda, fda
+##############################################################################
+# Generate data
+# -------------
+
#%% parameters
n = 1000 # nb samples in source and target datasets
@@ -39,6 +49,10 @@ nbnoise = 8
xs = np.hstack((xs, np.random.randn(n, nbnoise)))
xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+##############################################################################
+# Plot data
+# ---------
+
#%% plot samples
pl.figure(1, figsize=(6.4, 3.5))
@@ -53,11 +67,19 @@ pl.legend(loc=0)
pl.title('Other dimensions')
pl.tight_layout()
+##############################################################################
+# Compute Fisher Discriminant Analysis
+# ------------------------------------
+
#%% Compute FDA
p = 2
Pfda, projfda = fda(xs, ys, p)
+##############################################################################
+# Compute Wasserstein Discriminant Analysis
+# -----------------------------------------
+
#%% Compute WDA
p = 2
reg = 1e0
@@ -66,6 +88,11 @@ maxiter = 100
Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+##############################################################################
+# Plot 2D projections
+# -------------------
+
#%% plot samples
xsp = projfda(xs)