summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-03 17:29:16 +0100
committerGitHub <noreply@github.com>2021-11-03 17:29:16 +0100
commit9c6ac880d426b7577918b0c77bd74b3b01930ef6 (patch)
tree93b0899a0378a6fe8f063800091252d2c6ad9801 /ot/dr.py
parente1b67c641da3b3e497db6811af2c200022b10302 (diff)
[MRG] Docs updates (#298)
* bregman docs * sliced docs * docs partial * unbalanced docs * stochastic docs * plot docs * datasets docs * utils docs * dr docs * dr docs corrected * smooth docs * docs da * pep8 * docs gromov * more space after min and argmin * docs lp * bregman docs * bregman docs mistake corrected * pep8 Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py40
1 files changed, 23 insertions, 17 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 7469270..c2f51f8 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -22,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1, x2):
- """ Compute squared euclidean distance between samples (autograd)
+ r""" Compute squared euclidean distance between samples (autograd)
"""
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
@@ -30,7 +30,7 @@ def dist(x1, x2):
def sinkhorn(w1, w2, M, reg, k):
- """Sinkhorn algorithm with fixed number of iteration (autograd)
+ r"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
@@ -43,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k):
def split_classes(X, y):
- """split samples in X by classes in y
+ r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
lstsclass = np.unique(y)
return [X[y == i, :].astype(np.float32) for i in lstsclass]
def fda(X, y, p=2, reg=1e-16):
- """Fisher Discriminant Analysis
+ r"""Fisher Discriminant Analysis
Parameters
----------
@@ -111,18 +111,19 @@ def fda(X, y, p=2, reg=1e-16):
def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
r"""
- Wasserstein Discriminant Analysis [11]_
+ Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
The function solves the following optimization problem:
.. math::
- P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
+ \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad
+ \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)}
where :
- - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
+ - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold
- :math:`W` is entropic regularized Wasserstein distances
- - :math:`X^i` are samples in the dataset corresponding to class i
+ - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
Parameters
----------
@@ -140,7 +141,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
P0 : ndarray, shape (d, p)
Initial starting point for projection.
normalize : bool, optional
- Normalise the Wasserstaiun distane by the average distance on P0 (default : False)
+ Normalise the Wasserstaiun distance by the average distance on P0 (default : False)
verbose : int, optional
Print information along iterations.
@@ -151,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
proj : callable
Projection function including mean centering.
+
+ .. _references-wda:
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
@@ -217,27 +220,28 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
r"""
- Projection Robust Wasserstein Distance [32]
+ Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
The function solves the following optimization problem:
.. math::
- \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)
+ \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j}
+ \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi)
- - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
+ - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold
- :math:`H(\pi)` is entropy regularizer
- - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively
+ - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively
Parameters
----------
X : ndarray, shape (n, d)
- Samples from measure \mu
+ Samples from measure :math:`\mu`
Y : ndarray, shape (n, d)
- Samples from measure \nu
+ Samples from measure :math:`\nu`
a : ndarray, shape (n, )
- weights for measure \mu
+ weights for measure :math:`\mu`
b : ndarray, shape (n, )
- weights for measure \nu
+ weights for measure :math:`\nu`
tau : float
stepsize for Riemannian Gradient Descent
U0 : ndarray, shape (d, p)
@@ -258,6 +262,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
U : ndarray, shape (d, k)
Projection operator.
+
+ .. _references-projection-robust-wasserstein:
References
----------
.. [32] Huang, M. , Ma S. & Lai L. (2021).