summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_WDA.py13
-rw-r--r--ot/dr.py73
2 files changed, 78 insertions, 8 deletions
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py
index bbe3888..d2eaf6d 100644
--- a/examples/plot_WDA.py
+++ b/examples/plot_WDA.py
@@ -11,7 +11,7 @@ import numpy as np
import matplotlib.pylab as pl
import ot
from ot.datasets import get_1D_gauss as gauss
-from ot.dr import wda
+from ot.dr import wda, fda
#%% parameters
@@ -36,7 +36,12 @@ pl.legend(loc=0)
pl.title('Discriminant dimensions')
-#%% plot distributions and loss matrix
+#%% Comlpute FDA
+p=2
+
+Pfda,projfda = fda(xs,ys,p)
+
+#%% Compute WDA
p=2
reg=1
k=10
@@ -46,8 +51,8 @@ P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
#%% plot samples
-xsp=proj(xs)
-xtp=proj(xt)
+xsp=projfda(xs)
+xtp=projfda(xt)
pl.figure(1,(10,5))
diff --git a/ot/dr.py b/ot/dr.py
index 9187b57..fdb4daa 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -7,6 +7,7 @@ import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
+import scipy.linalg as la
def dist(x1,x2):
""" Compute squared euclidean distance between samples (autograd)
@@ -32,9 +33,73 @@ def split_classes(X,y):
"""
lstsclass=np.unique(y)
return [X[y==i,:].astype(np.float32) for i in lstsclass]
+
+
+def fda(X,y,p=2,reg=1e-16):
+ """
+ Fisher Discriminant Analysis
+
+ Parameters
+ ----------
+ X : numpy.ndarray (n,d)
+ Training samples
+ y : np.ndarray (n,)
+ labels for training samples
+ p : int, optional
+ size of dimensionnality reduction
+ reg : float, optional
+ Regularization term >0 (ridge regularization)
+ Returns
+ -------
+ P : (d x p) ndarray
+ Optimal transportation matrix for the given parameters
+ proj : fun
+ projection function including mean centering
+
+
+ """
+
+ mx=np.mean(X)
+ X-=mx.reshape((1,-1))
+
+ # data split between classes
+ d=X.shape[1]
+ xc=split_classes(X,y)
+ nc=len(xc)
+
+ p=min(nc-1,p)
+
+ Cw=0
+ for x in xc:
+ Cw+=np.cov(x,rowvar=False)
+ Cw/=nc
+
+ mxc=np.zeros((d,nc))
+
+ for i in range(nc):
+ mxc[:,i]=np.mean(xc[i])
+
+ mx0=np.mean(mxc,1)
+ Cb=0
+ for i in range(nc):
+ Cb+=(mxc[:,i]-mx0).reshape((-1,1))*(mxc[:,i]-mx0).reshape((1,-1))
+
+ w,V=la.eig(Cb,Cw+reg*np.eye(d))
+
+ idx=np.argsort(w.real)
+
+ Popt=V[:,idx[-p:]]
+
+
+
+ def proj(X):
+ return (X-mx.reshape((1,-1))).dot(Popt)
+
+ return Popt, proj
+
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
"""
Wasserstein Discriminant Analysis [11]_
@@ -73,16 +138,13 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
P : (d x p) ndarray
Optimal transportation matrix for the given parameters
proj : fun
- projectiuon function including mean centering
+ projection function including mean centering
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
-
-
"""
@@ -131,3 +193,6 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
return (X-mx.reshape((1,-1))).dot(Popt)
return Popt, proj
+
+
+