summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorAdrienCorenflos <adrien.corenflos@gmail.com>2020-10-22 09:28:53 +0100
committerGitHub <noreply@github.com>2020-10-22 10:28:53 +0200
commit78b44af2434f494c8f9e4c8c91003fbc0e1d4415 (patch)
tree013002f0a65918cee5eb95648965d4361f0c3dc2 /examples
parent7adc1b1aa73c55dc07983ff08dcb23fd71e9e8b6 (diff)
[MRG] Sliced wasserstein (#203)
* example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes * Implements sliced wasserstein * Changed formatting of string for py3.5 support * Docstest, expected 0.0 and not 0. * Adressed comments by @rflamary * No 3d plot here * add sliced to the docs * Incorporate comments by @rflamary * add link to pdf Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/sliced-wasserstein/README.txt4
-rw-r--r--examples/sliced-wasserstein/plot_variance.py84
2 files changed, 88 insertions, 0 deletions
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
new file mode 100644
index 0000000..a575345
--- /dev/null
+++ b/examples/sliced-wasserstein/README.txt
@@ -0,0 +1,4 @@
+
+
+Sliced Wasserstein Distance
+--------------------------- \ No newline at end of file
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
new file mode 100644
index 0000000..f3deeff
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+2D Sliced Wasserstein Distance
+==============================
+
+This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+###################################################################################
+# Compute Sliced Wasserstein distance for different seeds and number of projections
+# -----------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###################################################################################
+# Plot Sliced Wasserstein Distance
+# -----------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label="SWD")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()