summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/ot/dr.py b/ot/dr.py
new file mode 100644
index 0000000..680dabf
--- /dev/null
+++ b/ot/dr.py
@@ -0,0 +1,200 @@
+# -*- coding: utf-8 -*-
+"""
+Dimension reduction with optimal transport
+
+
+.. warning::
+ Note that by default the module is not import in :mod:`ot`. In order to
+ use it you need to explicitely import :mod:`ot.dr`
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+from scipy import linalg
+import autograd.numpy as np
+from pymanopt.manifolds import Stiefel
+from pymanopt import Problem
+from pymanopt.solvers import SteepestDescent, TrustRegions
+
+
+def dist(x1, x2):
+ """ Compute squared euclidean distance between samples (autograd)
+ """
+ x1p2 = np.sum(np.square(x1), 1)
+ x2p2 = np.sum(np.square(x2), 1)
+ return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
+
+
+def sinkhorn(w1, w2, M, reg, k):
+ """Sinkhorn algorithm with fixed number of iteration (autograd)
+ """
+ K = np.exp(-M / reg)
+ ui = np.ones((M.shape[0],))
+ vi = np.ones((M.shape[1],))
+ for i in range(k):
+ vi = w2 / (np.dot(K.T, ui))
+ ui = w1 / (np.dot(K, vi))
+ G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
+ return G
+
+
+def split_classes(X, y):
+ """split samples in X by classes in y
+ """
+ lstsclass = np.unique(y)
+ return [X[y == i, :].astype(np.float32) for i in lstsclass]
+
+
+def fda(X, y, p=2, reg=1e-16):
+ """Fisher Discriminant Analysis
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
+ p : int, optional
+ Size of dimensionnality reduction.
+ reg : float, optional
+ Regularization term >0 (ridge regularization)
+
+ Returns
+ -------
+ P : ndarray, shape (d, p)
+ Optimal transportation matrix for the given parameters
+ proj : callable
+ projection function including mean centering
+ """
+
+ mx = np.mean(X)
+ X -= mx.reshape((1, -1))
+
+ # data split between classes
+ d = X.shape[1]
+ xc = split_classes(X, y)
+ nc = len(xc)
+
+ p = min(nc - 1, p)
+
+ Cw = 0
+ for x in xc:
+ Cw += np.cov(x, rowvar=False)
+ Cw /= nc
+
+ mxc = np.zeros((d, nc))
+
+ for i in range(nc):
+ mxc[:, i] = np.mean(xc[i])
+
+ mx0 = np.mean(mxc, 1)
+ Cb = 0
+ for i in range(nc):
+ Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \
+ (mxc[:, i] - mx0).reshape((1, -1))
+
+ w, V = linalg.eig(Cb, Cw + reg * np.eye(d))
+
+ idx = np.argsort(w.real)
+
+ Popt = V[:, idx[-p:]]
+
+ def proj(X):
+ return (X - mx.reshape((1, -1))).dot(Popt)
+
+ return Popt, proj
+
+
+def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
+ """
+ Wasserstein Discriminant Analysis [11]_
+
+ The function solves the following optimization problem:
+
+ .. math::
+ P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
+
+ where :
+
+ - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
+ - :math:`W` is entropic regularized Wasserstein distances
+ - :math:`X^i` are samples in the dataset corresponding to class i
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
+ p : int, optional
+ Size of dimensionnality reduction.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ solver : None | str, optional
+ None for steepest descent or 'TrustRegions' for trust regions algorithm
+ else should be a pymanopt.solvers
+ P0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ P : ndarray, shape (d, p)
+ Optimal transportation matrix for the given parameters
+ proj : callable
+ Projection function including mean centering.
+
+ References
+ ----------
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
+ """ # noqa
+
+ mx = np.mean(X)
+ X -= mx.reshape((1, -1))
+
+ # data split between classes
+ d = X.shape[1]
+ xc = split_classes(X, y)
+ # compute uniform weighs
+ wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
+
+ def cost(P):
+ # wda loss
+ loss_b = 0
+ loss_w = 0
+
+ for i, xi in enumerate(xc):
+ xi = np.dot(xi, P)
+ for j, xj in enumerate(xc[i:]):
+ xj = np.dot(xj, P)
+ M = dist(xi, xj)
+ G = sinkhorn(wc[i], wc[j + i], M, reg, k)
+ if j == 0:
+ loss_w += np.sum(G * M)
+ else:
+ loss_b += np.sum(G * M)
+
+ # loss inversed because minimization
+ return loss_w / loss_b
+
+ # declare manifold and problem
+ manifold = Stiefel(d, p)
+ problem = Problem(manifold=manifold, cost=cost)
+
+ # declare solver and solve
+ if solver is None:
+ solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ elif solver in ['tr', 'TrustRegions']:
+ solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+
+ Popt = solver.solve(problem, x=P0)
+
+ def proj(X):
+ return (X - mx.reshape((1, -1))).dot(Popt)
+
+ return Popt, proj