From a54775103541ea37f54269de1ba1e1396a6d7b30 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 24 Apr 2020 17:32:57 +0200 Subject: exmaples in sections --- examples/plot_WDA.py | 129 --------------------------------------------------- 1 file changed, 129 deletions(-) delete mode 100644 examples/plot_WDA.py (limited to 'examples/plot_WDA.py') diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py deleted file mode 100644 index 5e17433..0000000 --- a/examples/plot_WDA.py +++ /dev/null @@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- -""" -================================= -Wasserstein Discriminant Analysis -================================= - -This example illustrate the use of WDA as proposed in [11]. - - -[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). -Wasserstein Discriminant Analysis. - -""" - -# Author: Remi Flamary -# -# License: MIT License - -# sphinx_gallery_thumbnail_number = 2 - -import numpy as np -import matplotlib.pylab as pl - -from ot.dr import wda, fda - - -############################################################################## -# Generate data -# ------------- - -#%% parameters - -n = 1000 # nb samples in source and target datasets -nz = 0.2 - -# generate circle dataset -t = np.random.rand(n) * 2 * np.pi -ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 -xs = np.concatenate( - (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) -xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2) - -t = np.random.rand(n) * 2 * np.pi -yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 -xt = np.concatenate( - (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1) -xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2) - -nbnoise = 8 - -xs = np.hstack((xs, np.random.randn(n, nbnoise))) -xt = np.hstack((xt, np.random.randn(n, nbnoise))) - -############################################################################## -# Plot data -# --------- - -#%% plot samples -pl.figure(1, figsize=(6.4, 3.5)) - -pl.subplot(1, 2, 1) -pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples') -pl.legend(loc=0) -pl.title('Discriminant dimensions') - -pl.subplot(1, 2, 2) -pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples') -pl.legend(loc=0) -pl.title('Other dimensions') -pl.tight_layout() - -############################################################################## -# Compute Fisher Discriminant Analysis -# ------------------------------------ - -#%% Compute FDA -p = 2 - -Pfda, projfda = fda(xs, ys, p) - -############################################################################## -# Compute Wasserstein Discriminant Analysis -# ----------------------------------------- - -#%% Compute WDA -p = 2 -reg = 1e0 -k = 10 -maxiter = 100 - -Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) - - -############################################################################## -# Plot 2D projections -# ------------------- - -#%% plot samples - -xsp = projfda(xs) -xtp = projfda(xt) - -xspw = projwda(xs) -xtpw = projwda(xt) - -pl.figure(2) - -pl.subplot(2, 2, 1) -pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples') -pl.legend(loc=0) -pl.title('Projected training samples FDA') - -pl.subplot(2, 2, 2) -pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples') -pl.legend(loc=0) -pl.title('Projected test samples FDA') - -pl.subplot(2, 2, 3) -pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples') -pl.legend(loc=0) -pl.title('Projected training samples WDA') - -pl.subplot(2, 2, 4) -pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples') -pl.legend(loc=0) -pl.title('Projected test samples WDA') -pl.tight_layout() - -pl.show() -- cgit v1.2.3