summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:24:08 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:24:08 +0200
commit461d269538196a679e65aa3804c7ab88dce6dfaa (patch)
treed155cd2b2053711d67ebc2bdcd4f3e0f4e124ab0 /ot/dr.py
parent3cc99e6590fa87ae8705fc93315590b27bf84efc (diff)
doc wda
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py45
1 files changed, 16 insertions, 29 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 3965149..14a92c1 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -54,41 +54,28 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,)
- samples in the target domain
- M : np.ndarray (ns,nt)
- loss matrix
- reg : float
- Regularization term >0
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshol on error (>0)
- verbose : bool, optional
+ X : numpy.ndarray (n,d)
+ Training samples
+ y : np.ndarray (n,)
+ labels for training samples
+ p : int, optional
+ size of dimensionnality reduction
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ solver : str, optional
+ None for steepest decsent or 'TrustRegions' for trust regions algorithm
+ else shoudl be a pymanopt.sovers
+ verbose : int, optional
Print information along iterations
- log : bool, optional
- record log if True
+
Returns
-------
- gamma : (ns x nt) ndarray
+ P : (d x p) ndarray
Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
- Examples
- --------
-
- >>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
+ proj : fun
+ projectiuon function including mean centering
References