diff options
-rw-r--r-- | examples/plot_screenkhorn_1D.py | 45 | ||||
-rw-r--r-- | ot/bregman.py | 12 |
2 files changed, 26 insertions, 31 deletions
diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py index e0d7bfd..22a9ddc 100644 --- a/examples/plot_screenkhorn_1D.py +++ b/examples/plot_screenkhorn_1D.py @@ -4,16 +4,22 @@ # In[ ]: +from ot.bregman import screenkhorn +from ot.datasets import make_1D_gauss as gauss +import ot.plot +import ot +import matplotlib.pylab as pl +import numpy as np get_ipython().run_line_magic('matplotlib', 'inline') -# +# # # 1D Screened optimal transport -# -# +# +# # This example illustrates the computation of Screenkhorn: Screening Sinkhorn Algorithm for Optimal transport. -# -# +# +# # In[13]: @@ -22,18 +28,11 @@ get_ipython().run_line_magic('matplotlib', 'inline') # # License: MIT License -import numpy as np -import matplotlib.pylab as pl -import ot -import ot.plot -from ot.datasets import make_1D_gauss as gauss -from ot.bregman import screenkhorn - # Generate data # ------------- -# -# +# +# # In[14]: @@ -56,8 +55,8 @@ M /= M.max() # Plot distributions and loss matrix # ---------------------------------- -# -# +# +# # In[15]: @@ -77,17 +76,17 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') # Solve Screened Sinkhorn # -------------- -# -# +# +# # In[21]: # Screenkhorn -lambd = 1e-2 # entropy parameter -ns_budget = 30 # budget number of points to be keeped in the source distribution -nt_budget = 30 # budget number of points to be keeped in the target distribution +lambd = 1e-2 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution Gsc = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) pl.figure(4, figsize=(5, 5)) @@ -97,7 +96,3 @@ pl.show() # In[ ]: - - - - diff --git a/ot/bregman.py b/ot/bregman.py index 28377b0..4f24cf4 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1791,7 +1791,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, restricted=True, +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): """" Screening Sinkhorn Algorithm for Regularized Optimal Transport @@ -1824,18 +1824,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest reg : `float` Level of the entropy regularisation - ns_budget: `int`, deafult=None + ns_budget : `int`, deafult=None Number budget of points to be keeped in the source domain If it is None then 50% of the source sample points will be keeped - nt_budget: `int`, deafult=None + nt_budget : `int`, deafult=None Number budget of points to be keeped in the target domain If it is None then 50% of the target sample points will be keeped - uniform: `bool`, default=True - If `True`, a_i = 1. / ns and b_j = 1. / nt + uniform : `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, namely a_i = 1 / ns and b_j = 1 / nt. - restricted: `bool`, default=True + restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations |