summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py31
1 files changed, 23 insertions, 8 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 81804a7..17ec06f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -8,7 +8,9 @@ import numpy as np
def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
"""
- Solve the optimal transport problem (OT)
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The function solves the following optimization problem:
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
@@ -20,17 +22,20 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
\gamma\geq 0
where :
- - M is the metric cost matrix
- - Omega is the entropic regularization term
- - a and b are the sample weights
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]_
+
Parameters
----------
- a : (ns,) ndarray
- samples in the source domain
- b : (nt,) ndarray
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
samples in the target domain
- M : (ns,nt) ndarray
+ M : np.ndarray (ns,nt)
loss matrix
reg: float()
Regularization term >0
@@ -41,6 +46,16 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
gamma: (ns x nt) ndarray
Optimal transportation matrix for the given parameters
+ References
+ ----------
+
+ .. [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+
+ See Also
+ --------
+ ot.emd.emd : Unregularized optimal ransport
+
"""
# init data
Nini = len(a)