diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-05 17:11:17 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2017-07-05 17:11:17 +0200 |
commit | 05765e26dec18697d9eb60ac5f3a7610af6ae8d2 (patch) | |
tree | 9e429f1a984171f3b899b60ce740a419d9811bef | |
parent | 315d812b18fcb1170b4b84907b87aeaadaa9a196 (diff) |
add FDA for comparison
-rw-r--r-- | examples/plot_WDA.py | 13 | ||||
-rw-r--r-- | ot/dr.py | 73 |
2 files changed, 78 insertions, 8 deletions
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py index bbe3888..d2eaf6d 100644 --- a/examples/plot_WDA.py +++ b/examples/plot_WDA.py @@ -11,7 +11,7 @@ import numpy as np import matplotlib.pylab as pl import ot from ot.datasets import get_1D_gauss as gauss -from ot.dr import wda +from ot.dr import wda, fda #%% parameters @@ -36,7 +36,12 @@ pl.legend(loc=0) pl.title('Discriminant dimensions') -#%% plot distributions and loss matrix +#%% Comlpute FDA +p=2 + +Pfda,projfda = fda(xs,ys,p) + +#%% Compute WDA p=2 reg=1 k=10 @@ -46,8 +51,8 @@ P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter) #%% plot samples -xsp=proj(xs) -xtp=proj(xt) +xsp=projfda(xs) +xtp=projfda(xt) pl.figure(1,(10,5)) @@ -7,6 +7,7 @@ import autograd.numpy as np from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions +import scipy.linalg as la def dist(x1,x2): """ Compute squared euclidean distance between samples (autograd) @@ -32,9 +33,73 @@ def split_classes(X,y): """ lstsclass=np.unique(y) return [X[y==i,:].astype(np.float32) for i in lstsclass] + + +def fda(X,y,p=2,reg=1e-16): + """ + Fisher Discriminant Analysis + + Parameters + ---------- + X : numpy.ndarray (n,d) + Training samples + y : np.ndarray (n,) + labels for training samples + p : int, optional + size of dimensionnality reduction + reg : float, optional + Regularization term >0 (ridge regularization) + Returns + ------- + P : (d x p) ndarray + Optimal transportation matrix for the given parameters + proj : fun + projection function including mean centering + + + """ + + mx=np.mean(X) + X-=mx.reshape((1,-1)) + + # data split between classes + d=X.shape[1] + xc=split_classes(X,y) + nc=len(xc) + + p=min(nc-1,p) + + Cw=0 + for x in xc: + Cw+=np.cov(x,rowvar=False) + Cw/=nc + + mxc=np.zeros((d,nc)) + + for i in range(nc): + mxc[:,i]=np.mean(xc[i]) + + mx0=np.mean(mxc,1) + Cb=0 + for i in range(nc): + Cb+=(mxc[:,i]-mx0).reshape((-1,1))*(mxc[:,i]-mx0).reshape((1,-1)) + + w,V=la.eig(Cb,Cw+reg*np.eye(d)) + + idx=np.argsort(w.real) + + Popt=V[:,idx[-p:]] + + + + def proj(X): + return (X-mx.reshape((1,-1))).dot(Popt) + + return Popt, proj + def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): """ Wasserstein Discriminant Analysis [11]_ @@ -73,16 +138,13 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): P : (d x p) ndarray Optimal transportation matrix for the given parameters proj : fun - projectiuon function including mean centering + projection function including mean centering References ---------- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. - - - """ @@ -131,3 +193,6 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): return (X-mx.reshape((1,-1))).dot(Popt) return Popt, proj + + + |