summaryrefslogtreecommitdiff
path: root/examples/others
diff options
context:
space:
mode:
Diffstat (limited to 'examples/others')
-rw-r--r--examples/others/plot_WeakOT_VS_OT.py98
-rw-r--r--examples/others/plot_factored_coupling.py86
-rw-r--r--examples/others/plot_logo.py112
-rw-r--r--examples/others/plot_screenkhorn_1D.py71
-rw-r--r--examples/others/plot_stochastic.py189
5 files changed, 556 insertions, 0 deletions
diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py
new file mode 100644
index 0000000..a29c875
--- /dev/null
+++ b/examples/others/plot_WeakOT_VS_OT.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+Weak Optimal Transport VS exact Optimal Transport
+====================================================
+
+Illustration of 2D optimal transport between distributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data an plot it
+# ------------------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+pl.figure(2)
+pl.imshow(M, interpolation='nearest')
+pl.title('Cost matrix M')
+
+
+##############################################################################
+# Compute Weak OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+#%% Weak OT
+
+Gweak = ot.weak_optimal_transport(xs, xt, a, b)
+
+
+##############################################################################
+# Plot weak OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(3, (8, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(G0, interpolation='nearest')
+pl.title('OT matrix')
+
+pl.subplot(1, 2, 2)
+pl.imshow(Gweak, interpolation='nearest')
+pl.title('Weak OT matrix')
+
+pl.figure(4, (8, 5))
+
+pl.subplot(1, 2, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('OT matrix with samples')
+
+pl.subplot(1, 2, 2)
+ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Weak OT matrix with samples')
diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py
new file mode 100644
index 0000000..b5b1c9f
--- /dev/null
+++ b/examples/others/plot_factored_coupling.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+Optimal transport with factored couplings
+==========================================
+
+Illustration of the factored coupling OT between 2D empirical distributions
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+# %%
+# Generate data an plot it
+# ------------------------
+
+# parameters and data generation
+
+np.random.seed(42)
+
+n = 100 # nb samples
+
+xs = np.random.rand(n, 2) - .5
+
+xs = xs + np.sign(xs)
+
+xt = np.random.rand(n, 2) - .5
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+# %%
+# Compute Factore OT and exact OT solutions
+# --------------------------------------
+
+#%% EMD
+M = ot.dist(xs, xt)
+G0 = ot.emd(a, b, M)
+
+#%% factored OT OT
+
+Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4)
+
+
+# %%
+# Plot factored OT and exact OT solutions
+# --------------------------------------
+
+pl.figure(2, (14, 4))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Exact OT with samples')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5)
+ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples')
+pl.title('Factored OT with template samples')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.title('Factored OT low rank OT plan')
diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py
new file mode 100644
index 0000000..bb4f640
--- /dev/null
+++ b/examples/others/plot_logo.py
@@ -0,0 +1,112 @@
+
+# -*- coding: utf-8 -*-
+r"""
+=======================
+Logo of the POT toolbox
+=======================
+
+In this example we plot the logo of the POT toolbox.
+
+This logo is that it is done 100% in Python and generated using
+matplotlib and ploting teh solution of the EMD solver from POT.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+# %% Load modules
+import numpy as np
+import matplotlib.pyplot as pl
+import ot
+
+# %%
+# Data for logo
+# -------------
+
+
+# Letter P
+p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ])
+p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ])
+
+# Letter O
+o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ])
+o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ])
+
+# Scaling and translation for letter O
+o1[:, 0] += 6.4
+o2[:, 0] += 6.4
+o1[:, 0] *= 0.6
+o2[:, 0] *= 0.6
+
+# Letter T
+t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ])
+t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ])
+
+# Translating the T
+t1[:, 0] += 7.1
+t2[:, 0] += 7.1
+
+# Concatenate all letters
+x1 = np.concatenate((p1, o1, t1), axis=0)
+x2 = np.concatenate((p2, o2, t2), axis=0)
+
+# Horizontal and vertical scaling
+sx = 1.0
+sy = .5
+x1[:, 0] *= sx
+x1[:, 1] *= sy
+x2[:, 0] *= sx
+x2[:, 1] *= sy
+
+# %%
+# Plot the logo (clear background)
+# --------------------------------
+
+# Solve OT problem between the points
+M = ot.dist(x1, x2, metric='euclidean')
+T = ot.emd([], [], M)
+
+pl.figure(1, (3.5, 1.1))
+pl.clf()
+# plot the OT plan
+for i in range(M.shape[0]):
+ for j in range(M.shape[1]):
+ if T[i, j] > 1e-8:
+ pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1)
+# plot the samples
+pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k')
+pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k')
+
+
+pl.axis('equal')
+pl.axis('off')
+
+# Save logo file
+# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight')
+
+# %%
+# Plot the logo (dark background)
+# --------------------------------
+
+pl.figure(2, (3.5, 1.1), facecolor='darkgray')
+pl.clf()
+# plot the OT plan
+for i in range(M.shape[0]):
+ for j in range(M.shape[1]):
+ if T[i, j] > 1e-8:
+ pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1)
+# plot the samples
+pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w')
+pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w')
+
+pl.axis('equal')
+pl.axis('off')
+
+# Save logo file
+# pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo_dark.png', dpi=150, transparent=True, bbox_inches='tight')
diff --git a/examples/others/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py
new file mode 100644
index 0000000..2023649
--- /dev/null
+++ b/examples/others/plot_screenkhorn_1D.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+"""
+========================================
+Screened optimal transport (Screenkhorn)
+========================================
+
+This example illustrates the computation of Screenkhorn [26].
+
+[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+Screening Sinkhorn Algorithm for Regularized Optimal Transport,
+Advances in Neural Information Processing Systems 33 (NeurIPS).
+"""
+
+# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+from ot.bregman import screenkhorn
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve Screenkhorn
+# -----------------------
+
+# Screenkhorn
+lambd = 2e-03 # entropy parameter
+ns_budget = 30 # budget number of points to be keeped in the source distribution
+nt_budget = 30 # budget number of points to be keeped in the target distribution
+
+G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True)
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn')
+pl.show()
diff --git a/examples/others/plot_stochastic.py b/examples/others/plot_stochastic.py
new file mode 100644
index 0000000..3a1ef31
--- /dev/null
+++ b/examples/others/plot_stochastic.py
@@ -0,0 +1,189 @@
+"""
+===================
+Stochastic examples
+===================
+
+This example is designed to show how to use the stochatic optimization
+algorithms for discrete and semi-continuous measures from the POT library.
+
+[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F.
+Stochastic Optimization for Large-scale Optimal Transport.
+Advances in Neural Information Processing Systems (2016).
+
+[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. &
+Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
+International Conference on Learning Representation (2018)
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+import ot
+import ot.plot
+
+
+#############################################################################
+# Compute the Transportation Matrix for the Semi-Dual Problem
+# -----------------------------------------------------------
+#
+# Discrete case
+# `````````````
+#
+# Sample two discrete measures for the discrete case and compute their cost
+# matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+# Call the "SAG" method to find the transportation matrix in the discrete case
+
+method = "SAG"
+sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax)
+print(sag_pi)
+
+#############################################################################
+# Semi-Continuous Case
+# ````````````````````
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case, the points where source and target measures are defined and compute the
+# cost matrix.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+# Call the "ASGD" method to find the transportation matrix in the semicontinous
+# case.
+
+method = "ASGD"
+asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, log=log)
+print(log_asgd['alpha'], log_asgd['beta'])
+print(asgd_pi)
+
+#############################################################################
+# Compare the results with the Sinkhorn algorithm
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+
+##############################################################################
+# Plot Transportation Matrices
+# ````````````````````````````
+#
+# For SAG
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
+pl.show()
+
+
+##############################################################################
+# For ASGD
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
+pl.show()
+
+
+##############################################################################
+# For Sinkhorn
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()
+
+
+#############################################################################
+# Compute the Transportation Matrix for the Dual Problem
+# ------------------------------------------------------
+#
+# Semi-continuous case
+# ````````````````````
+#
+# Sample one general measure a, one discrete measures b for the semi-continuous
+# case and compute the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 100000
+lr = 0.1
+batch_size = 3
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SGD" dual method to find the transportation matrix in the
+# semi-continuous case
+
+sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
+ batch_size, numItermax,
+ lr, log=log)
+print(log_sgd['alpha'], log_sgd['beta'])
+print(sgd_dual_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ```````````````````````````````````````````````
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+##############################################################################
+# Plot Transportation Matrices
+# ````````````````````````````
+#
+# For SGD
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
+pl.show()
+
+
+##############################################################################
+# For Sinkhorn
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()