diff options
author | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2020-01-20 14:07:53 +0100 |
commit | bdfb24ff37ea777d6e266b145047cd4e281ebac3 (patch) | |
tree | 00cbac5f3dc25a4ee76164828abd72c1cbab37cc /ot/dr.py | |
parent | abc441b00f0fe2fa4ef0efc4e1aa67b27cca9a13 (diff) | |
parent | 5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff) |
Merge tag '0.6.0' into debian/sid
Diffstat (limited to 'ot/dr.py')
-rw-r--r-- | ot/dr.py | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/ot/dr.py b/ot/dr.py new file mode 100644 index 0000000..680dabf --- /dev/null +++ b/ot/dr.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +""" +Dimension reduction with optimal transport + + +.. warning:: + Note that by default the module is not import in :mod:`ot`. In order to + use it you need to explicitely import :mod:`ot.dr` + +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License + +from scipy import linalg +import autograd.numpy as np +from pymanopt.manifolds import Stiefel +from pymanopt import Problem +from pymanopt.solvers import SteepestDescent, TrustRegions + + +def dist(x1, x2): + """ Compute squared euclidean distance between samples (autograd) + """ + x1p2 = np.sum(np.square(x1), 1) + x2p2 = np.sum(np.square(x2), 1) + return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T) + + +def sinkhorn(w1, w2, M, reg, k): + """Sinkhorn algorithm with fixed number of iteration (autograd) + """ + K = np.exp(-M / reg) + ui = np.ones((M.shape[0],)) + vi = np.ones((M.shape[1],)) + for i in range(k): + vi = w2 / (np.dot(K.T, ui)) + ui = w1 / (np.dot(K, vi)) + G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1])) + return G + + +def split_classes(X, y): + """split samples in X by classes in y + """ + lstsclass = np.unique(y) + return [X[y == i, :].astype(np.float32) for i in lstsclass] + + +def fda(X, y, p=2, reg=1e-16): + """Fisher Discriminant Analysis + + Parameters + ---------- + X : ndarray, shape (n, d) + Training samples. + y : ndarray, shape (n,) + Labels for training samples. + p : int, optional + Size of dimensionnality reduction. + reg : float, optional + Regularization term >0 (ridge regularization) + + Returns + ------- + P : ndarray, shape (d, p) + Optimal transportation matrix for the given parameters + proj : callable + projection function including mean centering + """ + + mx = np.mean(X) + X -= mx.reshape((1, -1)) + + # data split between classes + d = X.shape[1] + xc = split_classes(X, y) + nc = len(xc) + + p = min(nc - 1, p) + + Cw = 0 + for x in xc: + Cw += np.cov(x, rowvar=False) + Cw /= nc + + mxc = np.zeros((d, nc)) + + for i in range(nc): + mxc[:, i] = np.mean(xc[i]) + + mx0 = np.mean(mxc, 1) + Cb = 0 + for i in range(nc): + Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \ + (mxc[:, i] - mx0).reshape((1, -1)) + + w, V = linalg.eig(Cb, Cw + reg * np.eye(d)) + + idx = np.argsort(w.real) + + Popt = V[:, idx[-p:]] + + def proj(X): + return (X - mx.reshape((1, -1))).dot(Popt) + + return Popt, proj + + +def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): + """ + Wasserstein Discriminant Analysis [11]_ + + The function solves the following optimization problem: + + .. math:: + P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} + + where : + + - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold + - :math:`W` is entropic regularized Wasserstein distances + - :math:`X^i` are samples in the dataset corresponding to class i + + Parameters + ---------- + X : ndarray, shape (n, d) + Training samples. + y : ndarray, shape (n,) + Labels for training samples. + p : int, optional + Size of dimensionnality reduction. + reg : float, optional + Regularization term >0 (entropic regularization) + solver : None | str, optional + None for steepest descent or 'TrustRegions' for trust regions algorithm + else should be a pymanopt.solvers + P0 : ndarray, shape (d, p) + Initial starting point for projection. + verbose : int, optional + Print information along iterations. + + Returns + ------- + P : ndarray, shape (d, p) + Optimal transportation matrix for the given parameters + proj : callable + Projection function including mean centering. + + References + ---------- + .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). + Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. + """ # noqa + + mx = np.mean(X) + X -= mx.reshape((1, -1)) + + # data split between classes + d = X.shape[1] + xc = split_classes(X, y) + # compute uniform weighs + wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc] + + def cost(P): + # wda loss + loss_b = 0 + loss_w = 0 + + for i, xi in enumerate(xc): + xi = np.dot(xi, P) + for j, xj in enumerate(xc[i:]): + xj = np.dot(xj, P) + M = dist(xi, xj) + G = sinkhorn(wc[i], wc[j + i], M, reg, k) + if j == 0: + loss_w += np.sum(G * M) + else: + loss_b += np.sum(G * M) + + # loss inversed because minimization + return loss_w / loss_b + + # declare manifold and problem + manifold = Stiefel(d, p) + problem = Problem(manifold=manifold, cost=cost) + + # declare solver and solve + if solver is None: + solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose) + elif solver in ['tr', 'TrustRegions']: + solver = TrustRegions(maxiter=maxiter, logverbosity=verbose) + + Popt = solver.solve(problem, x=P0) + + def proj(X): + return (X - mx.reshape((1, -1))).dot(Popt) + + return Popt, proj |