summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
commitb2a7afb848a78570d01f35f9b239be8838520edc (patch)
treefc243208d24f5488d5ce06298b2ebb39b76be9bb /ot/dr.py
parentc698e0aa20d28e36d25f87082855a490283f3c88 (diff)
parentf251b4d080a577c2cee890ca43d8ec3658332021 (diff)
merge new unbalanced
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py57
1 files changed, 24 insertions, 33 deletions
diff --git a/ot/dr.py b/ot/dr.py
index d2bf6e2..680dabf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -49,30 +49,25 @@ def split_classes(X, y):
def fda(X, y, p=2, reg=1e-16):
- """
- Fisher Discriminant Analysis
-
+ """Fisher Discriminant Analysis
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (ridge regularization)
-
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
+ proj : callable
projection function including mean centering
-
-
"""
mx = np.mean(X)
@@ -130,37 +125,33 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : str, optional
- None for steepest decsent or 'TrustRegions' for trust regions algorithm
- else shoudl be a pymanopt.solvers
- P0 : numpy.ndarray (d,p)
- Initial starting point for projection
+ solver : None | str, optional
+ None for steepest descent or 'TrustRegions' for trust regions algorithm
+ else should be a pymanopt.solvers
+ P0 : ndarray, shape (d, p)
+ Initial starting point for projection.
verbose : int, optional
- Print information along iterations
-
-
+ Print information along iterations.
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
- projection function including mean centering
-
+ proj : callable
+ Projection function including mean centering.
References
----------
-
- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
mx = np.mean(X)