From d7c709e2bae3bafec9efad87e758919c8db61933 Mon Sep 17 00:00:00 2001 From: Jakub Zadrożny Date: Fri, 21 Jan 2022 08:50:19 +0100 Subject: [MRG] Implement Sinkhorn in log-domain for WDA (#336) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [MRG] Implement Sinkhorn in log-domain for WDA * for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs) * this can be resolved by using log-domain implementation of the sinkhorn algorithm * Add feature to RELEASES and contributor name * Add 'sinkhorn_method' parameter to WDA * use the standard Sinkhorn solver by default (faster) * use log-domain Sinkhorn if asked by the user Co-authored-by: Jakub Zadrożny Co-authored-by: Rémi Flamary --- ot/dr.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) (limited to 'ot') diff --git a/ot/dr.py b/ot/dr.py index 1671ca0..0955c55 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -11,6 +11,7 @@ Dimension reduction with OT # Author: Remi Flamary # Minhui Huang +# Jakub Zadrozny # # License: MIT License @@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k): return G +def logsumexp(M, axis): + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) + """ + amax = np.amax(M, axis=axis, keepdims=True) + return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) + + +def sinkhorn_log(w1, w2, M, reg, k): + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) + """ + Mr = -M / reg + ui = np.zeros((M.shape[0],)) + vi = np.zeros((M.shape[1],)) + log_w1 = np.log(w1) + log_w2 = np.log(w2) + for i in range(k): + vi = log_w2 - logsumexp(Mr + ui[:, None], 0) + ui = log_w1 - logsumexp(Mr + vi[None, :], 1) + G = np.exp(ui[:, None] + Mr + vi[None, :]) + return G + + def split_classes(X, y): r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ @@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): +def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis :ref:`[11] ` @@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no - :math:`W` is entropic regularized Wasserstein distances - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sparse cost matrices, you should use the + :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical + errors, but can be slow in practice. + Parameters ---------- X : ndarray, shape (n, d) @@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers + sinkhorn_method : str + method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log' P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional @@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa + if sinkhorn_method.lower() == 'sinkhorn': + sinkhorn_solver = sinkhorn + elif sinkhorn_method.lower() == 'sinkhorn_log': + sinkhorn_solver = sinkhorn_log + else: + raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) + mx = np.mean(X) X -= mx.reshape((1, -1)) @@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) + G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: -- cgit v1.2.3