summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py570
1 files changed, 542 insertions, 28 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 2cd832b..f1f8437 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Bregman projections for regularized OT
+Bregman projections solvers for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -8,12 +8,16 @@ Bregman projections for regularized OT
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Hicham Janati <hicham.janati@inria.fr>
+# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+# Alexander Tong <alexander.tong@yale.edu>
+# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
#
# License: MIT License
import numpy as np
import warnings
from .utils import unif, dist
+from scipy.optimize import fmin_l_bfgs_b
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -536,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
old_v = v[i_2]
v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
G[:, i_2] = u * K[:, i_2] * v[i_2]
- #aviol = (G@one_m - a)
- #aviol_2 = (G.T@one_n - b)
+ # aviol = (G@one_m - a)
+ # aviol_2 = (G.T@one_n - b)
viol += (-old_v + v[i_2]) * K[:, i_2] * u
viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
- #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+ # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
if stopThr_val <= stopThr:
break
@@ -905,11 +909,6 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
else:
alpha, beta = warmstart
- def get_K(alpha, beta):
- """log space computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1))
- - beta.reshape((1, dim_b))) / reg)
-
# print(np.min(K))
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
@@ -937,7 +936,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# the 10th iterations
transp = G
err = np.linalg.norm(
- (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2
+ (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
if log:
log['err'].append(err)
@@ -963,7 +962,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
- assert(len(weights) == alldistribT.shape[1])
+ assert (len(weights) == alldistribT.shape[1])
return np.exp(np.dot(np.log(alldistribT), weights.T))
@@ -1037,11 +1036,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
"""
if method.lower() == 'sinkhorn':
- return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
+ return barycenter_sinkhorn(A, M, reg, weights=weights,
+ numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ return barycenter_stabilized(A, M, reg, weights=weights,
+ numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else:
@@ -1103,7 +1104,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
if weights is None:
weights = np.ones(A.shape[1]) / A.shape[1]
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1201,7 +1202,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1329,7 +1330,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if weights is None:
weights = np.ones(A.shape[0]) / A.shape[0]
else:
- assert(len(weights) == A.shape[0])
+ assert (len(weights) == A.shape[0])
if log:
log = {'err': []}
@@ -1342,12 +1343,17 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
err = 1
# build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
t = np.linspace(0, 1, A.shape[1])
[Y, X] = np.meshgrid(t, t)
- xi1 = np.exp(-(X - Y)**2 / reg)
+ xi1 = np.exp(-(X - Y) ** 2 / reg)
+
+ t = np.linspace(0, 1, A.shape[2])
+ [Y, X] = np.meshgrid(t, t)
+ xi2 = np.exp(-(X - Y) ** 2 / reg)
def K(x):
- return np.dot(np.dot(xi1, x), xi1)
+ return np.dot(np.dot(xi1, x), xi2)
while (err > stopThr and cpt < numItermax):
@@ -1492,6 +1498,164 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1)
+def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
+
+ s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
+
+ where :
+
+ - :math:`\lambda_k` is the weight of k-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+
+ The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+
+ The algorithm used for solving the problem is the Iterative Bregman projections algorithm
+ with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
+
+ Parameters
+ ----------
+ Xs : list of K np.ndarray(nsk,d)
+ features of all source domains' samples
+ Ys : list of K np.ndarray(nsk,)
+ labels of all source domains' samples
+ Xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float
+ Regularization term > 0
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on relative change in the barycenter (>0)
+ log : bool, optional
+ record log if True
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+
+ Returns
+ -------
+ h : (C,) ndarray
+ proportion estimation in the target domain
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
+
+ '''
+ nbclasses = len(np.unique(Ys[0]))
+ nbdomains = len(Xs)
+
+ # log dictionary
+ if log:
+ log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []}
+
+ K = []
+ M = []
+ D1 = []
+ D2 = []
+
+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
+ for d in range(nbdomains):
+ dom = {}
+ nsk = Xs[d].shape[0] # get number of elements for this domain
+ dom['nbelem'] = nsk
+ classes = np.unique(Ys[d]) # get number of classes for this domain
+
+ # format classes to start from 0 for convenience
+ if np.min(classes) != 0:
+ Ys[d] = Ys[d] - np.min(classes)
+ classes = np.unique(Ys[d])
+
+ # build the corresponding D_1 and D_2 matrices
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
+
+ for c in classes:
+ nbelemperclass = np.sum(Ys[d] == c)
+ if nbelemperclass != 0:
+ Dtmp1[int(c), Ys[d] == c] = 1.
+ Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
+ D1.append(Dtmp1)
+ D2.append(Dtmp2)
+
+ # build the cost matrix and the Gibbs kernel
+ Mtmp = dist(Xs[d], Xt, metric=metric)
+ M.append(Mtmp)
+
+ Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
+ np.divide(Mtmp, -reg, out=Ktmp)
+ np.exp(Ktmp, out=Ktmp)
+ K.append(Ktmp)
+
+ # uniform target distribution
+ a = unif(np.shape(Xt)[0])
+
+ cpt = 0 # iterations count
+ err = 1
+ old_bary = np.ones((nbclasses))
+
+ while (err > stopThr and cpt < numItermax):
+
+ bary = np.zeros((nbclasses))
+
+ # update coupling matrices for marginal constraints w.r.t. uniform target distribution
+ for d in range(nbdomains):
+ K[d] = projC(K[d], a)
+ other = np.sum(K[d], axis=1)
+ bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
+
+ bary = np.exp(bary)
+
+ # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
+ for d in range(nbdomains):
+ new = np.dot(D2[d].T, bary)
+ K[d] = projR(K[d], new)
+
+ err = np.linalg.norm(bary - old_bary)
+ cpt = cpt + 1
+ old_bary = bary
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ bary = bary / np.sum(bary)
+
+ if log:
+ log['niter'] = cpt
+ log['M'] = M
+ log['D1'] = D1
+ log['D2'] = D2
+ log['gamma'] = K
+ return bary, log
+ else:
+ return bary
+
+
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, verbose=False,
log=False, **kwargs):
@@ -1583,7 +1747,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return pi
-def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1665,14 +1830,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
M = dist(X_s, X_t, metric=metric)
if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
return sinkhorn_loss
-def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -1758,11 +1926,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log:
- sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose, log=log, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -1777,11 +1948,354 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div), log
else:
- sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
+
+
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
+ maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
+ r""""
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
+
+ The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+
+ ..math::
+ (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+
+ where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+
+ s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+
+ e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+
+ The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+
+
+ Parameters
+ ----------
+ a : `numpy.ndarray`, shape=(ns,)
+ samples weights in the source domain
+
+ b : `numpy.ndarray`, shape=(nt,)
+ samples weights in the target domain
+
+ M : `numpy.ndarray`, shape=(ns, nt)
+ Cost matrix
+
+ reg : `float`
+ Level of the entropy regularisation
+
+ ns_budget : `int`, deafult=None
+ Number budget of points to be keeped in the source domain
+ If it is None then 50% of the source sample points will be keeped
+
+ nt_budget : `int`, deafult=None
+ Number budget of points to be keeped in the target domain
+ If it is None then 50% of the target sample points will be keeped
+
+ uniform : `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
+
+ restricted : `bool`, default=True
+ If `True`, a warm-start initialization for the L-BFGS-B solver
+ using a restricted Sinkhorn algorithm with at most 5 iterations
+
+ maxiter : `int`, default=10000
+ Maximum number of iterations in LBFGS solver
+
+ maxfun : `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+
+ pgtol : `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
+
+ verbose : `bool`, default=False
+ If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
+ and epsilon
+
+ Dependency
+ ----------
+ To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
+ in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+
+
+ Returns
+ -------
+ gamma : `numpy.ndarray`, shape=(ns, nt)
+ Screened optimal transportation matrix for the given parameters
+
+ log : `dict`, default=False
+ Log dictionary return only if log==True in parameters
+
+
+ References
+ -----------
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+
+ """
+ # check if bottleneck module exists
+ try:
+ import bottleneck
+ except ImportError:
+ warnings.warn(
+ "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ bottleneck = np
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+ ns, nt = M.shape
+
+ # by default, we keep only 50% of the sample data points
+ if ns_budget is None:
+ ns_budget = int(np.floor(0.5 * ns))
+ if nt_budget is None:
+ nt_budget = int(np.floor(0.5 * nt))
+
+ # calculate the Gibbs kernel
+ K = np.empty_like(M)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ def projection(u, epsilon):
+ u[u <= epsilon] = epsilon
+ return u
+
+ # ----------------------------------------------------------------------------------------------------------------#
+ # Step 1: Screening pre-processing #
+ # ----------------------------------------------------------------------------------------------------------------#
+
+ if ns_budget == ns and nt_budget == nt:
+ # full number of budget points (ns, nt) = (ns_budget, nt_budget)
+ Isel = np.ones(ns, dtype=bool)
+ Jsel = np.ones(nt, dtype=bool)
+ epsilon = 0.0
+ kappa = 1.0
+
+ cst_u = 0.
+ cst_v = 0.
+
+ bounds_u = [(0.0, np.inf)] * ns
+ bounds_v = [(0.0, np.inf)] * nt
+
+ a_I = a
+ b_J = b
+ K_IJ = K
+ K_IJc = []
+ K_IcJ = []
+
+ vec_eps_IJc = np.zeros(nt)
+ vec_eps_IcJ = np.zeros(ns)
+
+ else:
+ # sum of rows and columns of K
+ K_sum_cols = K.sum(axis=1)
+ K_sum_rows = K.sum(axis=0)
+
+ if uniform:
+ if ns / ns_budget < 4:
+ aK_sort = np.sort(K_sum_cols)
+ epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
+ else:
+ aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ epsilon_u_square = a[0] / aK_sort
+
+ if nt / nt_budget < 4:
+ bK_sort = np.sort(K_sum_rows)
+ epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
+ else:
+ bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ epsilon_v_square = b[0] / bK_sort
+ else:
+ aK = a / K_sum_cols
+ bK = b / K_sum_rows
+
+ aK_sort = np.sort(aK)[::-1]
+ epsilon_u_square = aK_sort[ns_budget - 1]
+
+ bK_sort = np.sort(bK)[::-1]
+ epsilon_v_square = bK_sort[nt_budget - 1]
+
+ # active sets I and J (see Lemma 1 in [26])
+ Isel = a >= epsilon_u_square * K_sum_cols
+ Jsel = b >= epsilon_v_square * K_sum_rows
+
+ if sum(Isel) != ns_budget:
+ if uniform:
+ aK = a / K_sum_cols
+ aK_sort = np.sort(aK)[::-1]
+ epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ Isel = a >= epsilon_u_square * K_sum_cols
+ ns_budget = sum(Isel)
+
+ if sum(Jsel) != nt_budget:
+ if uniform:
+ bK = b / K_sum_rows
+ bK_sort = np.sort(bK)[::-1]
+ epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ Jsel = b >= epsilon_v_square * K_sum_rows
+ nt_budget = sum(Jsel)
+
+ epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
+ kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
+
+ if verbose:
+ print("epsilon = %s\n" % epsilon)
+ print("kappa = %s\n" % kappa)
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+
+ # Ic, Jc: complementary of the active sets I and J
+ Ic = ~Isel
+ Jc = ~Jsel
+
+ K_IJ = K[np.ix_(Isel, Jsel)]
+ K_IcJ = K[np.ix_(Ic, Jsel)]
+ K_IJc = K[np.ix_(Isel, Jc)]
+
+ K_min = K_IJ.min()
+ if K_min == 0:
+ K_min = np.finfo(float).tiny
+
+ # a_I, b_J, a_Ic, b_Jc
+ a_I = a[Isel]
+ b_J = b[Jsel]
+ if not uniform:
+ a_I_min = a_I.min()
+ a_I_max = a_I.max()
+ b_J_max = b_J.max()
+ b_J_min = b_J.min()
+ else:
+ a_I_min = a_I[0]
+ a_I_max = a_I[0]
+ b_J_max = b_J[0]
+ b_J_min = b_J[0]
+
+ # box constraints in L-BFGS-B (see Proposition 1 in [26])
+ bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+
+ bounds_v = [(
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+
+ # pre-calculated constants for the objective
+ vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
+ vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+
+ # initialisation
+ u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
+ v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+
+ # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
+ if restricted:
+ if ns_budget != ns or nt_budget != nt:
+ cst_u = kappa * epsilon * K_IJc.sum(axis=1)
+ cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+
+ cpt = 1
+ while cpt < 5: # 5 iterations
+ K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ v0 = b_J / (kappa * K_IJ_v)
+ KIJ_u = np.dot(K_IJ, v0) + cst_u
+ u0 = (kappa * a_I) / KIJ_u
+ cpt += 1
+
+ u0 = projection(u0, epsilon / kappa)
+ v0 = projection(v0, epsilon * kappa)
+
+ else:
+ u0 = u0
+ v0 = v0
+
+ def restricted_sinkhorn(usc, vsc, max_iter=5):
+ """
+ Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ """
+ cpt = 1
+ while cpt < max_iter:
+ K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ vsc = b_J / (kappa * K_IJ_v)
+ KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ usc = (kappa * a_I) / KIJ_u
+ cpt += 1
+
+ usc = projection(usc, epsilon / kappa)
+ vsc = projection(vsc, epsilon * kappa)
+
+ return usc, vsc
+
+ def screened_obj(usc, vsc):
+ part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
+ np.log(vsc))
+ part_IJc = np.dot(usc, vec_eps_IJc)
+ part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ psi_epsilon = part_IJ + part_IJc + part_IcJ
+ return psi_epsilon
+
+ def screened_grad(usc, vsc):
+ # gradients of Psi_(kappa,epsilon) w.r.t u and v
+ grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ return grad_u, grad_v
+
+ def bfgspost(theta):
+ u = theta[:ns_budget]
+ v = theta[ns_budget:]
+ # objective
+ f = screened_obj(u, v)
+ # gradient
+ g_u, g_v = screened_grad(u, v)
+ g = np.hstack([g_u, g_v])
+ return f, g
+
+ # ----------------------------------------------------------------------------------------------------------------#
+ # Step 2: L-BFGS-B solver #
+ # ----------------------------------------------------------------------------------------------------------------#
+
+ u0, v0 = restricted_sinkhorn(u0, v0)
+ theta0 = np.hstack([u0, v0])
+
+ bounds = bounds_u + bounds_v # constraint bounds
+
+ def obj(theta):
+ return bfgspost(theta)
+
+ theta, _, _ = fmin_l_bfgs_b(func=obj,
+ x0=theta0,
+ bounds=bounds,
+ maxfun=maxfun,
+ pgtol=pgtol,
+ maxiter=maxiter)
+
+ usc = theta[:ns_budget]
+ vsc = theta[ns_budget:]
+
+ usc_full = np.full(ns, epsilon / kappa)
+ vsc_full = np.full(nt, epsilon * kappa)
+ usc_full[Isel] = usc
+ vsc_full[Jsel] = vsc
+
+ if log:
+ log = {}
+ log['u'] = usc_full
+ log['v'] = vsc_full
+ log['Isel'] = Isel
+ log['Jsel'] = Jsel
+
+ gamma = usc_full[:, None] * K * vsc_full[None, :]
+ gamma = gamma / gamma.sum()
+
+ if log:
+ return gamma, log
+ else:
+ return gamma