summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorMokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>2020-01-07 15:29:23 +0100
committerGitHub <noreply@github.com>2020-01-07 15:29:23 +0100
commit27b6740ea95b609ecdb103fbff7c1bbc62071ddc (patch)
tree2a2692da7cccc77c447a18f3f03b8fd36d5eb17a /ot/bregman.py
parente821872581e5e62d984883d8b8f881e35160be56 (diff)
improve documentation of screenkhorn
add Exception at the beginning to check the installation of bottleneck module
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py62
1 files changed, 45 insertions, 17 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 58c76d0..456b61f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1789,52 +1789,82 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
-def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, verbose=False):
+def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=True, verbose=False, log=False):
""""
- Screening Sinkhorn Algorithm for Regularized Optimal Transport.
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
+
+ The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+
+ ..math::
+ (u, v) = \argmin_{u, v} 1_{ns}.T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+
+ where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+
+ s.t. e^{u_i} >= \epsilon / \kappa, for all i in {1, ..., ns}
+
+ e^{v_j} >= \epsilon \kappa, for all j in {1, ..., nt}
+
+ The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+
Parameters
----------
a : `numpy.ndarray`, shape=(ns,)
- samples weights in the source domain.
+ samples weights in the source domain
b : `numpy.ndarray`, shape=(nt,)
- samples weights in the target domain.
+ samples weights in the target domain
M : `numpy.ndarray`, shape=(ns, nt)
Cost matrix.
reg : `float`
- Level of the entropy regularisation.
+ Level of the entropy regularisation
ns_budget: `int`
- Number budget of points to be keeped in the source domain.
+ Number budget of points to be keeped in the source domain
nt_budget: `int`
- Number budget of points to be keeped in the target domain.
+ Number budget of points to be keeped in the target domain
uniform: `bool`, default=True
If `True`, a_i = 1. / ns and b_j = 1. / nt
restricted: `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
- using a restricted Sinkhorn algorithm with at most 5 iterations.
+ using a restricted Sinkhorn algorithm with at most 5 iterations
verbose: `bool`, default=False
- If `True`, dispaly informations along iterations.
-
+ If `True`, dispaly informations along iterations
+
+ Dependency
+ ----------
+ To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step.
+ If Bottleneck isn't installed, the following error message appears:
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+
+
Returns
-------
- Gsc : `numpy.ndarray`, shape=(ns, nt)
- Screened optimal transportation matrix for the given parameters.
+ gamma : `numpy.ndarray`, shape=(ns, nt)
+ Screened optimal transportation matrix for the given parameters
+
+ log : `dict`, default=False
+ Log dictionary return only if log==True in parameters
- References:
+
+ References
-----------
- .. [1] M. Z. Alaya, Maxime Bérar, Gilles Gasso, Alain Rakotomamonjy. Screening Sinkhorn Algorithm for Regularized
- Optimal Transport, NeurIPS 2019.
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
"""
+ # check if bottleneck module exists
+ try:
+ import bottleneck
+ except ImportError as e:
+ print("Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/")
+
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -1892,7 +1922,6 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
aK_sort = np.sort(K_sum_cols)
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
- import bottleneck
aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
epsilon_u_square = a[0] / aK_sort
@@ -1900,7 +1929,6 @@ def screenkhorn(a, b, M, reg, ns_budget, nt_budget, uniform=True, restricted=Tru
bK_sort = np.sort(K_sum_rows)
epsilon_v_square = b[0] / bK_sort[ns_budget - 1]
else:
- import bottleneck
bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
epsilon_v_square = b[0] / bK_sort
else: