summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /ot/dr.py
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py44
1 files changed, 42 insertions, 2 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 1671ca0..0955c55 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -11,6 +11,7 @@ Dimension reduction with OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
+# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
#
# License: MIT License
@@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
return G
+def logsumexp(M, axis):
+ r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
+ """
+ amax = np.amax(M, axis=axis, keepdims=True)
+ return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)
+
+
+def sinkhorn_log(w1, w2, M, reg, k):
+ r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
+ """
+ Mr = -M / reg
+ ui = np.zeros((M.shape[0],))
+ vi = np.zeros((M.shape[1],))
+ log_w1 = np.log(w1)
+ log_w2 = np.log(w2)
+ for i in range(k):
+ vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
+ ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
+ G = np.exp(ui[:, None] + Mr + vi[None, :])
+ return G
+
+
def split_classes(X, y):
r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
@@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False):
r"""
Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
@@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
- :math:`W` is entropic regularized Wasserstein distances
- :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sparse cost matrices, you should use the
+ :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
+ errors, but can be slow in practice.
+
Parameters
----------
X : ndarray, shape (n, d)
@@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
+ sinkhorn_method : str
+ method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
P0 : ndarray, shape (d, p)
Initial starting point for projection.
normalize : bool, optional
@@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
+ if sinkhorn_method.lower() == 'sinkhorn':
+ sinkhorn_solver = sinkhorn
+ elif sinkhorn_method.lower() == 'sinkhorn_log':
+ sinkhorn_solver = sinkhorn_log
+ else:
+ raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method)
+
mx = np.mean(X)
X -= mx.reshape((1, -1))
@@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
+ G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else: