diff options
author | Mokhtar Z. Alaya <mzalaya@Mokhtars-MacBook-Pro.local> | 2020-01-08 18:54:07 +0100 |
---|---|---|
committer | Mokhtar Z. Alaya <mzalaya@Mokhtars-MacBook-Pro.local> | 2020-01-08 18:54:07 +0100 |
commit | 3e77515b4f19cf1c37b2f971a54b2fe5efe9daef (patch) | |
tree | 80ed034012902c4e83c00aa4f0469f1ac121c438 /examples/plot_screenkhorn_1D.py | |
parent | e00f46aa2ea11f0e88a5b2005caa7518ca109357 (diff) |
add illustration for screenkhorn
Diffstat (limited to 'examples/plot_screenkhorn_1D.py')
-rw-r--r-- | examples/plot_screenkhorn_1D.py | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py new file mode 100644 index 0000000..e0d7bfd --- /dev/null +++ b/examples/plot_screenkhorn_1D.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[ ]: + + +get_ipython().run_line_magic('matplotlib', 'inline') + + +# +# # 1D Screened optimal transport +# +# +# This example illustrates the computation of Screenkhorn: Screening Sinkhorn Algorithm for Optimal transport. +# +# + +# In[13]: + + +# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss +from ot.bregman import screenkhorn + + +# Generate data +# ------------- +# +# + +# In[14]: + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +# Plot distributions and loss matrix +# ---------------------------------- +# +# + +# In[15]: + + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +# Solve Screened Sinkhorn +# -------------- +# +# + +# In[21]: + + +# Screenkhorn + +lambd = 1e-2 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution + +Gsc = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Screenkhorn') + +pl.show() + + +# In[ ]: + + + + |