diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 301 | ||||
-rw-r--r-- | ot/stochastic.py | 3 |
2 files changed, 304 insertions, 0 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 013bc33..dc43834 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -5,10 +5,12 @@ Bregman projections for regularized OT # Author: Remi Flamary <remi.flamary@unice.fr> # Nicolas Courty <ncourty@irisa.fr> +# Kilian Fatras <kilian.fatras@irisa.fr> # # License: MIT License import numpy as np +from .utils import unif, dist def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -1296,3 +1298,302 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, return np.sum(K0, axis=1), log else: return np.sum(K0, axis=1) + + +def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + ''' + Solve the entropic regularization optimal transport problem and return the + OT matrix from empirical data + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`M` is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : np.ndarray (ns, d) + samples in the source domain + X_t : np.ndarray (nt, d) + samples in the target domain + reg : float + Regularization term >0 + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_s = 2 + >>> n_t = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False) + >>> print(emp_sinkhorn) + >>> [[4.99977301e-01 2.26989344e-05] + [2.26989344e-05 4.99977301e-01]] + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + if a is None: + a = unif(np.shape(X_s)[0]) + if b is None: + b = unif(np.shape(X_t)[0]) + + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi + + +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + ''' + Solve the entropic regularization optimal transport problem from empirical + data and return the OT loss + + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + where : + + - :math:`M` is the (ns,nt) metric cost matrix + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : np.ndarray (ns, d) + samples in the source domain + X_t : np.ndarray (nt, d) + samples in the target domain + reg : float + Regularization term >0 + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_s = 2 + >>> n_t = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False) + >>> print(loss_sinkhorn) + >>> [4.53978687e-05] + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + if a is None: + a = unif(np.shape(X_s)[0]) + if b is None: + b = unif(np.shape(X_t)[0]) + + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + return sinkhorn_loss + + +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): + ''' + Compute the sinkhorn divergence loss from empirical data + + The function solves the following optimization problems and return the + sinkhorn divergence :math:`S`: + + .. math:: + + W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + + W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + + W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + + S &= W - 1/2 * (W_a + W_b) + + .. math:: + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + \gamma_a 1 = a + + \gamma_a^T 1= a + + \gamma_a\geq 0 + + \gamma_b 1 = b + + \gamma_b^T 1= b + + \gamma_b\geq 0 + where : + + - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt)) + - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : np.ndarray (ns, d) + samples in the source domain + X_t : np.ndarray (nt, d) + samples in the target domain + reg : float + Regularization term >0 + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_s = 2 + >>> n_t = 4 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_s), (n_s, 1)) + >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1)) + >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg) + >>> print(emp_sinkhorn_div) + >>> [2.99977435] + + + References + ---------- + + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + ''' + if log: + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + + log = {} + log['sinkhorn_loss_ab'] = sinkhorn_loss_ab + log['sinkhorn_loss_a'] = sinkhorn_loss_a + log['sinkhorn_loss_b'] = sinkhorn_loss_b + log['log_sinkhorn_ab'] = log_ab + log['log_sinkhorn_a'] = log_a + log['log_sinkhorn_b'] = log_b + + return max(0, sinkhorn_div), log + + else: + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + return max(0, sinkhorn_div) diff --git a/ot/stochastic.py b/ot/stochastic.py index 0db39c8..85c4230 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -348,8 +348,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + s.t. \gamma 1 = a + \gamma^T 1= b + \gamma \geq 0 Where : |