diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2019-07-03 14:34:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-07-03 14:34:13 +0200 |
commit | 952503e02b1fc9bdf0811b937baacca57e4a98f1 (patch) | |
tree | c5b1b8f10eaee4c2ecaa12f629255489c2481590 /docs/source/auto_examples/plot_UOT_1D.py | |
parent | 8b3927bb5e8935c3dbddf054f054dc0c036fbdfe (diff) | |
parent | 7402d344240ce94e33c53daff419d4356278d48f (diff) |
Merge pull request #88 from rflamary/doc_modules
[MRG] Update documentation and add quick start guide
Diffstat (limited to 'docs/source/auto_examples/plot_UOT_1D.py')
-rw-r--r-- | docs/source/auto_examples/plot_UOT_1D.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_UOT_1D.py b/docs/source/auto_examples/plot_UOT_1D.py new file mode 100644 index 0000000..2ea8b05 --- /dev/null +++ b/docs/source/auto_examples/plot_UOT_1D.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +=============================== +1D Unbalanced optimal transport +=============================== + +This example illustrates the computation of Unbalanced Optimal transport +using a Kullback-Leibler relaxation. +""" + +# Author: Hicham Janati <hicham.janati@inria.fr> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# make distributions unbalanced +b *= 5. + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +############################################################################## +# Solve Unbalanced Sinkhorn +# -------------- + + +# Sinkhorn + +epsilon = 0.1 # entropy parameter +alpha = 1. # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') + +pl.show() |