From 37977810ab538fe461ef8f3ec16434c59f6f7c5a Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 9 Jun 2017 13:11:41 +0200 Subject: add doc ot.gpu.bregman --- ot/gpu/bregman.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) (limited to 'ot/gpu/bregman.py') diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index cc610b7..2c3e317 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -9,6 +9,80 @@ import cudamat def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, returnAsGPU=False): + """ + Solve the entropic regularization optimal transport problem on GPU + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples in the target domain + M_GPU : cudamat.CUDAMatrix (ns,nt) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + returnAsGPU : bool, optional + return the OT matrix as a cudamat.CUDAMatrix + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ # init data Nini = len(a) Nfin = len(b) -- cgit v1.2.3