diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_bregman.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py index 112bfca..e128ea2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,6 +3,7 @@ # Author: Remi Flamary <remi.flamary@unice.fr> # Kilian Fatras <kilian.fatras@irisa.fr> # Quang Huy Tran <quang-huy.tran@univ-ubs.fr> +# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr> # # License: MIT License @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) |