diff options
Diffstat (limited to 'src/python/test')
-rwxr-xr-x | src/python/test/test_wasserstein_barycenter.py | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py index 910d23ff..07242582 100755 --- a/src/python/test/test_wasserstein_barycenter.py +++ b/src/python/test/test_wasserstein_barycenter.py @@ -27,7 +27,18 @@ def test_lagrangian_barycenter(): res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]]) dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]]) + dg8 = np.array([[0., 4.]]) + + # error crit. + eps = 0.000001 - assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < 0.001 + + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < eps assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2))) - assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < 0.001 + assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < eps + Y, log = lagrangian_barycenter(pdiagset=[dg4, dg8], verbose=True) + assert np.linalg.norm(Y - np.array([[1,3]])) < eps + assert np.abs(log["energy"] - 2) < eps + assert np.array_equal(log["groupings"][0] , np.array([[0, -1]])) + assert np.array_equal(log["groupings"][1] , np.array([[0, 0]])) + assert lagrangian_barycenter(pdiagset = []) is None |