diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-12 14:32:02 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-12 14:32:02 +0200 |
commit | faa4744597f93e7005bd48729441562e092e3ab6 (patch) | |
tree | e37ee8c9296e584dae0655397436258b262c0649 | |
parent | a1ae72e46cbe52443a1b00d3b8ffd2e2adc80077 (diff) |
add init WDA
-rw-r--r-- | ot/dr.py | 8 |
1 files changed, 5 insertions, 3 deletions
@@ -100,7 +100,7 @@ def fda(X,y,p=2,reg=1e-16): return Popt, proj -def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): +def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0,P0=None): """ Wasserstein Discriminant Analysis [11]_ @@ -127,7 +127,9 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): Regularization term >0 (entropic regularization) solver : str, optional None for steepest decsent or 'TrustRegions' for trust regions algorithm - else shoudl be a pymanopt.sovers + else shoudl be a pymanopt.solvers + P0 : numpy.ndarray (d,p) + Initial starting point for projection verbose : int, optional Print information along iterations @@ -187,7 +189,7 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): elif solver in ['tr','TrustRegions']: solver= TrustRegions(maxiter=maxiter,logverbosity=verbose) - Popt = solver.solve(problem) + Popt = solver.solve(problem,x=P0) def proj(X): return (X-mx.reshape((1,-1))).dot(Popt) |