summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_barycenter.py
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2020-03-16 13:16:04 +0100
committertlacombe <lacombe1993@gmail.com>2020-03-16 13:16:04 +0100
commit4b546a43fe14178dcfb2b327e27a580fc9811499 (patch)
tree42376206ef6750113f76a828d1dfeb7b3bfc48eb /src/python/test/test_wasserstein_barycenter.py
parent16e04f4f14cc73956064aef283e29134f00d306c (diff)
update doc (indentation, mention of -1 for the diag) and added a few more tests
Diffstat (limited to 'src/python/test/test_wasserstein_barycenter.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
index a58a4d62..5167cb84 100755
--- a/src/python/test/test_wasserstein_barycenter.py
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -27,19 +27,20 @@ def test_lagrangian_barycenter():
res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
- dg8 = np.array([[0., 4.]])
+ dg8 = np.array([[0., 4.], [4, 8]])
# error crit.
- eps = 0.000001
+ eps = 1e-7
assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps
assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2)))
assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps
Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True)
- assert np.linalg.norm(Y - np.array([[1,3]])) < eps
- assert np.abs(log["energy"] - 2) < eps
- assert np.array_equal(log["groupings"][0] , np.array([[0, -1]]))
- assert np.array_equal(log["groupings"][1] , np.array([[0, 0]]))
- assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3]])) < eps
+ assert np.linalg.norm(Y - np.array([[1,3], [5, 7]])) < eps
+ assert np.abs(log["energy"] - 4) < eps
+ assert np.array_equal(log["groupings"][0] , np.array([[0, -1], [1, -1]]))
+ assert np.array_equal(log["groupings"][1] , np.array([[0, 0], [1, 1]]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg8, dg4], init=np.array([[0.2, 0.6], [0.5, 0.7]]), verbose=False) - np.array([[1, 3], [5, 7]])) < eps
assert lagrangian_barycenter(pdiagset = []) is None
+