diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-04-05 11:57:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-05 11:57:10 +0200 |
commit | ad02112d4288f3efdd5bc6fc6e45444313bba871 (patch) | |
tree | f6cd539450c2ed36cf5d7014debfd82e8b9fddfb /examples/unbalanced-partial/plot_UOT_1D.py | |
parent | 0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (diff) |
[MRG] Update examples in the doc (#359)
* add transparent color logo
* add transparent color logo
* move screenkhorn
* move stochastic and install ffmpeg on circleci
* try something
* add sudo
* install ffmpeg before python
* cleanup examples
* test svg scrapper
* add animation for reg path
* better example OT sivergence
* update ttles and add plots
* update free support
* proper figure indexes
* have less frame sin animation
* update readme and release file
* add tests for python 3.10
Diffstat (limited to 'examples/unbalanced-partial/plot_UOT_1D.py')
-rw-r--r-- | examples/unbalanced-partial/plot_UOT_1D.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 183849c..06dd02d 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot @@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') pl.show() + + +# %% +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') +pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') +pl.legend(loc='upper right') +pl.title('Distributions and transported mass for UOT') |