summaryrefslogtreecommitdiff
path: root/ot/gpu
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-06-09 13:11:41 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-06-09 13:11:41 +0200
commit37977810ab538fe461ef8f3ec16434c59f6f7c5a (patch)
tree873664806122c6b8554c9b5f055adba2d3014591 /ot/gpu
parent55ff888aab4e3a5f0a4a3180887d40e96fbea3d8 (diff)
add doc ot.gpu.bregman
Diffstat (limited to 'ot/gpu')
-rw-r--r--ot/gpu/bregman.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index cc610b7..2c3e317 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -9,6 +9,80 @@ import cudamat
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log=False, returnAsGPU=False):
+ """
+ Solve the entropic regularization optimal transport problem on GPU
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M_GPU : cudamat.CUDAMatrix (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ returnAsGPU : bool, optional
+ return the OT matrix as a cudamat.CUDAMatrix
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
# init data
Nini = len(a)
Nfin = len(b)