diff options
-rw-r--r-- | docs/source/conf.py | 2 | ||||
-rw-r--r-- | docs/source/index.rst | 25 | ||||
-rw-r--r-- | ot/bregman.py | 31 |
3 files changed, 49 insertions, 9 deletions
diff --git a/docs/source/conf.py b/docs/source/conf.py index 0caefae..e114e2c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -38,7 +38,7 @@ extensions = [ 'sphinx.ext.coverage', 'sphinx.ext.mathjax', 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', + 'sphinx.ext.viewcode','sphinx.ext.autodoc', 'sphinxcontrib.napoleon' ] # Add any paths that contain templates here, relative to this directory. diff --git a/docs/source/index.rst b/docs/source/index.rst index e96f544..8613bfd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,16 +11,41 @@ Contents: .. toctree:: :maxdepth: 2 + +Module ot +========= + +This module provide easy access to solvers for the most common OT problems + .. automodule:: ot :members: + +Module ot.emd +========= .. automodule:: ot.emd :members: + +Module ot.bregman +========= + .. automodule:: ot.bregman :members: + +Module ot.utils +========= + .. automodule:: ot.utils :members: + +Module ot.datasets +========= + .. automodule:: ot.datasets :members: + +Module ot.plot +========= + .. automodule:: ot.plot :members: diff --git a/ot/bregman.py b/ot/bregman.py index 81804a7..17ec06f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -8,7 +8,9 @@ import numpy as np def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): """ - Solve the optimal transport problem (OT) + Solve the entropic regularization optimal transport problem and return the OT matrix + + The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) @@ -20,17 +22,20 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): \gamma\geq 0 where : - - M is the metric cost matrix - - Omega is the entropic regularization term - - a and b are the sample weights + - M is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - a and b are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]_ + Parameters ---------- - a : (ns,) ndarray - samples in the source domain - b : (nt,) ndarray + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) samples in the target domain - M : (ns,nt) ndarray + M : np.ndarray (ns,nt) loss matrix reg: float() Regularization term >0 @@ -41,6 +46,16 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9): gamma: (ns x nt) ndarray Optimal transportation matrix for the given parameters + References + ---------- + + .. [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.emd.emd : Unregularized optimal ransport + """ # init data Nini = len(a) |