From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- examples/others/plot_WeakOT_VS_OT.py | 98 ++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 examples/others/plot_WeakOT_VS_OT.py (limited to 'examples/others') diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 0000000..a29c875 --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs, xt, a, b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(3, (8, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1, 2, 2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(4, (8, 5)) + +pl.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('OT matrix with samples') + +pl.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Weak OT matrix with samples') -- cgit v1.2.3