summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortgnassou <66993815+tgnassou@users.noreply.github.com>2023-01-16 18:09:44 +0100
committerGitHub <noreply@github.com>2023-01-16 18:09:44 +0100
commit97feeb32b6c069d7bb44cd995531c2b820d59771 (patch)
tree18f28e89a925534884c6ed97bfd986bbb61d1279
parent058d275565f0f65c23e06853812d5eb3a6ebdcef (diff)
[MRG] OT for Gaussian distributions (#428)
* add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
-rw-r--r--CONTRIBUTORS.md1
-rw-r--r--RELEASES.md1
-rw-r--r--docs/source/all.rst1
-rw-r--r--docs/source/quickstart.rst6
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py2
-rw-r--r--examples/gromov/plot_barycenter_fgw.py2
-rw-r--r--ot/__init__.py3
-rw-r--r--ot/da.py118
-rw-r--r--ot/gaussian.py333
-rw-r--r--test/test_da.py21
-rw-r--r--test/test_gaussian.py98
11 files changed, 448 insertions, 138 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 0524151..67d8337 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -40,6 +40,7 @@ The contributors to this library are:
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
+* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
## Acknowledgments
diff --git a/RELEASES.md b/RELEASES.md
index c78319d..4ed3625 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -4,6 +4,7 @@
#### New features
+- Added Bures Wasserstein distance in `ot.gaussian` (PR ##428)
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
diff --git a/docs/source/all.rst b/docs/source/all.rst
index 1ec6be3..60cc85c 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -31,6 +31,7 @@ API and modules
sliced
weak
factored
+ gaussian
.. autosummary::
:toctree: ../modules/generated/
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index b4cc8ab..c8eac30 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -279,7 +279,7 @@ distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
distributions. In the case when the finite sample dataset is supposed Gaussian,
-we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the
+we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
Monge mapping.
@@ -628,7 +628,7 @@ approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
-:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector
+:any:`ot.gaussian.bures_wasserstein_mapping` that returns the operator :math:`A` and vector
:math:`b`. Note that if the number of samples is too small there is a parameter
:code:`reg` that provides a regularization for the covariance matrix estimation.
@@ -640,7 +640,7 @@ method proposed in [8]_ that estimates a continuous mapping approximating the
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.
-.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear
+.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.gaussian.bures_wasserstein_mapping
:add-heading: Examples of Monge mapping estimation
:heading-level: "
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index a44096a..8284a2a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -61,7 +61,7 @@ plt.plot(xt[:, 0], xt[:, 1], 'o')
# Estimate linear mapping and transport
# -------------------------------------
-Ae, be = ot.da.OT_mapping_linear(xs, xt)
+Ae, be = ot.gaussian.empirical_bures_wasserstein_mapping(xs, xt)
xst = xs.dot(Ae) + be
diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 556e08f..dc3c6aa 100644
--- a/examples/gromov/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -174,7 +174,7 @@ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
# -------------------------
#%% Create the barycenter
-bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+bary = nx.from_numpy_array(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
for i, v in enumerate(A.ravel()):
bary.add_node(i, attr_name=v)
diff --git a/ot/__init__.py b/ot/__init__.py
index 51eb726..0b55e0c 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -35,6 +35,7 @@ from . import regpath
from . import weak
from . import factored
from . import solvers
+from . import gaussian
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -56,7 +57,7 @@ __version__ = "0.8.3dev"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd2_1d', 'wasserstein_1d', 'backend',
+ 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
diff --git a/ot/da.py b/ot/da.py
index 083663c..35e303b 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -17,8 +17,9 @@ from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import list_to_array, check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .unbalanced import sinkhorn_unbalanced
+from .gaussian import empirical_bures_wasserstein_mapping
from .optim import cg
from .optim import gcg
@@ -679,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
- wt=None, bias=True, log=False):
- r"""Return OT linear operator between samples.
-
- The function estimates the optimal linear operator that aligns the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
- and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
- :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
- :ref:`[15] <references-OT-mapping-linear>`.
-
- The linear operator from source to target :math:`M`
-
- .. math::
- M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
-
- where :
-
- .. math::
- \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
- \Sigma_s^{-1/2}
-
- \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
-
- Parameters
- ----------
- xs : array-like (ns,d)
- samples in the source domain
- xt : array-like (nt,d)
- samples in the target domain
- reg : float,optional
- regularization added to the diagonals of covariances (>0)
- ws : array-like (ns,1), optional
- weights for the source samples
- wt : array-like (ns,1), optional
- weights for the target samples
- bias: boolean, optional
- estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
- log : bool, optional
- record log if True
-
-
- Returns
- -------
- A : (d, d) array-like
- Linear operator
- b : (1, d) array-like
- bias
- log : dict
- log dictionary return only if log==True in parameters
-
-
- .. _references-OT-mapping-linear:
- References
- ----------
- .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
- distributions", Journal of Optimization Theory and Applications
- Vol 43, 1984
-
- .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
-
- """
- xs, xt = list_to_array(xs, xt)
- nx = get_backend(xs, xt)
-
- d = xs.shape[1]
-
- if bias:
- mxs = nx.mean(xs, axis=0)[None, :]
- mxt = nx.mean(xt, axis=0)[None, :]
-
- xs = xs - mxs
- xt = xt - mxt
- else:
- mxs = nx.zeros((1, d), type_as=xs)
- mxt = nx.zeros((1, d), type_as=xs)
-
- if ws is None:
- ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
-
- if wt is None:
- wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
-
- Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
- Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
-
- Cs12 = nx.sqrtm(Cs)
- Cs_12 = nx.inv(Cs12)
-
- M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
-
- A = dots(Cs_12, M0, Cs_12)
-
- b = mxt - nx.dot(mxs, A)
-
- if log:
- log = {}
- log['Cs'] = Cs
- log['Ct'] = Ct
- log['Cs12'] = Cs12
- log['Cs_12'] = Cs_12
- return A, b, log
- else:
- return A, b
+OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping)
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
@@ -1378,10 +1274,10 @@ class LinearTransport(BaseTransport):
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
- returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=nx.reshape(self.mu_s, (-1, 1)),
- wt=nx.reshape(self.mu_t, (-1, 1)),
- bias=self.bias, log=self.log)
+ returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg,
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
+ bias=self.bias, log=self.log)
# deal with the value of log
if self.log:
diff --git a/ot/gaussian.py b/ot/gaussian.py
new file mode 100644
index 0000000..4ffb726
--- /dev/null
+++ b/ot/gaussian.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+"""
+Optimal transport for Gaussian distributions
+"""
+
+# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
+# Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .utils import dots
+from .utils import list_to_array
+
+
+def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
+ r"""Return OT linear operator between samples.
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[2] <references-OT-mapping-linear>`.
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
+
+ where :
+
+ .. math::
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
+ \Sigma_s^{-1/2}
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
+
+ Parameters
+ ----------
+ ms : array-like (d,)
+ mean of the source distribution
+ mt : array-like (d,)
+ mean of the target distribution
+ Cs : array-like (d,)
+ covariance of the source distribution
+ Ct : array-like (d,)
+ covariance of the target distribution
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d, d) array-like
+ Linear operator
+ b : (1, d) array-like
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-OT-mapping-linear:
+ References
+ ----------
+ .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
+ nx = get_backend(ms, mt, Cs, Ct)
+
+ Cs12 = nx.sqrtm(Cs)
+ Cs12inv = nx.inv(Cs12)
+
+ M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
+
+ A = dots(Cs12inv, M0, Cs12inv)
+
+ b = mt - nx.dot(ms, A)
+
+ if log:
+ log = {}
+ log['Cs12'] = Cs12
+ log['Cs12inv'] = Cs12inv
+ return A, b, log
+ else:
+ return A, b
+
+
+def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ r"""Return OT linear operator between samples.
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[2] <references-OT-mapping-linear>`.
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
+
+ where :
+
+ .. math::
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
+ \Sigma_s^{-1/2}
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
+
+ Parameters
+ ----------
+ xs : array-like (ns,d)
+ samples in the source domain
+ xt : array-like (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of covariances (>0)
+ ws : array-like (ns,1), optional
+ weights for the source samples
+ wt : array-like (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d, d) array-like
+ Linear operator
+ b : (1, d) array-like
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-OT-mapping-linear:
+ References
+ ----------
+ .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
+
+ if ws is None:
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
+
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
+
+ if log:
+ A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log)
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ return A, b, log
+ else:
+ A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct)
+ return A, b
+
+
+def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
+ r"""Return Bures Wasserstein distance between samples.
+
+ The function estimates the Bures-Wasserstein distance between two
+ empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
+ discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
+
+ The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
+
+ .. math::
+ \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
+
+ where :
+
+ .. math::
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+
+ Parameters
+ ----------
+ ms : array-like (d,)
+ mean of the source distribution
+ mt : array-like (d,)
+ mean of the target distribution
+ Cs : array-like (d,)
+ covariance of the source distribution
+ Ct : array-like (d,)
+ covariance of the target distribution
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : float
+ Bures Wasserstein distance
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-bures-wasserstein-distance:
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
+ nx = get_backend(ms, mt, Cs, Ct)
+
+ Cs12 = nx.sqrtm(Cs)
+
+ B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
+ W = nx.sqrt(nx.norm(ms - mt)**2 + B)
+ if log:
+ log = {}
+ log['Cs12'] = Cs12
+ return W, log
+ else:
+ return W
+
+
+def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ r"""Return Bures Wasserstein distance from mean and covariance of distribution.
+
+ The function estimates the Bures-Wasserstein distance between two
+ empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
+ discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
+
+ The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
+
+ .. math::
+ \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
+
+ where :
+
+ .. math::
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+
+ Parameters
+ ----------
+ xs : array-like (ns,d)
+ samples in the source domain
+ xt : array-like (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of covariances (>0)
+ ws : array-like (ns,1), optional
+ weights for the source samples
+ wt : array-like (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : float
+ Bures Wasserstein distance
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-bures-wasserstein-distance:
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
+
+ if ws is None:
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
+
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
+
+ if log:
+ W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log)
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ return W, log
+ else:
+ W = bures_wasserstein_distance(mxs, mxt, Cs, Ct)
+ return W
diff --git a/test/test_da.py b/test/test_da.py
index 138936f..c5f08d6 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -577,27 +577,6 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
-def test_linear_mapping(nx):
- ns = 50
- nt = 50
-
- Xs, ys = make_data_classif('3gauss', ns)
- Xt, yt = make_data_classif('3gauss2', nt)
-
- Xsb, Xtb = nx.from_numpy(Xs, Xt)
-
- A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
-
- Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
-
- Ct = np.cov(Xt.T)
- Cst = np.cov(Xst.T)
-
- np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-
-
-@pytest.skip_backend("jax")
-@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
ns = 50
nt = 50
diff --git a/test/test_gaussian.py b/test/test_gaussian.py
new file mode 100644
index 0000000..be7a806
--- /dev/null
+++ b/test/test_gaussian.py
@@ -0,0 +1,98 @@
+"""Tests for module gaussian"""
+
+# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
+# Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+import numpy as np
+
+import pytest
+
+import ot
+from ot.datasets import make_data_classif
+
+
+def test_bures_wasserstein_mapping(nx):
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ ms = np.mean(Xs, axis=0)[None, :]
+ mt = np.mean(Xt, axis=0)[None, :]
+ Cs = np.cov(Xs.T)
+ Ct = np.cov(Xt.T)
+
+ Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct)
+
+ A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True)
+ A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
+ Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log)
+
+ Cst = np.cov(Xst.T)
+ Cst_log = np.cov(Xst_log.T)
+
+ np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.parametrize("bias", [True, False])
+def test_empirical_bures_wasserstein_mapping(nx, bias):
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ if not bias:
+ ms = np.mean(Xs, axis=0)[None, :]
+ mt = np.mean(Xt, axis=0)[None, :]
+
+ Xs = Xs - ms
+ Xt = Xt - mt
+
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+
+ A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias)
+ A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
+ Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log)
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+ Cst_log = np.cov(Xst_log.T)
+
+ np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_bures_wasserstein_distance(nx):
+ ms, mt = np.array([0]), np.array([10])
+ Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32)
+ msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct)
+ Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True)
+ Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False)
+
+ np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.parametrize("bias", [True, False])
+def test_empirical_bures_wasserstein_distance(nx, bias):
+ ns = 400
+ nt = 400
+
+ rng = np.random.RandomState(10)
+ Xs = rng.normal(0, 1, ns)[:, np.newaxis]
+ Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis]
+
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+ Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias)
+ Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias)
+
+ np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)