summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2019-07-09 16:48:36 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2019-07-09 16:48:36 +0200
commit1b00740a39f90f1e0bc7dc3a35723560c9ab4e97 (patch)
tree87f43f8f13cf82497999b7f7ea6b6e25fbddcb68 /ot/dr.py
parent952503e02b1fc9bdf0811b937baacca57e4a98f1 (diff)
first pass with adding pydocstyle in makefile
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py57
1 files changed, 24 insertions, 33 deletions
diff --git a/ot/dr.py b/ot/dr.py
index d2bf6e2..680dabf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -49,30 +49,25 @@ def split_classes(X, y):
def fda(X, y, p=2, reg=1e-16):
- """
- Fisher Discriminant Analysis
-
+ """Fisher Discriminant Analysis
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (ridge regularization)
-
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
+ proj : callable
projection function including mean centering
-
-
"""
mx = np.mean(X)
@@ -130,37 +125,33 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : str, optional
- None for steepest decsent or 'TrustRegions' for trust regions algorithm
- else shoudl be a pymanopt.solvers
- P0 : numpy.ndarray (d,p)
- Initial starting point for projection
+ solver : None | str, optional
+ None for steepest descent or 'TrustRegions' for trust regions algorithm
+ else should be a pymanopt.solvers
+ P0 : ndarray, shape (d, p)
+ Initial starting point for projection.
verbose : int, optional
- Print information along iterations
-
-
+ Print information along iterations.
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
- projection function including mean centering
-
+ proj : callable
+ Projection function including mean centering.
References
----------
-
- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
mx = np.mean(X)