diff options
author | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
commit | 35bd2c98b642df78638d7d733bc1a89d873db1de (patch) | |
tree | 6bc637624004713808d3097b95acdccbb9608e52 /ot/unbalanced.py | |
parent | c4753bd3f74139af8380127b66b484bc09b50661 (diff) | |
parent | eccb1386eea52b94b82456d126bd20cbe3198e05 (diff) |
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 525 |
1 files changed, 376 insertions, 149 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 15e180b..90c920c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -4,13 +4,14 @@ Regularized Unbalanced OT solvers """ # Author: Hicham Janati <hicham.janati@inria.fr> +# Laetitia Chapel <laetitia.chapel@univ-ubs.fr> # License: MIT License from __future__ import division import warnings -import numpy as np -from scipy.special import logsumexp +from .backend import get_backend +from .utils import list_to_array # from .utils import unif, dist @@ -43,12 +44,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -70,12 +71,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -172,12 +173,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -198,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- - ot_distance : (n_hists,) ndarray + ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -239,9 +240,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>` """ - b = np.asarray(b, dtype=np.float64) + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, @@ -291,12 +293,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -315,12 +317,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -354,17 +356,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -377,17 +377,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, 1)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(M / (-reg)) fi = reg_m / (reg_m + reg) @@ -397,14 +394,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, uprev = u vprev = v - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (a / Kv) ** fi - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -412,8 +409,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, v = vprev break - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -426,11 +427,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, break if log: - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -475,12 +476,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -501,12 +502,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -538,17 +539,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -561,56 +560,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim_a) - beta = np.zeros(dim_b) + alpha = nx.zeros(dim_a, type_as=M) + beta = nx.zeros(dim_b, type_as=M) while (err > stopThr and cpt < numItermax): uprev = u vprev = v - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((a / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = ((b / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True if n_hists: - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) else: - alpha = alpha + reg * np.log(np.max(u)) - beta = beta + reg * np.log(np.max(v)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u)) + beta = beta + reg * nx.log(nx.max(v)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -620,8 +615,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), - 1.) + err = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -636,25 +632,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if n_hists: - logu = alpha[:, None] / reg + np.log(u) - logv = beta[:, None] / reg + np.log(v) + logu = alpha[:, None] / reg + nx.log(u) + logv = beta[:, None] / reg + nx.log(v) else: - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) if log: log['logu'] = logu log['logv'] = logv if n_hists: # return only loss - res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + - logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) - res = np.exp(res) + res = nx.logsumexp( + nx.log(M + 1e-100)[:, :, None] + + logu[:, None, :] + + logv[None, :, :] + - M[:, :, None] / reg, + axis=(0, 1) + ) + res = nx.exp(res) if log: return res, log else: return res else: # return OT matrix - ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) if log: return ot_matrix, log else: @@ -683,9 +684,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 @@ -693,7 +694,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Marginal relaxation term > 0 tau : float Stabilization threshold for log domain absorption. - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -708,7 +709,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -726,9 +727,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) @@ -737,47 +741,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, fi = reg_m / (reg_m + reg) - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=A) / dim + v = nx.ones((dim, n_hists), type_as=A) / dim # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros(dim, type_as=A) + beta = nx.zeros(dim, type_as=A) + q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): - qprev = q.copy() - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + qprev = nx.copy(q) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = (Ktu ** (1 - fi)) * f_beta - q = q.dot(weights) ** (1 / (1 - fi)) + q = nx.dot(q, weights) ** (1 / (1 - fi)) Q = q[:, None] v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -786,8 +786,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(q - qprev).max() / max(abs(q).max(), - abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -804,8 +805,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, "Or a larger absorption threshold `tau`.") if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -833,15 +834,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -856,7 +857,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -874,40 +875,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) if log: log = {'err': []} - K = np.exp(- M / reg) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) - u = np.ones((dim, 1)) - q = np.ones(dim) + v = nx.ones((dim, n_hists), type_as=A) + u = nx.ones((dim, 1), type_as=A) + q = nx.ones(dim, type_as=A) err = 1. for i in range(numItermax): - uprev = u.copy() - vprev = v.copy() - qprev = q.copy() + uprev = nx.copy(u) + vprev = nx.copy(v) + qprev = nx.copy(q) - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (A / Kv) ** fi - Ktu = K.T.dot(u) - q = ((Ktu ** (1 - fi)).dot(weights)) + Ktu = nx.dot(K.T, u) + q = nx.dot(Ktu ** (1 - fi), weights) q = q ** (1 / (1 - fi)) Q = q[:, None] v = (Q / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -916,8 +920,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, q = qprev break # compute change in barycenter - err = abs(q - qprev).max() - err /= max(abs(q).max(), abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 + ) if log: log['err'].append(err) # if barycenter did not change + at least 10 iterations - stop @@ -932,8 +937,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -961,15 +966,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -984,7 +989,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -1025,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) + + +def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] <references-regpath>` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2) + array([[0.3 , 0. ], + [0. , 0.07]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2) + array([[0.25, 0. ], + [0. , 0. ]]) + + + .. _references-regpath: + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd : Unregularized OT + ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT + """ + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + if G0 is None: + G = a[:, None] * b[None, :] + else: + G = G0 + + if log: + log = {'err': [], 'G': []} + + if div == 'kl': + K = nx.exp(M / - reg_m / 2) + elif div == 'l2': + K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2, + nx.zeros((dim_a, dim_b), type_as=M)) + else: + warnings.warn("The div parameter should be either equal to 'kl' or \ + 'l2': it has been set to 'kl'.") + div = 'kl' + K = nx.exp(M / - reg_m / 2) + + for i in range(numItermax): + Gprev = G + + if div == 'kl': + u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16)) + v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16)) + G = G * K * u[:, None] * v[None, :] + elif div == 'l2': + Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16 + G = G * K / Gd + + err = nx.sqrt(nx.sum((G - Gprev) ** 2)) + if log: + log['err'].append(err) + log['G'].append(G) + if verbose: + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break + + if log: + log['cost'] = nx.sum(G * M) + return G, log + else: + return G + + +def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] <references-regpath>` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) + 0.25 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) + 0.57 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0, + numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=True) + + if log: + return log_mm['cost'], log_mm + else: + return log_mm['cost'] |