summaryrefslogtreecommitdiff
path: root/examples/sliced-wasserstein/plot_variance.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/sliced-wasserstein/plot_variance.py')
-rw-r--r--examples/sliced-wasserstein/plot_variance.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index f3deeff..27df656 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -4,9 +4,11 @@
2D Sliced Wasserstein Distance
==============================
-This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].
+This example illustrates the computation of the sliced Wasserstein Distance as
+proposed in [31].
-[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of
+measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
"""
@@ -50,9 +52,9 @@ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')
-###################################################################################
-# Compute Sliced Wasserstein distance for different seeds and number of projections
-# -----------
+###############################################################################
+# Sliced Wasserstein distance for different seeds and number of projections
+# -------------------------------------------------------------------------
n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
@@ -66,9 +68,9 @@ for seed in range(n_seed):
res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)
-###################################################################################
+###############################################################################
# Plot Sliced Wasserstein Distance
-# -----------
+# --------------------------------
pl.figure(2)
pl.plot(n_projections_arr, res_mean, label="SWD")