summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorEduardo Fernandes Montesuma <edumontesuma@gmail.com>2022-07-27 11:16:14 +0200
committerGitHub <noreply@github.com>2022-07-27 11:16:14 +0200
commit818c7ace20da36d8042b0d7ad7a712b27f7afd59 (patch)
tree58dd4e0c9f990ea0c851712d85748de99ce6b236 /test
parent7c2a9523747c90aebfef711fdf34b5bbdb6f2f4d (diff)
[MRG] Free support Sinkhorn barycenters (#387)
* Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_bregman.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 112bfca..e128ea2 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -3,6 +3,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+def test_free_support_sinkhorn_barycenter():
+ measures_locations = [
+ np.array([-1.]).reshape((1, 1)), # First dirac support
+ np.array([1.]).reshape((1, 1)) # Second dirac support
+ ]
+
+ measures_weights = [
+ np.array([1.]), # First dirac sample weights
+ np.array([1.]) # Second dirac sample weights
+ ]
+
+ # Barycenter initialization
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
+ # term to 1, but this should be, in general, fine-tuned to the problem.
+ X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)
+
+ # Verifies if calculated barycenter matches ground-truth
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.parametrize("method, verbose, warn",
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
[True, False], [True, False]))