diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 338 |
1 files changed, 338 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index ba5c7ba..c304b5d 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -8,12 +8,14 @@ Bregman projections for regularized OT # Kilian Fatras <kilian.fatras@irisa.fr> # Titouan Vayer <titouan.vayer@irisa.fr> # Hicham Janati <hicham.janati@inria.fr> +# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> # # License: MIT License import numpy as np import warnings from .utils import unif, dist +from scipy.optimize import fmin_l_bfgs_b def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -1787,3 +1789,339 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) + + +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, + maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): + r"""" + Screening Sinkhorn Algorithm for Regularized Optimal Transport + + The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem: + + ..math:: + (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b> + + where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and + + s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns} + + e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt} + + The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26] + + + Parameters + ---------- + a : `numpy.ndarray`, shape=(ns,) + samples weights in the source domain + + b : `numpy.ndarray`, shape=(nt,) + samples weights in the target domain + + M : `numpy.ndarray`, shape=(ns, nt) + Cost matrix + + reg : `float` + Level of the entropy regularisation + + ns_budget : `int`, deafult=None + Number budget of points to be keeped in the source domain + If it is None then 50% of the source sample points will be keeped + + nt_budget : `int`, deafult=None + Number budget of points to be keeped in the target domain + If it is None then 50% of the target sample points will be keeped + + uniform : `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt + + restricted : `bool`, default=True + If `True`, a warm-start initialization for the L-BFGS-B solver + using a restricted Sinkhorn algorithm with at most 5 iterations + + maxiter : `int`, default=10000 + Maximum number of iterations in LBFGS solver + + maxfun : `int`, default=10000 + Maximum number of function evaluations in LBFGS solver + + pgtol : `float`, default=1e-09 + Final objective function accuracy in LBFGS solver + + verbose : `bool`, default=False + If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa + and epsilon + + Dependency + ---------- + To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) + in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: + "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" + + + Returns + ------- + gamma : `numpy.ndarray`, shape=(ns, nt) + Screened optimal transportation matrix for the given parameters + + log : `dict`, default=False + Log dictionary return only if log==True in parameters + + + References + ----------- + .. [26] Alaya M. Z., BĂ©rar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + + """ + # check if bottleneck module exists + try: + import bottleneck + except ImportError: + print("Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/") + + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) + ns, nt = M.shape + + # by default, we keep only 50% of the sample data points + if ns_budget is None: + ns_budget = int(np.floor(0.5 * ns)) + if nt_budget is None: + nt_budget = int(np.floor(0.5 * nt)) + + # calculate the Gibbs kernel + K = np.empty_like(M) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + + def projection(u, epsilon): + u[u <= epsilon] = epsilon + return u + + # ----------------------------------------------------------------------------------------------------------------# + # Step 1: Screening pre-processing # + # ----------------------------------------------------------------------------------------------------------------# + + if ns_budget == ns and nt_budget == nt: + # full number of budget points (ns, nt) = (ns_budget, nt_budget) + Isel = np.ones(ns, dtype=bool) + Jsel = np.ones(nt, dtype=bool) + epsilon = 0.0 + kappa = 1.0 + + cst_u = 0. + cst_v = 0. + + bounds_u = [(0.0, np.inf)] * ns + bounds_v = [(0.0, np.inf)] * nt + + a_I = a + b_J = b + K_IJ = K + K_IJc = [] + K_IcJ = [] + + vec_eps_IJc = np.zeros(nt) + vec_eps_IcJ = np.zeros(ns) + + else: + # sum of rows and columns of K + K_sum_cols = K.sum(axis=1) + K_sum_rows = K.sum(axis=0) + + if uniform: + if ns / ns_budget < 4: + aK_sort = np.sort(K_sum_cols) + epsilon_u_square = a[0] / aK_sort[ns_budget - 1] + else: + aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1] + epsilon_u_square = a[0] / aK_sort + + if nt / nt_budget < 4: + bK_sort = np.sort(K_sum_rows) + epsilon_v_square = b[0] / bK_sort[nt_budget - 1] + else: + bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1] + epsilon_v_square = b[0] / bK_sort + else: + aK = a / K_sum_cols + bK = b / K_sum_rows + + aK_sort = np.sort(aK)[::-1] + epsilon_u_square = aK_sort[ns_budget - 1] + + bK_sort = np.sort(bK)[::-1] + epsilon_v_square = bK_sort[nt_budget - 1] + + # active sets I and J (see Lemma 1 in [26]) + Isel = a >= epsilon_u_square * K_sum_cols + Jsel = b >= epsilon_v_square * K_sum_rows + + if sum(Isel) != ns_budget: + if uniform: + aK = a / K_sum_cols + aK_sort = np.sort(aK)[::-1] + epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean() + Isel = a >= epsilon_u_square * K_sum_cols + ns_budget = sum(Isel) + + if sum(Jsel) != nt_budget: + if uniform: + bK = b / K_sum_rows + bK_sort = np.sort(bK)[::-1] + epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean() + Jsel = b >= epsilon_v_square * K_sum_rows + nt_budget = sum(Jsel) + + epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) + kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) + + if verbose: + print("epsilon = %s\n" % epsilon) + print("kappa = %s\n" % kappa) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + + # Ic, Jc: complementary of the active sets I and J + Ic = ~Isel + Jc = ~Jsel + + K_IJ = K[np.ix_(Isel, Jsel)] + K_IcJ = K[np.ix_(Ic, Jsel)] + K_IJc = K[np.ix_(Isel, Jc)] + + K_min = K_IJ.min() + if K_min == 0: + K_min = np.finfo(float).tiny + + # a_I, b_J, a_Ic, b_Jc + a_I = a[Isel] + b_J = b[Jsel] + if not uniform: + a_I_min = a_I.min() + a_I_max = a_I.max() + b_J_max = b_J.max() + b_J_min = b_J.min() + else: + a_I_min = a_I[0] + a_I_max = a_I[0] + b_J_max = b_J[0] + b_J_min = b_J[0] + + # box constraints in L-BFGS-B (see Proposition 1 in [26]) + bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + + bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + + # pre-calculated constants for the objective + vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) + vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0) + + # initialisation + u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa) + v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa) + + # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) + if restricted: + if ns_budget != ns or nt_budget != nt: + cst_u = kappa * epsilon * K_IJc.sum(axis=1) + cst_v = epsilon * K_IcJ.sum(axis=0) / kappa + + cpt = 1 + while cpt < 5: # 5 iterations + K_IJ_v = np.dot(K_IJ.T, u0) + cst_v + v0 = b_J / (kappa * K_IJ_v) + KIJ_u = np.dot(K_IJ, v0) + cst_u + u0 = (kappa * a_I) / KIJ_u + cpt += 1 + + u0 = projection(u0, epsilon / kappa) + v0 = projection(v0, epsilon * kappa) + + else: + u0 = u0 + v0 = v0 + + def restricted_sinkhorn(usc, vsc, max_iter=5): + """ + Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + """ + cpt = 1 + while cpt < max_iter: + K_IJ_v = np.dot(K_IJ.T, usc) + cst_v + vsc = b_J / (kappa * K_IJ_v) + KIJ_u = np.dot(K_IJ, vsc) + cst_u + usc = (kappa * a_I) / KIJ_u + cpt += 1 + + usc = projection(usc, epsilon / kappa) + vsc = projection(vsc, epsilon * kappa) + + return usc, vsc + + def screened_obj(usc, vsc): + part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc)) + part_IJc = np.dot(usc, vec_eps_IJc) + part_IcJ = np.dot(vec_eps_IcJ, vsc) + psi_epsilon = part_IJ + part_IJc + part_IcJ + return psi_epsilon + + def screened_grad(usc, vsc): + # gradients of Psi_(kappa,epsilon) w.r.t u and v + grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc + grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + return grad_u, grad_v + + def bfgspost(theta): + u = theta[:ns_budget] + v = theta[ns_budget:] + # objective + f = screened_obj(u, v) + # gradient + g_u, g_v = screened_grad(u, v) + g = np.hstack([g_u, g_v]) + return f, g + + #----------------------------------------------------------------------------------------------------------------# + # Step 2: L-BFGS-B solver # + #----------------------------------------------------------------------------------------------------------------# + + u0, v0 = restricted_sinkhorn(u0, v0) + theta0 = np.hstack([u0, v0]) + + bounds = bounds_u + bounds_v # constraint bounds + + def obj(theta): + return bfgspost(theta) + + theta, _, _ = fmin_l_bfgs_b(func=obj, + x0=theta0, + bounds=bounds, + maxfun=maxfun, + pgtol=pgtol, + maxiter=maxiter) + + usc = theta[:ns_budget] + vsc = theta[ns_budget:] + + usc_full = np.full(ns, epsilon / kappa) + vsc_full = np.full(nt, epsilon * kappa) + usc_full[Isel] = usc + vsc_full[Jsel] = vsc + + if log: + log = {} + log['u'] = usc_full + log['v'] = vsc_full + log['Isel'] = Isel + log['Jsel'] = Jsel + + gamma = usc_full[:, None] * K * vsc_full[None, :] + gamma = gamma / gamma.sum() + + if log: + return gamma, log + else: + return gamma |