summaryrefslogtreecommitdiff
path: root/examples/others/plot_screenkhorn_1D.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-04-05 11:57:10 +0200
committerGitHub <noreply@github.com>2022-04-05 11:57:10 +0200
commitad02112d4288f3efdd5bc6fc6e45444313bba871 (patch)
treef6cd539450c2ed36cf5d7014debfd82e8b9fddfb /examples/others/plot_screenkhorn_1D.py
parent0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (diff)
[MRG] Update examples in the doc (#359)
* add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10
Diffstat (limited to 'examples/others/plot_screenkhorn_1D.py')
-rw-r--r--examples/others/plot_screenkhorn_1D.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/examples/others/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py
new file mode 100644
index 0000000..2023649
--- /dev/null
+++ b/examples/others/plot_screenkhorn_1D.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+"""
+========================================
+Screened optimal transport (Screenkhorn)
+========================================
+
+This example illustrates the computation of Screenkhorn [26].
+
+[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+Screening Sinkhorn Algorithm for Regularized Optimal Transport,
+Advances in Neural Information Processing Systems 33 (NeurIPS).
+"""
+
+# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+from ot.bregman import screenkhorn
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve Screenkhorn
+# -----------------------
+
+# Screenkhorn
+lambd = 2e-03 # entropy parameter
+ns_budget = 30 # budget number of points to be keeped in the source distribution
+nt_budget = 30 # budget number of points to be keeped in the target distribution
+
+G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True)
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn')
+pl.show()