# -*- coding: utf-8 -*- """ Bregman projections for regularized OT with GPU """ # Author: Remi Flamary # Leo Gautheron # # License: MIT License import cupy as np # np used for matrix computation import cupy as cp # cp used for cupy specific operations from . import utils def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): """ Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) s.t. \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - M is the (ns,nt) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ Parameters ---------- a : np.ndarray (ns,) samples weights in the source domain b : np.ndarray (nt,) or np.ndarray (nt,nbb) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float Regularization term >0 numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshol on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- gamma : (ns x nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters Examples -------- >>> import ot >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] >>> ot.sinkhorn(a,b,M,1) array([[ 0.36552929, 0.13447071], [ 0.13447071, 0.36552929]]) References ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT """ a = cp.asarray(a) b = cp.asarray(b) M = cp.asarray(M) if len(a) == 0: a = np.ones((M.shape[0],)) / M.shape[0] if len(b) == 0: b = np.ones((M.shape[1],)) / M.shape[1] # init data Nini = len(a) Nfin = len(b) if len(b.shape) > 1: nbb = b.shape[1] else: nbb = 0 if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if nbb: u = np.ones((Nini, nbb)) / Nini v = np.ones((Nfin, nbb)) / Nfin else: u = np.ones(Nini) / Nini v = np.ones(Nfin) / Nfin # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) # print(np.min(K)) tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 err = 1 while (err > stopThr and cpt < numItermax): uprev = u vprev = v KtransposeU = np.dot(K.T, u) v = np.divide(b, KtransposeU) u = 1. / np.dot(Kp, v) if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) u = uprev v = vprev break if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if nbb: err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 tmp2 = np.sum(u[:, None] * K * v[None, :], 0) #tmp2=np.einsum('i,ij,j->j', u, K, v) err = np.linalg.norm(tmp2 - b)**2 # violation of marginal if log: log['err'].append(err) if verbose: if cpt % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if log: log['u'] = u log['v'] = v if nbb: # return only loss #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) res = np.empty(nbb) for i in range(nbb): res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) if to_numpy: res = utils.to_np(res) if log: return res, log else: return res else: # return OT matrix res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) if to_numpy: res = utils.to_np(res) if log: return res, log else: return res # define sinkhorn as sinkhorn_knopp sinkhorn = sinkhorn_knopp