summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authortgnassou <66993815+tgnassou@users.noreply.github.com>2023-01-16 18:09:44 +0100
committerGitHub <noreply@github.com>2023-01-16 18:09:44 +0100
commit97feeb32b6c069d7bb44cd995531c2b820d59771 (patch)
tree18f28e89a925534884c6ed97bfd986bbb61d1279 /test/test_da.py
parent058d275565f0f65c23e06853812d5eb3a6ebdcef (diff)
[MRG] OT for Gaussian distributions (#428)
* add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py21
1 files changed, 0 insertions, 21 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 138936f..c5f08d6 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -577,27 +577,6 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
-def test_linear_mapping(nx):
- ns = 50
- nt = 50
-
- Xs, ys = make_data_classif('3gauss', ns)
- Xt, yt = make_data_classif('3gauss2', nt)
-
- Xsb, Xtb = nx.from_numpy(Xs, Xt)
-
- A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
-
- Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
-
- Ct = np.cov(Xt.T)
- Cst = np.cov(Xst.T)
-
- np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-
-
-@pytest.skip_backend("jax")
-@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
ns = 50
nt = 50