From 7b9f4e9d4b198183929c60eceec1a419b5212e2d Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 25 Oct 2016 10:24:55 +0200 Subject: set numpy doc format --- ot/bregman.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 81804a7..17ec06f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -8,7 +8,9 @@ import numpy as np def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): """ - Solve the optimal transport problem (OT) + Solve the entropic regularization optimal transport problem and return the OT matrix + + The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) @@ -20,17 +22,20 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): \gamma\geq 0 where : - - M is the metric cost matrix - - Omega is the entropic regularization term - - a and b are the sample weights + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]_ + Parameters ---------- - a : (ns,) ndarray - samples in the source domain - b : (nt,) ndarray + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) samples in the target domain - M : (ns,nt) ndarray + M : np.ndarray (ns,nt) loss matrix reg: float() Regularization term >0 @@ -41,6 +46,16 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + References + ---------- + + .. [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.emd.emd : Unregularized optimal ransport + """ # init data Nini = len(a) -- cgit v1.2.3