From 818c7ace20da36d8042b0d7ad7a712b27f7afd59 Mon Sep 17 00:00:00 2001 From: Eduardo Fernandes Montesuma Date: Wed, 27 Jul 2022 11:16:14 +0200 Subject: [MRG] Free support Sinkhorn barycenters (#387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: RĂ©mi Flamary --- test/test_bregman.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) (limited to 'test') diff --git a/test/test_bregman.py b/test/test_bregman.py index 112bfca..e128ea2 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -3,6 +3,7 @@ # Author: Remi Flamary # Kilian Fatras # Quang Huy Tran +# Eduardo Fernandes Montesuma # # License: MIT License @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) +def test_free_support_sinkhorn_barycenter(): + measures_locations = [ + np.array([-1.]).reshape((1, 1)), # First dirac support + np.array([1.]).reshape((1, 1)) # Second dirac support + ] + + measures_weights = [ + np.array([1.]), # First dirac sample weights + np.array([1.]) # Second dirac sample weights + ] + + # Barycenter initialization + X_init = np.array([-12.]).reshape((1, 1)) + + # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter + bar_locations = np.array([0.]).reshape((1, 1)) + + # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization + # term to 1, but this should be, in general, fine-tuned to the problem. + X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) + + # Verifies if calculated barycenter matches ground-truth + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) -- cgit v1.2.3