summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_screenkhorn_1D.py45
-rw-r--r--ot/bregman.py12
2 files changed, 26 insertions, 31 deletions
diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py
index e0d7bfd..22a9ddc 100644
--- a/examples/plot_screenkhorn_1D.py
+++ b/examples/plot_screenkhorn_1D.py
@@ -4,16 +4,22 @@
# In[ ]:
+from ot.bregman import screenkhorn
+from ot.datasets import make_1D_gauss as gauss
+import ot.plot
+import ot
+import matplotlib.pylab as pl
+import numpy as np
get_ipython().run_line_magic('matplotlib', 'inline')
-#
+#
# # 1D Screened optimal transport
-#
-#
+#
+#
# This example illustrates the computation of Screenkhorn: Screening Sinkhorn Algorithm for Optimal transport.
-#
-#
+#
+#
# In[13]:
@@ -22,18 +28,11 @@ get_ipython().run_line_magic('matplotlib', 'inline')
#
# License: MIT License
-import numpy as np
-import matplotlib.pylab as pl
-import ot
-import ot.plot
-from ot.datasets import make_1D_gauss as gauss
-from ot.bregman import screenkhorn
-
# Generate data
# -------------
-#
-#
+#
+#
# In[14]:
@@ -56,8 +55,8 @@ M /= M.max()
# Plot distributions and loss matrix
# ----------------------------------
-#
-#
+#
+#
# In[15]:
@@ -77,17 +76,17 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
# Solve Screened Sinkhorn
# --------------
-#
-#
+#
+#
# In[21]:
# Screenkhorn
-lambd = 1e-2 # entropy parameter
-ns_budget = 30 # budget number of points to be keeped in the source distribution
-nt_budget = 30 # budget number of points to be keeped in the target distribution
+lambd = 1e-2 # entropy parameter
+ns_budget = 30 # budget number of points to be keeped in the source distribution
+nt_budget = 30 # budget number of points to be keeped in the target distribution
Gsc = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True)
pl.figure(4, figsize=(5, 5))
@@ -97,7 +96,3 @@ pl.show()
# In[ ]:
-
-
-
-
diff --git a/ot/bregman.py b/ot/bregman.py
index 28377b0..4f24cf4 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1791,7 +1791,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div)
-def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, restricted=True,
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
""""
Screening Sinkhorn Algorithm for Regularized Optimal Transport
@@ -1824,18 +1824,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=True, rest
reg : `float`
Level of the entropy regularisation
- ns_budget: `int`, deafult=None
+ ns_budget : `int`, deafult=None
Number budget of points to be keeped in the source domain
If it is None then 50% of the source sample points will be keeped
- nt_budget: `int`, deafult=None
+ nt_budget : `int`, deafult=None
Number budget of points to be keeped in the target domain
If it is None then 50% of the target sample points will be keeped
- uniform: `bool`, default=True
- If `True`, a_i = 1. / ns and b_j = 1. / nt
+ uniform : `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform, namely a_i = 1 / ns and b_j = 1 / nt.
- restricted: `bool`, default=True
+ restricted : `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations