From 97feeb32b6c069d7bb44cd995531c2b820d59771 Mon Sep 17 00:00:00 2001 From: tgnassou <66993815+tgnassou@users.noreply.github.com> Date: Mon, 16 Jan 2023 18:09:44 +0100 Subject: [MRG] OT for Gaussian distributions (#428) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: RĂ©mi Flamary Co-authored-by: Alexandre Gramfort --- test/test_gaussian.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 test/test_gaussian.py (limited to 'test/test_gaussian.py') diff --git a/test/test_gaussian.py b/test/test_gaussian.py new file mode 100644 index 0000000..be7a806 --- /dev/null +++ b/test/test_gaussian.py @@ -0,0 +1,98 @@ +"""Tests for module gaussian""" + +# Author: Theo Gnassounou +# Remi Flamary +# +# License: MIT License + +import numpy as np + +import pytest + +import ot +from ot.datasets import make_data_classif + + +def test_bures_wasserstein_mapping(nx): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + Cs = np.cov(Xs.T) + Ct = np.cov(Xt.T) + + Xsb, msb, mtb, Csb, Ctb = nx.from_numpy(Xs, ms, mt, Cs, Ct) + + A_log, b_log, log = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=True) + A, b = ot.gaussian.bures_wasserstein_mapping(msb, mtb, Csb, Ctb, log=False) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_mapping(nx, bias): + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + if not bias: + ms = np.mean(Xs, axis=0)[None, :] + mt = np.mean(Xt, axis=0)[None, :] + + Xs = Xs - ms + Xt = Xt - mt + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + + A, b, log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=True, bias=bias) + A_log, b_log = ot.gaussian.empirical_bures_wasserstein_mapping(Xsb, Xtb, log=False, bias=bias) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) + Xst_log = nx.to_numpy(nx.dot(Xsb, A_log) + b_log) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + Cst_log = np.cov(Xst_log.T) + + np.testing.assert_allclose(Cst_log, Cst, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_bures_wasserstein_distance(nx): + ms, mt = np.array([0]), np.array([10]) + Cs, Ct = np.array([[1]]).astype(np.float32), np.array([[1]]).astype(np.float32) + msb, mtb, Csb, Ctb = nx.from_numpy(ms, mt, Cs, Ct) + Wb_log, log = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=True) + Wb = ot.gaussian.bures_wasserstein_distance(msb, mtb, Csb, Ctb, log=False) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("bias", [True, False]) +def test_empirical_bures_wasserstein_distance(nx, bias): + ns = 400 + nt = 400 + + rng = np.random.RandomState(10) + Xs = rng.normal(0, 1, ns)[:, np.newaxis] + Xt = rng.normal(10 * bias, 1, nt)[:, np.newaxis] + + Xsb, Xtb = nx.from_numpy(Xs, Xt) + Wb_log, log = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=True, bias=bias) + Wb = ot.gaussian.empirical_bures_wasserstein_distance(Xsb, Xtb, log=False, bias=bias) + + np.testing.assert_allclose(nx.to_numpy(Wb_log), nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2) -- cgit v1.2.3