diff options
author | ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> | 2021-11-03 17:29:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-03 17:29:16 +0100 |
commit | 9c6ac880d426b7577918b0c77bd74b3b01930ef6 (patch) | |
tree | 93b0899a0378a6fe8f063800091252d2c6ad9801 /ot/dr.py | |
parent | e1b67c641da3b3e497db6811af2c200022b10302 (diff) |
[MRG] Docs updates (#298)
* bregman docs
* sliced docs
* docs partial
* unbalanced docs
* stochastic docs
* plot docs
* datasets docs
* utils docs
* dr docs
* dr docs corrected
* smooth docs
* docs da
* pep8
* docs gromov
* more space after min and argmin
* docs lp
* bregman docs
* bregman docs mistake corrected
* pep8
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/dr.py')
-rw-r--r-- | ot/dr.py | 40 |
1 files changed, 23 insertions, 17 deletions
@@ -22,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions def dist(x1, x2): - """ Compute squared euclidean distance between samples (autograd) + r""" Compute squared euclidean distance between samples (autograd) """ x1p2 = np.sum(np.square(x1), 1) x2p2 = np.sum(np.square(x2), 1) @@ -30,7 +30,7 @@ def dist(x1, x2): def sinkhorn(w1, w2, M, reg, k): - """Sinkhorn algorithm with fixed number of iteration (autograd) + r"""Sinkhorn algorithm with fixed number of iteration (autograd) """ K = np.exp(-M / reg) ui = np.ones((M.shape[0],)) @@ -43,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k): def split_classes(X, y): - """split samples in X by classes in y + r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ lstsclass = np.unique(y) return [X[y == i, :].astype(np.float32) for i in lstsclass] def fda(X, y, p=2, reg=1e-16): - """Fisher Discriminant Analysis + r"""Fisher Discriminant Analysis Parameters ---------- @@ -111,18 +111,19 @@ def fda(X, y, p=2, reg=1e-16): def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): r""" - Wasserstein Discriminant Analysis [11]_ + Wasserstein Discriminant Analysis :ref:`[11] <references-wda>` The function solves the following optimization problem: .. math:: - P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} + \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad + \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)} where : - - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold + - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold - :math:`W` is entropic regularized Wasserstein distances - - :math:`X^i` are samples in the dataset corresponding to class i + - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i Parameters ---------- @@ -140,7 +141,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional - Normalise the Wasserstaiun distane by the average distance on P0 (default : False) + Normalise the Wasserstaiun distance by the average distance on P0 (default : False) verbose : int, optional Print information along iterations. @@ -151,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no proj : callable Projection function including mean centering. + + .. _references-wda: References ---------- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). @@ -217,27 +220,28 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): r""" - Projection Robust Wasserstein Distance [32] + Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>` The function solves the following optimization problem: .. math:: - \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) + \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j} + \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi) - - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold + - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold - :math:`H(\pi)` is entropy regularizer - - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively + - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively Parameters ---------- X : ndarray, shape (n, d) - Samples from measure \mu + Samples from measure :math:`\mu` Y : ndarray, shape (n, d) - Samples from measure \nu + Samples from measure :math:`\nu` a : ndarray, shape (n, ) - weights for measure \mu + weights for measure :math:`\mu` b : ndarray, shape (n, ) - weights for measure \nu + weights for measure :math:`\nu` tau : float stepsize for Riemannian Gradient Descent U0 : ndarray, shape (d, p) @@ -258,6 +262,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh U : ndarray, shape (d, k) Projection operator. + + .. _references-projection-robust-wasserstein: References ---------- .. [32] Huang, M. , Ma S. & Lai L. (2021). |