diff options
author | tgnassou <66993815+tgnassou@users.noreply.github.com> | 2023-01-16 18:09:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-16 18:09:44 +0100 |
commit | 97feeb32b6c069d7bb44cd995531c2b820d59771 (patch) | |
tree | 18f28e89a925534884c6ed97bfd986bbb61d1279 /test/test_da.py | |
parent | 058d275565f0f65c23e06853812d5eb3a6ebdcef (diff) |
[MRG] OT for Gaussian distributions (#428)
* add gaussian modules
* add gaussian modules
* add PR to release.md
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Apply suggestions from code review
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/gaussian.py
* Update ot/gaussian.py
* add empirical bures wassertsein distance, fix docstring and test
* update to fit with new networkx API
* add test for jax et tf"
* fix test
* fix test?
* add empirical_bures_wasserstein_mapping
* fix docs
* fix doc
* fix docstring
* add tgnassou to contributors
* add more coverage for gaussian.py
* add deprecated function
* fix doc math"
"
* fix doc math"
"
* add remi flamary to authors of gaussiansmodule
* fix equation
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_da.py')
-rw-r--r-- | test/test_da.py | 21 |
1 files changed, 0 insertions, 21 deletions
diff --git a/test/test_da.py b/test/test_da.py index 138936f..c5f08d6 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -577,27 +577,6 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") -def test_linear_mapping(nx): - ns = 50 - nt = 50 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - Xsb, Xtb = nx.from_numpy(Xs, Xt) - - A, b = ot.da.OT_mapping_linear(Xsb, Xtb) - - Xst = nx.to_numpy(nx.dot(Xsb, A) + b) - - Ct = np.cov(Xt.T) - Cst = np.cov(Xst.T) - - np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) - - -@pytest.skip_backend("jax") -@pytest.skip_backend("tf") def test_linear_mapping_class(nx): ns = 50 nt = 50 |