diff options
author | Jakub Zadrożny <jakub.r.zadrozny@gmail.com> | 2022-01-21 08:50:19 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-21 08:50:19 +0100 |
commit | d7c709e2bae3bafec9efad87e758919c8db61933 (patch) | |
tree | 867cc38b6f782207b5045751be2c0f0e2824b6af | |
parent | 263c5842664c1dff4f8e58111d6bddb33927539e (diff) |
[MRG] Implement Sinkhorn in log-domain for WDA (#336)
* [MRG] Implement Sinkhorn in log-domain for WDA
* for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs)
* this can be resolved by using log-domain implementation of the sinkhorn algorithm
* Add feature to RELEASES and contributor name
* Add 'sinkhorn_method' parameter to WDA
* use the standard Sinkhorn solver by default (faster)
* use log-domain Sinkhorn if asked by the user
Co-authored-by: Jakub Zadrożny <jz@qed.ai>
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 2 | ||||
-rw-r--r-- | ot/dr.py | 44 | ||||
-rw-r--r-- | test/test_dr.py | 22 |
3 files changed, 66 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md index c6ab9c3..a5fcbe1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,8 @@ #### New features - Better list of related examples in quick start guide with `minigallery` (PR #334) +- Add optional log-domain Sinkhorn implementation in WDA to support smaller values + of the regularization parameter (PR #336) #### Closed issues @@ -11,6 +11,7 @@ Dimension reduction with OT # Author: Remi Flamary <remi.flamary@unice.fr> # Minhui Huang <mhhuang@ucdavis.edu> +# Jakub Zadrozny <jakub.r.zadrozny@gmail.com> # # License: MIT License @@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k): return G +def logsumexp(M, axis): + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) + """ + amax = np.amax(M, axis=axis, keepdims=True) + return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) + + +def sinkhorn_log(w1, w2, M, reg, k): + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) + """ + Mr = -M / reg + ui = np.zeros((M.shape[0],)) + vi = np.zeros((M.shape[1],)) + log_w1 = np.log(w1) + log_w2 = np.log(w2) + for i in range(k): + vi = log_w2 - logsumexp(Mr + ui[:, None], 0) + ui = log_w1 - logsumexp(Mr + vi[None, :], 1) + G = np.exp(ui[:, None] + Mr + vi[None, :]) + return G + + def split_classes(X, y): r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ @@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): +def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis :ref:`[11] <references-wda>` @@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no - :math:`W` is entropic regularized Wasserstein distances - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sparse cost matrices, you should use the + :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical + errors, but can be slow in practice. + Parameters ---------- X : ndarray, shape (n, d) @@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers + sinkhorn_method : str + method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log' P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional @@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa + if sinkhorn_method.lower() == 'sinkhorn': + sinkhorn_solver = sinkhorn + elif sinkhorn_method.lower() == 'sinkhorn_log': + sinkhorn_solver = sinkhorn_log + else: + raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) + mx = np.mean(X) X -= mx.reshape((1, -1)) @@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) + G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: diff --git a/test/test_dr.py b/test/test_dr.py index 741f2ad..6d7fc9a 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -61,6 +61,28 @@ def test_wda(): @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_low_reg(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log') + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_normalized(): n_samples = 100 # nb samples in source and target datasets |