summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJakub Zadrożny <jakub.r.zadrozny@gmail.com>2022-01-21 08:50:19 +0100
committerGitHub <noreply@github.com>2022-01-21 08:50:19 +0100
commitd7c709e2bae3bafec9efad87e758919c8db61933 (patch)
tree867cc38b6f782207b5045751be2c0f0e2824b6af
parent263c5842664c1dff4f8e58111d6bddb33927539e (diff)
[MRG] Implement Sinkhorn in log-domain for WDA (#336)
* [MRG] Implement Sinkhorn in log-domain for WDA * for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs) * this can be resolved by using log-domain implementation of the sinkhorn algorithm * Add feature to RELEASES and contributor name * Add 'sinkhorn_method' parameter to WDA * use the standard Sinkhorn solver by default (faster) * use log-domain Sinkhorn if asked by the user Co-authored-by: Jakub Zadrożny <jz@qed.ai> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md2
-rw-r--r--ot/dr.py44
-rw-r--r--test/test_dr.py22
3 files changed, 66 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md
index c6ab9c3..a5fcbe1 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,6 +5,8 @@
#### New features
- Better list of related examples in quick start guide with `minigallery` (PR #334)
+- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
+ of the regularization parameter (PR #336)
#### Closed issues
diff --git a/ot/dr.py b/ot/dr.py
index 1671ca0..0955c55 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -11,6 +11,7 @@ Dimension reduction with OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
+# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
#
# License: MIT License
@@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
return G
+def logsumexp(M, axis):
+ r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
+ """
+ amax = np.amax(M, axis=axis, keepdims=True)
+ return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)
+
+
+def sinkhorn_log(w1, w2, M, reg, k):
+ r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
+ """
+ Mr = -M / reg
+ ui = np.zeros((M.shape[0],))
+ vi = np.zeros((M.shape[1],))
+ log_w1 = np.log(w1)
+ log_w2 = np.log(w2)
+ for i in range(k):
+ vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
+ ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
+ G = np.exp(ui[:, None] + Mr + vi[None, :])
+ return G
+
+
def split_classes(X, y):
r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
@@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False):
r"""
Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
@@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
- :math:`W` is entropic regularized Wasserstein distances
- :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sparse cost matrices, you should use the
+ :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
+ errors, but can be slow in practice.
+
Parameters
----------
X : ndarray, shape (n, d)
@@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
+ sinkhorn_method : str
+ method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
P0 : ndarray, shape (d, p)
Initial starting point for projection.
normalize : bool, optional
@@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
+ if sinkhorn_method.lower() == 'sinkhorn':
+ sinkhorn_solver = sinkhorn
+ elif sinkhorn_method.lower() == 'sinkhorn_log':
+ sinkhorn_solver = sinkhorn_log
+ else:
+ raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method)
+
mx = np.mean(X)
X -= mx.reshape((1, -1))
@@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
+ G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
diff --git a/test/test_dr.py b/test/test_dr.py
index 741f2ad..6d7fc9a 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -61,6 +61,28 @@ def test_wda():
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_wda_low_reg():
+
+ n_samples = 100 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 2
+
+ Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log')
+
+ projwda(xs)
+
+ np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_normalized():
n_samples = 100 # nb samples in source and target datasets