summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-05 17:11:17 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-05 17:11:17 +0200
commit05765e26dec18697d9eb60ac5f3a7610af6ae8d2 (patch)
tree9e429f1a984171f3b899b60ce740a419d9811bef /ot/dr.py
parent315d812b18fcb1170b4b84907b87aeaadaa9a196 (diff)
add FDA for comparison
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py73
1 files changed, 69 insertions, 4 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 9187b57..fdb4daa 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -7,6 +7,7 @@ import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
+import scipy.linalg as la
def dist(x1,x2):
""" Compute squared euclidean distance between samples (autograd)
@@ -32,9 +33,73 @@ def split_classes(X,y):
"""
lstsclass=np.unique(y)
return [X[y==i,:].astype(np.float32) for i in lstsclass]
+
+
+def fda(X,y,p=2,reg=1e-16):
+ """
+ Fisher Discriminant Analysis
+
+ Parameters
+ ----------
+ X : numpy.ndarray (n,d)
+ Training samples
+ y : np.ndarray (n,)
+ labels for training samples
+ p : int, optional
+ size of dimensionnality reduction
+ reg : float, optional
+ Regularization term >0 (ridge regularization)
+ Returns
+ -------
+ P : (d x p) ndarray
+ Optimal transportation matrix for the given parameters
+ proj : fun
+ projection function including mean centering
+
+
+ """
+
+ mx=np.mean(X)
+ X-=mx.reshape((1,-1))
+
+ # data split between classes
+ d=X.shape[1]
+ xc=split_classes(X,y)
+ nc=len(xc)
+
+ p=min(nc-1,p)
+
+ Cw=0
+ for x in xc:
+ Cw+=np.cov(x,rowvar=False)
+ Cw/=nc
+
+ mxc=np.zeros((d,nc))
+
+ for i in range(nc):
+ mxc[:,i]=np.mean(xc[i])
+
+ mx0=np.mean(mxc,1)
+ Cb=0
+ for i in range(nc):
+ Cb+=(mxc[:,i]-mx0).reshape((-1,1))*(mxc[:,i]-mx0).reshape((1,-1))
+
+ w,V=la.eig(Cb,Cw+reg*np.eye(d))
+
+ idx=np.argsort(w.real)
+
+ Popt=V[:,idx[-p:]]
+
+
+
+ def proj(X):
+ return (X-mx.reshape((1,-1))).dot(Popt)
+
+ return Popt, proj
+
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
"""
Wasserstein Discriminant Analysis [11]_
@@ -73,16 +138,13 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
P : (d x p) ndarray
Optimal transportation matrix for the given parameters
proj : fun
- projectiuon function including mean centering
+ projection function including mean centering
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
-
-
"""
@@ -131,3 +193,6 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
return (X-mx.reshape((1,-1))).dot(Popt)
return Popt, proj
+
+
+