summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py201
1 files changed, 102 insertions, 99 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 763ce35..d30ab30 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -3,43 +3,50 @@
Dimension reduction with optimal transport
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+from scipy import linalg
import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
-import scipy.linalg as la
-def dist(x1,x2):
+
+def dist(x1, x2):
""" Compute squared euclidean distance between samples (autograd)
"""
- x1p2=np.sum(np.square(x1),1)
- x2p2=np.sum(np.square(x2),1)
- return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*np.dot(x1,x2.T)
+ x1p2 = np.sum(np.square(x1), 1)
+ x2p2 = np.sum(np.square(x2), 1)
+ return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
+
-def sinkhorn(w1,w2,M,reg,k):
+def sinkhorn(w1, w2, M, reg, k):
"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
- K=np.exp(-M/reg)
- ui=np.ones((M.shape[0],))
- vi=np.ones((M.shape[1],))
+ K = np.exp(-M / reg)
+ ui = np.ones((M.shape[0],))
+ vi = np.ones((M.shape[1],))
for i in range(k):
- vi=w2/(np.dot(K.T,ui))
- ui=w1/(np.dot(K,vi))
- G=ui.reshape((M.shape[0],1))*K*vi.reshape((1,M.shape[1]))
+ vi = w2 / (np.dot(K.T, ui))
+ ui = w1 / (np.dot(K, vi))
+ G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
-def split_classes(X,y):
+
+def split_classes(X, y):
"""split samples in X by classes in y
"""
- lstsclass=np.unique(y)
- return [X[y==i,:].astype(np.float32) for i in lstsclass]
+ lstsclass = np.unique(y)
+ return [X[y == i, :].astype(np.float32) for i in lstsclass]
+
+
+def fda(X, y, p=2, reg=1e-16):
+ """
+ Fisher Discriminant Analysis
-def fda(X,y,p=2,reg=1e-16):
- """
- Fisher Discriminant Analysis
-
-
Parameters
----------
X : numpy.ndarray (n,d)
@@ -59,62 +66,62 @@ def fda(X,y,p=2,reg=1e-16):
proj : fun
projection function including mean centering
-
- """
-
- mx=np.mean(X)
- X-=mx.reshape((1,-1))
-
+
+ """
+
+ mx = np.mean(X)
+ X -= mx.reshape((1, -1))
+
# data split between classes
- d=X.shape[1]
- xc=split_classes(X,y)
- nc=len(xc)
-
- p=min(nc-1,p)
-
- Cw=0
+ d = X.shape[1]
+ xc = split_classes(X, y)
+ nc = len(xc)
+
+ p = min(nc - 1, p)
+
+ Cw = 0
for x in xc:
- Cw+=np.cov(x,rowvar=False)
- Cw/=nc
-
- mxc=np.zeros((d,nc))
-
+ Cw += np.cov(x, rowvar=False)
+ Cw /= nc
+
+ mxc = np.zeros((d, nc))
+
for i in range(nc):
- mxc[:,i]=np.mean(xc[i])
-
- mx0=np.mean(mxc,1)
- Cb=0
+ mxc[:, i] = np.mean(xc[i])
+
+ mx0 = np.mean(mxc, 1)
+ Cb = 0
for i in range(nc):
- Cb+=(mxc[:,i]-mx0).reshape((-1,1))*(mxc[:,i]-mx0).reshape((1,-1))
-
- w,V=la.eig(Cb,Cw+reg*np.eye(d))
-
- idx=np.argsort(w.real)
-
- Popt=V[:,idx[-p:]]
-
-
-
+ Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \
+ (mxc[:, i] - mx0).reshape((1, -1))
+
+ w, V = linalg.eig(Cb, Cw + reg * np.eye(d))
+
+ idx = np.argsort(w.real)
+
+ Popt = V[:, idx[-p:]]
+
def proj(X):
- return (X-mx.reshape((1,-1))).dot(Popt)
-
+ return (X - mx.reshape((1, -1))).dot(Popt)
+
return Popt, proj
-def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0,P0=None):
- """
+
+def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
+ """
Wasserstein Discriminant Analysis [11]_
-
+
The function solves the following optimization problem:
.. math::
P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
where :
-
+
- :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
- :math:`W` is entropic regularized Wasserstein distances
- - :math:`X^i` are samples in the dataset corresponding to class i
-
+ - :math:`X^i` are samples in the dataset corresponding to class i
+
Parameters
----------
X : numpy.ndarray (n,d)
@@ -147,54 +154,50 @@ def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0,P0=None):
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
- """
-
- mx=np.mean(X)
- X-=mx.reshape((1,-1))
-
+
+ """ # noqa
+
+ mx = np.mean(X)
+ X -= mx.reshape((1, -1))
+
# data split between classes
- d=X.shape[1]
- xc=split_classes(X,y)
+ d = X.shape[1]
+ xc = split_classes(X, y)
# compute uniform weighs
- wc=[np.ones((x.shape[0]),dtype=np.float32)/x.shape[0] for x in xc]
-
+ wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
+
def cost(P):
# wda loss
- loss_b=0
- loss_w=0
-
- for i,xi in enumerate(xc):
- xi=np.dot(xi,P)
- for j,xj in enumerate(xc[i:]):
- xj=np.dot(xj,P)
- M=dist(xi,xj)
- G=sinkhorn(wc[i],wc[j+i],M,reg,k)
- if j==0:
- loss_w+=np.sum(G*M)
+ loss_b = 0
+ loss_w = 0
+
+ for i, xi in enumerate(xc):
+ xi = np.dot(xi, P)
+ for j, xj in enumerate(xc[i:]):
+ xj = np.dot(xj, P)
+ M = dist(xi, xj)
+ G = sinkhorn(wc[i], wc[j + i], M, reg, k)
+ if j == 0:
+ loss_w += np.sum(G * M)
else:
- loss_b+=np.sum(G*M)
-
- # loss inversed because minimization
- return loss_w/loss_b
-
-
+ loss_b += np.sum(G * M)
+
+ # loss inversed because minimization
+ return loss_w / loss_b
+
# declare manifold and problem
- manifold = Stiefel(d, p)
+ manifold = Stiefel(d, p)
problem = Problem(manifold=manifold, cost=cost)
-
+
# declare solver and solve
if solver is None:
- solver= SteepestDescent(maxiter=maxiter,logverbosity=verbose)
- elif solver in ['tr','TrustRegions']:
- solver= TrustRegions(maxiter=maxiter,logverbosity=verbose)
-
- Popt = solver.solve(problem,x=P0)
-
- def proj(X):
- return (X-mx.reshape((1,-1))).dot(Popt)
-
- return Popt, proj
+ solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ elif solver in ['tr', 'TrustRegions']:
+ solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+ Popt = solver.solve(problem, x=P0)
+ def proj(X):
+ return (X - mx.reshape((1, -1))).dot(Popt)
+ return Popt, proj