summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati@inria.fr>2019-07-22 14:53:45 +0200
committerHicham Janati <hicham.janati@inria.fr>2019-07-22 14:53:45 +0200
commit10accb13c2f22c946b65b249d7aae6e4f6af7579 (patch)
treea8b9411120499eb3ce38f88b72a0f1fc8ee1f4b0 /ot
parent0d23718409b1f0ac41b9302d98ca3d1ab9577855 (diff)
add unbalanced with stabilization
Diffstat (limited to 'ot')
-rw-r--r--ot/unbalanced.py279
1 files changed, 245 insertions, 34 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index f6c2d5f..ca24e8b 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -9,10 +9,12 @@ Regularized Unbalanced OT
from __future__ import division
import warnings
import numpy as np
+from scipy.misc import logsumexp
+
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
+def sinkhorn_unbalanced(a, b, M, reg, mu, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the unbalanced entropic regularization optimal transport problem and return the loss
@@ -20,7 +22,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
@@ -45,11 +47,11 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ mu : float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -95,22 +97,29 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, mu,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ elif method.lower() == 'sinkhorn_stabilized':
+ def sink():
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, mu,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
@@ -120,7 +129,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
return sink()
-def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
+def sinkhorn_unbalanced2(a, b, M, reg, mu, method='sinkhorn',
numItermax=1000, stopThr=1e-9, verbose=False,
log=False, **kwargs):
r"""
@@ -129,7 +138,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
@@ -154,11 +163,11 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ mu : float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -203,22 +212,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, mu,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ elif method.lower() == 'sinkhorn_stabilized':
+ def sink():
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, mu,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, mu,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
@@ -232,7 +248,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
return sink()
-def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
+def sinkhorn_knopp_unbalanced(a, b, M, reg, mu, numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -240,7 +256,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
@@ -265,7 +281,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ mu : float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
@@ -338,14 +354,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
u = np.ones(n_a) / n_a
v = np.ones(n_b) / n_b
- # print(reg)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- # print(np.min(K))
- fi = alpha / (alpha + reg)
+ fi = mu / (mu + reg)
cpt = 0
err = 1.
@@ -371,8 +385,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
- err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
@@ -383,8 +397,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
print('{:5d}|{:8e}|'.format(cpt, err))
cpt = cpt + 1
if log:
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -401,7 +415,204 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :]
-def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, mu, tau=1e5, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\mu KL(\gamma 1, a) + \\mu KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (ns, nt) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target weights
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed M if b is a matrix (return OT loss + dual variables in log)
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ mu : float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ n_a, n_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(n_a, dtype=np.float64) / n_a
+ if len(b) == 0:
+ b = np.ones(n_b, dtype=np.float64) / n_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((n_a, n_hists)) / n_a
+ v = np.ones((n_b, n_hists)) / n_b
+ a = a.reshape(n_a, 1)
+ else:
+ u = np.ones(n_a) / n_a
+ v = np.ones(n_b) / n_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = mu / (mu + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(n_a)
+ beta = np.zeros(n_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + mu))
+ f_beta = np.exp(- beta / (reg + mu))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ if (u > tau).any() or (v > tau).any():
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %d' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def barycenter_unbalanced(A, M, reg, mu, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False):
r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
@@ -415,7 +626,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
- :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- - alpha is the marginal relaxation hyperparameter
+ - mu is the marginal relaxation hyperparameter
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
Parameters
@@ -426,7 +637,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
loss matrix for OT
reg : float
Entropy regularization term > 0
- alpha : float
+ mu : float
Marginal relaxation term > 0
weights : np.ndarray (n,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
@@ -467,7 +678,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
K = np.exp(- M / reg)
- fi = alpha / (alpha + reg)
+ fi = mu / (mu + reg)
v = np.ones((p, n_hists)) / p
u = np.ones((p, 1)) / p
@@ -499,8 +710,8 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err_u = abs(u - uprev).max() / max(abs(u), abs(uprev), 1.)
- err_v = abs(v - vprev).max() / max(abs(v), abs(vprev), 1.)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
@@ -513,8 +724,8 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
cpt += 1
if log:
log['niter'] = cpt
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q