summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/gromov.py2
-rw-r--r--test/test_gromov.py27
2 files changed, 28 insertions, 1 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index e03fa5b..65b2e29 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -613,7 +613,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
The function solves the following optimization problem:
.. math::
- C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+ C = argmin_C\in R^{NxN} \sum_s \lambda_s GW(C,Cs,p,ps)
Where :
diff --git a/test/test_gromov.py b/test/test_gromov.py
index d865380..bb23469 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -116,3 +116,30 @@ def test_gromov_barycenter():
'kl_loss', # 5e-4,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_gromov_entropic_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', nt)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'square_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
+ [ot.unif(ns), ot.unif(nt)
+ ], ot.unif(n_samples), [.5, .5],
+ 'kl_loss', 1e-3,
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))