From d7c709e2bae3bafec9efad87e758919c8db61933 Mon Sep 17 00:00:00 2001 From: Jakub Zadrożny Date: Fri, 21 Jan 2022 08:50:19 +0100 Subject: [MRG] Implement Sinkhorn in log-domain for WDA (#336) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [MRG] Implement Sinkhorn in log-domain for WDA * for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs) * this can be resolved by using log-domain implementation of the sinkhorn algorithm * Add feature to RELEASES and contributor name * Add 'sinkhorn_method' parameter to WDA * use the standard Sinkhorn solver by default (faster) * use log-domain Sinkhorn if asked by the user Co-authored-by: Jakub Zadrożny Co-authored-by: Rémi Flamary --- RELEASES.md | 2 ++ ot/dr.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- test/test_dr.py | 22 ++++++++++++++++++++++ 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index c6ab9c3..a5fcbe1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,8 @@ #### New features - Better list of related examples in quick start guide with `minigallery` (PR #334) +- Add optional log-domain Sinkhorn implementation in WDA to support smaller values + of the regularization parameter (PR #336) #### Closed issues diff --git a/ot/dr.py b/ot/dr.py index 1671ca0..0955c55 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -11,6 +11,7 @@ Dimension reduction with OT # Author: Remi Flamary # Minhui Huang +# Jakub Zadrozny # # License: MIT License @@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k): return G +def logsumexp(M, axis): + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) + """ + amax = np.amax(M, axis=axis, keepdims=True) + return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) + + +def sinkhorn_log(w1, w2, M, reg, k): + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) + """ + Mr = -M / reg + ui = np.zeros((M.shape[0],)) + vi = np.zeros((M.shape[1],)) + log_w1 = np.log(w1) + log_w2 = np.log(w2) + for i in range(k): + vi = log_w2 - logsumexp(Mr + ui[:, None], 0) + ui = log_w1 - logsumexp(Mr + vi[None, :], 1) + G = np.exp(ui[:, None] + Mr + vi[None, :]) + return G + + def split_classes(X, y): r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ @@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): +def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis :ref:`[11] ` @@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no - :math:`W` is entropic regularized Wasserstein distances - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sparse cost matrices, you should use the + :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical + errors, but can be slow in practice. + Parameters ---------- X : ndarray, shape (n, d) @@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers + sinkhorn_method : str + method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log' P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional @@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa + if sinkhorn_method.lower() == 'sinkhorn': + sinkhorn_solver = sinkhorn + elif sinkhorn_method.lower() == 'sinkhorn_log': + sinkhorn_solver = sinkhorn_log + else: + raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) + mx = np.mean(X) X -= mx.reshape((1, -1)) @@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) + G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: diff --git a/test/test_dr.py b/test/test_dr.py index 741f2ad..6d7fc9a 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -60,6 +60,28 @@ def test_wda(): np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_low_reg(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log') + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_normalized(): -- cgit v1.2.3