# -*- coding: utf-8 -*- """ Bregman projections for regularized OT with GPU """ # Author: Remi Flamary # Leo Gautheron # # License: MIT License import cupy as np # np used for matrix computation import cupy as cp # cp used for cupy specific operations from . import utils def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): """ Solve the entropic regularization optimal transport on GPU If the input matrix are in numpy format, they will be uploaded to the GPU first which can incur significant time overhead. The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) s.t. \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - M is the (ns,nt) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ Parameters ---------- a : np.ndarray (ns,) samples weights in the source domain b : np.ndarray (nt,) or np.ndarray (nt,nbb) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float Regularization term >0 numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshol on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True to_numpy : boolean, optional (default True) If true convert back the GPU array result to numpy format. Returns ------- gamma : (ns x nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters References ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT """ a = cp.asarray(a) b = cp.asarray(b) M = cp.asarray(M) if len(a) == 0: a = np.ones((M.shape[0],)) / M.shape[0] if len(b) == 0: b = np.ones((M.shape[1],)) / M.shape[1] # init data Nini = len(a) Nfin = len(b) if len(b.shape) > 1: nbb = b.shape[1] else: nbb = 0 if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if nbb: u = np.ones((Nini, nbb)) / Nini v = np.ones((Nfin, nbb)) / Nfin else: u = np.ones(Nini) / Nini v = np.ones(Nfin) / Nfin # print(reg) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) np.exp(K, out=K) # print(np.min(K)) tmp2 = np.empty(b.shape, dtype=M.dtype) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 err = 1 while (err > stopThr and cpt < numItermax): uprev = u vprev = v KtransposeU = np.dot(K.T, u) v = np.divide(b, KtransposeU) u = 1. / np.dot(Kp, v) if (np.any(KtransposeU == 0) or np.any(np.isnan(u)) or np.any(np.isnan(v)) or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) u = uprev v = vprev break if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if nbb: err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ np.sum((v - vprev)**2) / np.sum((v)**2) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 tmp2 = np.sum(u[:, None] * K * v[None, :], 0) #tmp2=np.einsum('i,ij,j->j', u, K, v) err = np.linalg.norm(tmp2 - b)**2 # violation of marginal if log: log['err'].append(err) if verbose: if cpt % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if log: log['u'] = u log['v'] = v if nbb: # return only loss #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) res = np.empty(nbb) for i in range(nbb): res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) if to_numpy: res = utils.to_np(res) if log: return res, log else: return res else: # return OT matrix res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) if to_numpy: res = utils.to_np(res) if log: return res, log else: return res # define sinkhorn as sinkhorn_knopp sinkhorn = sinkhorn_knopp