summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_barycenter.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-02-14 14:53:51 +0100
committertlacombe <lacombe1993@gmail.com>2020-02-14 14:53:51 +0100
commitdc4442bc402ac25290eb529b57407607434bb7ae (patch)
tree6b6dee70d9a6804dfdafda1448ea8872189ab751 /src/python/test/test_wasserstein_barycenter.py
parent5ad0b45a5b5387f49fe72843a74e96a346c6d6fe (diff)
barycenter update, adding more tests and details about log (assigments, cost, nb iter)
Diffstat (limited to 'src/python/test/test_wasserstein_barycenter.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py15
1 files changed, 13 insertions, 2 deletions
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
index 910d23ff..07242582 100755
--- a/src/python/test/test_wasserstein_barycenter.py
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -27,7 +27,18 @@ def test_lagrangian_barycenter():
res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+ dg8 = np.array([[0., 4.]])
+
+ # error crit.
+ eps = 0.000001
- assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < 0.001
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
- assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < 0.001
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
+ Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
+ assert np.linalg.norm(Y - np.array([[1,3]])) < eps
+ assert np.abs(log["energy"] - 2) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0]]))
+ assert lagrangian_barycenter(pdiagset = []) is None