summaryrefslogtreecommitdiff
path: root/examples/unbalanced-partial/plot_UOT_1D.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/unbalanced-partial/plot_UOT_1D.py')
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 183849c..06dd02d 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
pl.show()
+
+
+# %%
+# plot the transported mass
+# -------------------------
+
+pl.figure(4, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source')
+pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target')
+pl.legend(loc='upper right')
+pl.title('Distributions and transported mass for UOT')