From 97feeb32b6c069d7bb44cd995531c2b820d59771 Mon Sep 17 00:00:00 2001 From: tgnassou <66993815+tgnassou@users.noreply.github.com> Date: Mon, 16 Jan 2023 18:09:44 +0100 Subject: [MRG] OT for Gaussian distributions (#428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- ot/__init__.py | 3 +- ot/da.py | 118 ++------------------ ot/gaussian.py | 333 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 342 insertions(+), 112 deletions(-) create mode 100644 ot/gaussian.py (limited to 'ot') diff --git a/ot/__init__.py b/ot/__init__.py index 51eb726..0b55e0c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import regpath from . import weak from . import factored from . import solvers +from . import gaussian # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -56,7 +57,7 @@ __version__ = "0.8.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd2_1d', 'wasserstein_1d', 'backend', + 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', diff --git a/ot/da.py b/ot/da.py index 083663c..35e303b 100644 --- a/ot/da.py +++ b/ot/da.py @@ -17,8 +17,9 @@ from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import list_to_array, check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator, deprecated from .unbalanced import sinkhorn_unbalanced +from .gaussian import empirical_bures_wasserstein_mapping from .optim import cg from .optim import gcg @@ -679,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L -def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, - wt=None, bias=True, log=False): - r"""Return OT linear operator between samples. - - The function estimates the optimal linear operator that aligns the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` - and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in - :ref:`[14] ` and discussed in remark 2.29 in - :ref:`[15] `. - - The linear operator from source to target :math:`M` - - .. math:: - M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} - - where : - - .. math:: - \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} - \Sigma_s^{-1/2} - - \mathbf{b} &= \mu_t - \mathbf{A} \mu_s - - Parameters - ---------- - xs : array-like (ns,d) - samples in the source domain - xt : array-like (nt,d) - samples in the target domain - reg : float,optional - regularization added to the diagonals of covariances (>0) - ws : array-like (ns,1), optional - weights for the source samples - wt : array-like (ns,1), optional - weights for the target samples - bias: boolean, optional - estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) - log : bool, optional - record log if True - - - Returns - ------- - A : (d, d) array-like - Linear operator - b : (1, d) array-like - bias - log : dict - log dictionary return only if log==True in parameters - - - .. _references-OT-mapping-linear: - References - ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of - distributions", Journal of Optimization Theory and Applications - Vol 43, 1984 - - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - - """ - xs, xt = list_to_array(xs, xt) - nx = get_backend(xs, xt) - - d = xs.shape[1] - - if bias: - mxs = nx.mean(xs, axis=0)[None, :] - mxt = nx.mean(xt, axis=0)[None, :] - - xs = xs - mxs - xt = xt - mxt - else: - mxs = nx.zeros((1, d), type_as=xs) - mxt = nx.zeros((1, d), type_as=xs) - - if ws is None: - ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] - - if wt is None: - wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - - Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) - Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - - Cs12 = nx.sqrtm(Cs) - Cs_12 = nx.inv(Cs12) - - M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - - A = dots(Cs_12, M0, Cs_12) - - b = mxt - nx.dot(mxs, A) - - if log: - log = {} - log['Cs'] = Cs - log['Ct'] = Ct - log['Cs12'] = Cs12 - log['Cs_12'] = Cs_12 - return A, b, log - else: - return A, b +OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping) def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, @@ -1378,10 +1274,10 @@ class LinearTransport(BaseTransport): self.mu_t = self.distribution_estimation(Xt) # coupling estimation - returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=nx.reshape(self.mu_s, (-1, 1)), - wt=nx.reshape(self.mu_t, (-1, 1)), - bias=self.bias, log=self.log) + returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg, + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), + bias=self.bias, log=self.log) # deal with the value of log if self.log: diff --git a/ot/gaussian.py b/ot/gaussian.py new file mode 100644 index 0000000..4ffb726 --- /dev/null +++ b/ot/gaussian.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +""" +Optimal transport for Gaussian distributions +""" + +# Author: Theo Gnassounou +# Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .utils import dots +from .utils import list_to_array + + +def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): + r"""Return OT linear operator between samples. + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[1] ` and discussed in remark 2.29 in + :ref:`[2] `. + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + Parameters + ---------- + ms : array-like (d,) + mean of the source distribution + mt : array-like (d,) + mean of the target distribution + Cs : array-like (d,) + covariance of the source distribution + Ct : array-like (d,) + covariance of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + A : (d, d) array-like + Linear operator + b : (1, d) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) + nx = get_backend(ms, mt, Cs, Ct) + + Cs12 = nx.sqrtm(Cs) + Cs12inv = nx.inv(Cs12) + + M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) + + A = dots(Cs12inv, M0, Cs12inv) + + b = mt - nx.dot(ms, A) + + if log: + log = {} + log['Cs12'] = Cs12 + log['Cs12inv'] = Cs12inv + return A, b, log + else: + return A, b + + +def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): + r"""Return OT linear operator between samples. + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[1] ` and discussed in remark 2.29 in + :ref:`[2] `. + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + reg : float,optional + regularization added to the diagonals of covariances (>0) + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (d, d) array-like + Linear operator + b : (1, d) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + d = xs.shape[1] + + if bias: + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] + + xs = xs - mxs + xt = xt - mxt + else: + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) + + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) + + if log: + A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log) + log['Cs'] = Cs + log['Ct'] = Ct + return A, b, log + else: + A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) + return A, b + + +def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): + r"""Return Bures Wasserstein distance between samples. + + The function estimates the Bures-Wasserstein distance between two + empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, + discussed in remark 2.31 :ref:`[1] `. + + The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + Parameters + ---------- + ms : array-like (d,) + mean of the source distribution + mt : array-like (d,) + mean of the target distribution + Cs : array-like (d,) + covariance of the source distribution + Ct : array-like (d,) + covariance of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + W : float + Bures Wasserstein distance + log : dict + log dictionary return only if log==True in parameters + + + .. _references-bures-wasserstein-distance: + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) + nx = get_backend(ms, mt, Cs, Ct) + + Cs12 = nx.sqrtm(Cs) + + B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) + W = nx.sqrt(nx.norm(ms - mt)**2 + B) + if log: + log = {} + log['Cs12'] = Cs12 + return W, log + else: + return W + + +def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): + r"""Return Bures Wasserstein distance from mean and covariance of distribution. + + The function estimates the Bures-Wasserstein distance between two + empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, + discussed in remark 2.31 :ref:`[1] `. + + The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + Parameters + ---------- + xs : array-like (ns,d) + samples in the source domain + xt : array-like (nt,d) + samples in the target domain + reg : float,optional + regularization added to the diagonals of covariances (>0) + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) + log : bool, optional + record log if True + + + Returns + ------- + W : float + Bures Wasserstein distance + log : dict + log dictionary return only if log==True in parameters + + + .. _references-bures-wasserstein-distance: + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) + + d = xs.shape[1] + + if bias: + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] + + xs = xs - mxs + xt = xt - mxt + else: + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) + + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) + + if log: + W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) + log['Cs'] = Cs + log['Ct'] = Ct + return W, log + else: + W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) + return W -- cgit v1.2.3