diff options
author | Hicham Janati <hicham.janati100@gmail.com> | 2022-05-09 14:30:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-09 14:30:04 +0200 |
commit | c1ccfc45350f8db3fa78d91b84eb4286bcf36e69 (patch) | |
tree | ab1086116414350346cc678d8e9c3c68a31224cf /test | |
parent | 726e84e1e9f2832ea5ad156f62a5e3636c1fd3d3 (diff) |
[MRG] Fix barycenter mass (#375)
* fix transpose in sinkhorn barycenters
* add test for assymetric cost barycenters
* fix pep8
Co-authored-by: Hicham Janati <hicham.janati@inria.fr>
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c37984..112bfca 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -491,6 +491,41 @@ def test_barycenter(nx, method, verbose, warn): @pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_assymetric_cost(nx, method, verbose, warn): + n_bins = 20 # nb bins + + # Gaussian distributions + A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + + # creating matrix A containing all distributions + A = A[:, None] + + # assymetric loss matrix + normalization + rng = np.random.RandomState(42) + M = rng.randn(n_bins, n_bins) ** 2 + M /= M.max() + + A_nx, M_nx = nx.from_numpy(A, M) + reg = 1e-2 + + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +@pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False])) def test_barycenter_debiased(nx, method, verbose, warn): |