summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-12 14:32:02 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-12 14:32:02 +0200
commitfaa4744597f93e7005bd48729441562e092e3ab6 (patch)
treee37ee8c9296e584dae0655397436258b262c0649 /ot/dr.py
parenta1ae72e46cbe52443a1b00d3b8ffd2e2adc80077 (diff)
add init WDA
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/ot/dr.py b/ot/dr.py
index fdb4daa..763ce35 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -100,7 +100,7 @@ def fda(X,y,p=2,reg=1e-16):
return Popt, proj
-def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
+def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0,P0=None):
"""
Wasserstein Discriminant Analysis [11]_
@@ -127,7 +127,9 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
Regularization term >0 (entropic regularization)
solver : str, optional
None for steepest decsent or 'TrustRegions' for trust regions algorithm
- else shoudl be a pymanopt.sovers
+ else shoudl be a pymanopt.solvers
+ P0 : numpy.ndarray (d,p)
+ Initial starting point for projection
verbose : int, optional
Print information along iterations
@@ -187,7 +189,7 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
elif solver in ['tr','TrustRegions']:
solver= TrustRegions(maxiter=maxiter,logverbosity=verbose)
- Popt = solver.solve(problem)
+ Popt = solver.solve(problem,x=P0)
def proj(X):
return (X-mx.reshape((1,-1))).dot(Popt)