diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2022-03-24 10:53:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-24 10:53:47 +0100 |
commit | 767171593f2a98a26b9a39bf110a45085e3b982e (patch) | |
tree | 4eb4bcc657efc53a65c3fb4439bd0e0e106b6745 /ot/unbalanced.py | |
parent | 9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff) |
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft
* Add matrix inverse and square root to backend
* Eigen decomposition for older versions of pytorch (1.8.1 and older)
* Corrected eigen decomposition for pytorch 1.8.1 and older
* Spectral theorem is a thing
* Optimization
* small optimization
* More functions converted
* pep8
* remove a warning and prepare torch meshgrid for future torch release (which will change default indexing)
* dots and pep8
* Meshgrid corrected for older version and prepared for future versions changes
* New backend functions
* Base transport
* LinearTransport
* All transport classes + pep8
* PR added to release file
* Jcpot barycenter test
* unbalanced with backend
* pep8
* bug solve
* test of domain adaptation with backends
* solve bug for tic toc & macos
* solving scipy deprecation warning
* solving scipy deprecation warning attempt2
* solving scipy deprecation warning attempt3
* A warning is triggered when a float->int conversion is detected
* bug solve
* docs
* release file updated
* Better handling of float->int conversion in EMD
* Corrected test for is_floating_point
* docs
* release file updated
* cupy does not allow implicit cast
* fromnumpy
* added test
* test da tf jax
* test unbalanced with no provided histogram
* using type_as argument in unif function correctly
* pep8
* transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 302 |
1 files changed, 153 insertions, 149 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 15e180b..503cc1e 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -8,9 +8,9 @@ Regularized Unbalanced OT solvers from __future__ import division import warnings -import numpy as np -from scipy.special import logsumexp +from .backend import get_backend +from .utils import list_to_array # from .utils import unif, dist @@ -43,12 +43,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -70,12 +70,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -172,12 +172,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -198,7 +198,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- - ot_distance : (n_hists,) ndarray + ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -239,9 +239,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>` """ - b = np.asarray(b, dtype=np.float64) + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, @@ -291,12 +292,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -315,12 +316,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -354,17 +355,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -377,17 +376,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, 1)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(M / (-reg)) fi = reg_m / (reg_m + reg) @@ -397,14 +393,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, uprev = u vprev = v - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (a / Kv) ** fi - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -412,8 +408,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, v = vprev break - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -426,11 +426,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, break if log: - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -475,12 +475,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -501,12 +501,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -538,17 +538,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -561,56 +559,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim_a) - beta = np.zeros(dim_b) + alpha = nx.zeros(dim_a, type_as=M) + beta = nx.zeros(dim_b, type_as=M) while (err > stopThr and cpt < numItermax): uprev = u vprev = v - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((a / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = ((b / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True if n_hists: - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) else: - alpha = alpha + reg * np.log(np.max(u)) - beta = beta + reg * np.log(np.max(v)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u)) + beta = beta + reg * nx.log(nx.max(v)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -620,8 +614,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), - 1.) + err = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -636,25 +631,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if n_hists: - logu = alpha[:, None] / reg + np.log(u) - logv = beta[:, None] / reg + np.log(v) + logu = alpha[:, None] / reg + nx.log(u) + logv = beta[:, None] / reg + nx.log(v) else: - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) if log: log['logu'] = logu log['logv'] = logv if n_hists: # return only loss - res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + - logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) - res = np.exp(res) + res = nx.logsumexp( + nx.log(M + 1e-100)[:, :, None] + + logu[:, None, :] + + logv[None, :, :] + - M[:, :, None] / reg, + axis=(0, 1) + ) + res = nx.exp(res) if log: return res, log else: return res else: # return OT matrix - ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) if log: return ot_matrix, log else: @@ -683,9 +683,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 @@ -693,7 +693,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Marginal relaxation term > 0 tau : float Stabilization threshold for log domain absorption. - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -708,7 +708,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -726,9 +726,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) @@ -737,47 +740,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, fi = reg_m / (reg_m + reg) - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=A) / dim + v = nx.ones((dim, n_hists), type_as=A) / dim # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros(dim, type_as=A) + beta = nx.zeros(dim, type_as=A) + q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): - qprev = q.copy() - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + qprev = nx.copy(q) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = (Ktu ** (1 - fi)) * f_beta - q = q.dot(weights) ** (1 / (1 - fi)) + q = nx.dot(q, weights) ** (1 / (1 - fi)) Q = q[:, None] v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -786,8 +785,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(q - qprev).max() / max(abs(q).max(), - abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -804,8 +804,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, "Or a larger absorption threshold `tau`.") if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -833,15 +833,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -856,7 +856,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -874,40 +874,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) if log: log = {'err': []} - K = np.exp(- M / reg) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) - u = np.ones((dim, 1)) - q = np.ones(dim) + v = nx.ones((dim, n_hists), type_as=A) + u = nx.ones((dim, 1), type_as=A) + q = nx.ones(dim, type_as=A) err = 1. for i in range(numItermax): - uprev = u.copy() - vprev = v.copy() - qprev = q.copy() + uprev = nx.copy(u) + vprev = nx.copy(v) + qprev = nx.copy(q) - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (A / Kv) ** fi - Ktu = K.T.dot(u) - q = ((Ktu ** (1 - fi)).dot(weights)) + Ktu = nx.dot(K.T, u) + q = nx.dot(Ktu ** (1 - fi), weights) q = q ** (1 / (1 - fi)) Q = q[:, None] v = (Q / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -916,8 +919,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, q = qprev break # compute change in barycenter - err = abs(q - qprev).max() - err /= max(abs(q).max(), abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 + ) if log: log['err'].append(err) # if barycenter did not change + at least 10 iterations - stop @@ -932,8 +936,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -961,15 +965,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -984,7 +988,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters |