diff options
author | tgnassou <66993815+tgnassou@users.noreply.github.com> | 2023-01-16 18:09:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-16 18:09:44 +0100 |
commit | 97feeb32b6c069d7bb44cd995531c2b820d59771 (patch) | |
tree | 18f28e89a925534884c6ed97bfd986bbb61d1279 /test | |
parent | 058d275565f0f65c23e06853812d5eb3a6ebdcef (diff) |
[MRG] OT for Gaussian distributions (#428)
* add gaussian modules
* add gaussian modules
* add PR to release.md
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/gaussian.py
* Update ot/gaussian.py
* add empirical bures wassertsein distance, fix docstring and test
* update to fit with new networkx API
* add test for jax et tf"
* fix test
* fix test?
* add empirical_bures_wasserstein_mapping
* fix docs
* fix doc
* fix docstring
* add tgnassou to contributors
* add more coverage for gaussian.py
* add deprecated function
* fix doc math"
"
* fix doc math"
"
* add remi flamary to authors of gaussiansmodule
* fix equation
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test')
-rw-r--r-- | test/test_da.py | 21 | ||||
-rw-r--r-- | test/test_gaussian.py | 98 |
2 files changed, 98 insertions, 21 deletions
diff --git a/test/test_da.py b/test/test_da.py index 138936f..c5f08d6 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -577,27 +577,6 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") -def test_linear_mapping(nx): - ns = 50 - nt = 50 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - Xsb, Xtb = nx.from_numpy(Xs, Xt) - - A, b = ot.da.OT_mapping_linear(Xsb, Xtb) - - Xst = nx.to_numpy(nx.dot(Xsb, A) + b) - - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) - - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - - -@pytest.skip_backend("jax") -@pytest.skip_backend("tf") def test_linear_mapping_class(nx): ns = 50 nt = 50 diff --git a/test/test_gaussian.py b/test/test_gaussian.py new file mode 100644 index 0000000..be7a806 --- /dev/null +++ b/test/test_gaussian.py @@ -0,0 +1,98 @@ +"""Tests for module gaussian""" + +# Author: Theo Gnassounou <theo.gnassounou@inria.fr> +# Remi Flamary <remi.flamary@polytehnique.edu> +# +# License: MIT License + +import numpy as np + +import pytest + +import ot +from ot.datasets import make_data_classif + + +def test_bures_wasserstein_mapping(nx): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + Cs = np.cov(Xs.T) + Ct = np.cov(Xt.T) + + Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct) + + A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_mapping(nx, bias): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + if not bias: + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + + Xs = Xs - ms + Xt = Xt - mt + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + + A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias) + A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_bures_wasserstein_distance(nx): + ms, mt = np.array([0]), np.array([10]) + Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32) + msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) + Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) + Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_distance(nx, bias): + ns = 400 + nt = 400 + + rng = np.random.RandomState(10) + Xs = rng.normal(0, 1, ns)[:, np.newaxis] + Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis] + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias) + Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) |