summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py156
1 files changed, 145 insertions, 11 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 11d2e10..c2f51f8 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -10,6 +10,7 @@ Dimension reduction with OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -21,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1, x2):
- """ Compute squared euclidean distance between samples (autograd)
+ r""" Compute squared euclidean distance between samples (autograd)
"""
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
@@ -29,7 +30,7 @@ def dist(x1, x2):
def sinkhorn(w1, w2, M, reg, k):
- """Sinkhorn algorithm with fixed number of iteration (autograd)
+ r"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
@@ -42,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k):
def split_classes(X, y):
- """split samples in X by classes in y
+ r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
lstsclass = np.unique(y)
return [X[y == i, :].astype(np.float32) for i in lstsclass]
def fda(X, y, p=2, reg=1e-16):
- """Fisher Discriminant Analysis
+ r"""Fisher Discriminant Analysis
Parameters
----------
@@ -108,20 +109,21 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
- """
- Wasserstein Discriminant Analysis [11]_
+def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+ r"""
+ Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
The function solves the following optimization problem:
.. math::
- P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
+ \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad
+ \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)}
where :
- - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
+ - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold
- :math:`W` is entropic regularized Wasserstein distances
- - :math:`X^i` are samples in the dataset corresponding to class i
+ - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
Parameters
----------
@@ -138,6 +140,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
else should be a pymanopt.solvers
P0 : ndarray, shape (d, p)
Initial starting point for projection.
+ normalize : bool, optional
+ Normalise the Wasserstaiun distance by the average distance on P0 (default : False)
verbose : int, optional
Print information along iterations.
@@ -148,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
proj : callable
Projection function including mean centering.
+
+ .. _references-wda:
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
@@ -163,6 +169,18 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
# compute uniform weighs
wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
+ # pre-compute reg_c,c'
+ if P0 is not None and normalize:
+ regmean = np.zeros((len(xc), len(xc)))
+ for i, xi in enumerate(xc):
+ xi = np.dot(xi, P0)
+ for j, xj in enumerate(xc[i:]):
+ xj = np.dot(xj, P0)
+ M = dist(xi, xj)
+ regmean[i, j] = np.sum(M) / (len(xi) * len(xj))
+ else:
+ regmean = np.ones((len(xc), len(xc)))
+
def cost(P):
# wda loss
loss_b = 0
@@ -173,7 +191,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg, k)
+ G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
@@ -198,3 +216,119 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
+
+
+def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
+ r"""
+ Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j}
+ \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi)
+
+ - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold
+ - :math:`H(\pi)` is entropy regularizer
+ - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Samples from measure :math:`\mu`
+ Y : ndarray, shape (n, d)
+ Samples from measure :math:`\nu`
+ a : ndarray, shape (n, )
+ weights for measure :math:`\mu`
+ b : ndarray, shape (n, )
+ weights for measure :math:`\nu`
+ tau : float
+ stepsize for Riemannian Gradient Descent
+ U0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ k : int
+ Subspace dimension
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ pi : ndarray, shape (n, n)
+ Optimal transportation matrix for the given parameters
+ U : ndarray, shape (d, k)
+ Projection operator.
+
+
+ .. _references-projection-robust-wasserstein:
+ References
+ ----------
+ .. [32] Huang, M. , Ma S. & Lai L. (2021).
+ A Riemannian Block Coordinate Descent Method for Computing
+ the Projection Robust Wasserstein Distance, ICML.
+ """ # noqa
+
+ # initialization
+ n, d = X.shape
+ m, d = Y.shape
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ u = np.ones(n) / n
+ v = np.ones(m) / m
+ ones = np.ones((n, m))
+
+ assert d > k
+
+ if U0 is None:
+ U = np.random.randn(d, k)
+ U, _ = np.linalg.qr(U)
+ else:
+ U = U0
+
+ def Vpi(X, Y, a, b, pi):
+ # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
+ A = X.T.dot(pi).dot(Y)
+ return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T
+
+ err = 1
+ iter = 0
+
+ while err > stopThr and iter < maxiter:
+
+ # Projected cost matrix
+ UUT = U.dot(U.T)
+ M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
+ np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))
+
+ A = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=A)
+ np.exp(A, out=A)
+
+ # Sinkhorn update
+ Ap = (1 / a).reshape(-1, 1) * A
+ AtransposeU = np.dot(A.T, u)
+ v = np.divide(b, AtransposeU)
+ u = 1. / np.dot(Ap, v)
+ pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))
+
+ V = Vpi(X, Y, a, b, pi)
+
+ # Riemannian gradient descent
+ G = 2 / reg * V.dot(U)
+ GTU = G.T.dot(U)
+ xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
+ U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition
+
+ grad_norm = np.linalg.norm(xi)
+ err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))
+
+ f_val = np.trace(U.T.dot(V.dot(U)))
+ if verbose:
+ print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)
+
+ iter = iter + 1
+
+ return pi, U