summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 112bfca..e128ea2 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -3,6 +3,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License
@@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+def test_free_support_sinkhorn_barycenter():
+ measures_locations = [
+ np.array([-1.]).reshape((1, 1)), # First dirac support
+ np.array([1.]).reshape((1, 1)) # Second dirac support
+ ]
+
+ measures_weights = [
+ np.array([1.]), # First dirac sample weights
+ np.array([1.]) # Second dirac sample weights
+ ]
+
+ # Barycenter initialization
+ X_init = np.array([-12.]).reshape((1, 1))
+
+ # Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
+ bar_locations = np.array([0.]).reshape((1, 1))
+
+ # Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
+ # term to 1, but this should be, in general, fine-tuned to the problem.
+ X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)
+
+ # Verifies if calculated barycenter matches ground-truth
+ np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+
@pytest.mark.parametrize("method, verbose, warn",
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
[True, False], [True, False]))