summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py338
1 files changed, 338 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index ba5c7ba..c304b5d 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -8,12 +8,14 @@ Bregman projections for regularized OT
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Hicham Janati <hicham.janati@inria.fr>
+# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
#
# License: MIT License
import numpy as np
import warnings
from .utils import unif, dist
+from scipy.optimize import fmin_l_bfgs_b
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -1787,3 +1789,339 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
+
+
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
+ maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
+ r""""
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
+
+ The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+
+ ..math::
+ (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+
+ where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+
+ s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+
+ e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+
+ The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+
+
+ Parameters
+ ----------
+ a : `numpy.ndarray`, shape=(ns,)
+ samples weights in the source domain
+
+ b : `numpy.ndarray`, shape=(nt,)
+ samples weights in the target domain
+
+ M : `numpy.ndarray`, shape=(ns, nt)
+ Cost matrix
+
+ reg : `float`
+ Level of the entropy regularisation
+
+ ns_budget : `int`, deafult=None
+ Number budget of points to be keeped in the source domain
+ If it is None then 50% of the source sample points will be keeped
+
+ nt_budget : `int`, deafult=None
+ Number budget of points to be keeped in the target domain
+ If it is None then 50% of the target sample points will be keeped
+
+ uniform : `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
+
+ restricted : `bool`, default=True
+ If `True`, a warm-start initialization for the L-BFGS-B solver
+ using a restricted Sinkhorn algorithm with at most 5 iterations
+
+ maxiter : `int`, default=10000
+ Maximum number of iterations in LBFGS solver
+
+ maxfun : `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+
+ pgtol : `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
+
+ verbose : `bool`, default=False
+ If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
+ and epsilon
+
+ Dependency
+ ----------
+ To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
+ in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+
+
+ Returns
+ -------
+ gamma : `numpy.ndarray`, shape=(ns, nt)
+ Screened optimal transportation matrix for the given parameters
+
+ log : `dict`, default=False
+ Log dictionary return only if log==True in parameters
+
+
+ References
+ -----------
+ .. [26] Alaya M. Z., BĂ©rar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+
+ """
+ # check if bottleneck module exists
+ try:
+ import bottleneck
+ except ImportError:
+ print("Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/")
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+ ns, nt = M.shape
+
+ # by default, we keep only 50% of the sample data points
+ if ns_budget is None:
+ ns_budget = int(np.floor(0.5 * ns))
+ if nt_budget is None:
+ nt_budget = int(np.floor(0.5 * nt))
+
+ # calculate the Gibbs kernel
+ K = np.empty_like(M)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ def projection(u, epsilon):
+ u[u <= epsilon] = epsilon
+ return u
+
+ # ----------------------------------------------------------------------------------------------------------------#
+ # Step 1: Screening pre-processing #
+ # ----------------------------------------------------------------------------------------------------------------#
+
+ if ns_budget == ns and nt_budget == nt:
+ # full number of budget points (ns, nt) = (ns_budget, nt_budget)
+ Isel = np.ones(ns, dtype=bool)
+ Jsel = np.ones(nt, dtype=bool)
+ epsilon = 0.0
+ kappa = 1.0
+
+ cst_u = 0.
+ cst_v = 0.
+
+ bounds_u = [(0.0, np.inf)] * ns
+ bounds_v = [(0.0, np.inf)] * nt
+
+ a_I = a
+ b_J = b
+ K_IJ = K
+ K_IJc = []
+ K_IcJ = []
+
+ vec_eps_IJc = np.zeros(nt)
+ vec_eps_IcJ = np.zeros(ns)
+
+ else:
+ # sum of rows and columns of K
+ K_sum_cols = K.sum(axis=1)
+ K_sum_rows = K.sum(axis=0)
+
+ if uniform:
+ if ns / ns_budget < 4:
+ aK_sort = np.sort(K_sum_cols)
+ epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
+ else:
+ aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ epsilon_u_square = a[0] / aK_sort
+
+ if nt / nt_budget < 4:
+ bK_sort = np.sort(K_sum_rows)
+ epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
+ else:
+ bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ epsilon_v_square = b[0] / bK_sort
+ else:
+ aK = a / K_sum_cols
+ bK = b / K_sum_rows
+
+ aK_sort = np.sort(aK)[::-1]
+ epsilon_u_square = aK_sort[ns_budget - 1]
+
+ bK_sort = np.sort(bK)[::-1]
+ epsilon_v_square = bK_sort[nt_budget - 1]
+
+ # active sets I and J (see Lemma 1 in [26])
+ Isel = a >= epsilon_u_square * K_sum_cols
+ Jsel = b >= epsilon_v_square * K_sum_rows
+
+ if sum(Isel) != ns_budget:
+ if uniform:
+ aK = a / K_sum_cols
+ aK_sort = np.sort(aK)[::-1]
+ epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ Isel = a >= epsilon_u_square * K_sum_cols
+ ns_budget = sum(Isel)
+
+ if sum(Jsel) != nt_budget:
+ if uniform:
+ bK = b / K_sum_rows
+ bK_sort = np.sort(bK)[::-1]
+ epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ Jsel = b >= epsilon_v_square * K_sum_rows
+ nt_budget = sum(Jsel)
+
+ epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
+ kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
+
+ if verbose:
+ print("epsilon = %s\n" % epsilon)
+ print("kappa = %s\n" % kappa)
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+
+ # Ic, Jc: complementary of the active sets I and J
+ Ic = ~Isel
+ Jc = ~Jsel
+
+ K_IJ = K[np.ix_(Isel, Jsel)]
+ K_IcJ = K[np.ix_(Ic, Jsel)]
+ K_IJc = K[np.ix_(Isel, Jc)]
+
+ K_min = K_IJ.min()
+ if K_min == 0:
+ K_min = np.finfo(float).tiny
+
+ # a_I, b_J, a_Ic, b_Jc
+ a_I = a[Isel]
+ b_J = b[Jsel]
+ if not uniform:
+ a_I_min = a_I.min()
+ a_I_max = a_I.max()
+ b_J_max = b_J.max()
+ b_J_min = b_J.min()
+ else:
+ a_I_min = a_I[0]
+ a_I_max = a_I[0]
+ b_J_max = b_J[0]
+ b_J_min = b_J[0]
+
+ # box constraints in L-BFGS-B (see Proposition 1 in [26])
+ bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+
+ bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+
+ # pre-calculated constants for the objective
+ vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
+ vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+
+ # initialisation
+ u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
+ v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+
+ # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
+ if restricted:
+ if ns_budget != ns or nt_budget != nt:
+ cst_u = kappa * epsilon * K_IJc.sum(axis=1)
+ cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+
+ cpt = 1
+ while cpt < 5: # 5 iterations
+ K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ v0 = b_J / (kappa * K_IJ_v)
+ KIJ_u = np.dot(K_IJ, v0) + cst_u
+ u0 = (kappa * a_I) / KIJ_u
+ cpt += 1
+
+ u0 = projection(u0, epsilon / kappa)
+ v0 = projection(v0, epsilon * kappa)
+
+ else:
+ u0 = u0
+ v0 = v0
+
+ def restricted_sinkhorn(usc, vsc, max_iter=5):
+ """
+ Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ """
+ cpt = 1
+ while cpt < max_iter:
+ K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ vsc = b_J / (kappa * K_IJ_v)
+ KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ usc = (kappa * a_I) / KIJ_u
+ cpt += 1
+
+ usc = projection(usc, epsilon / kappa)
+ vsc = projection(vsc, epsilon * kappa)
+
+ return usc, vsc
+
+ def screened_obj(usc, vsc):
+ part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc))
+ part_IJc = np.dot(usc, vec_eps_IJc)
+ part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ psi_epsilon = part_IJ + part_IJc + part_IcJ
+ return psi_epsilon
+
+ def screened_grad(usc, vsc):
+ # gradients of Psi_(kappa,epsilon) w.r.t u and v
+ grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ return grad_u, grad_v
+
+ def bfgspost(theta):
+ u = theta[:ns_budget]
+ v = theta[ns_budget:]
+ # objective
+ f = screened_obj(u, v)
+ # gradient
+ g_u, g_v = screened_grad(u, v)
+ g = np.hstack([g_u, g_v])
+ return f, g
+
+ #----------------------------------------------------------------------------------------------------------------#
+ # Step 2: L-BFGS-B solver #
+ #----------------------------------------------------------------------------------------------------------------#
+
+ u0, v0 = restricted_sinkhorn(u0, v0)
+ theta0 = np.hstack([u0, v0])
+
+ bounds = bounds_u + bounds_v # constraint bounds
+
+ def obj(theta):
+ return bfgspost(theta)
+
+ theta, _, _ = fmin_l_bfgs_b(func=obj,
+ x0=theta0,
+ bounds=bounds,
+ maxfun=maxfun,
+ pgtol=pgtol,
+ maxiter=maxiter)
+
+ usc = theta[:ns_budget]
+ vsc = theta[ns_budget:]
+
+ usc_full = np.full(ns, epsilon / kappa)
+ vsc_full = np.full(nt, epsilon * kappa)
+ usc_full[Isel] = usc
+ vsc_full[Jsel] = vsc
+
+ if log:
+ log = {}
+ log['u'] = usc_full
+ log['v'] = vsc_full
+ log['Isel'] = Isel
+ log['Jsel'] = Jsel
+
+ gamma = usc_full[:, None] * K * vsc_full[None, :]
+ gamma = gamma / gamma.sum()
+
+ if log:
+ return gamma, log
+ else:
+ return gamma