From b23813b90aaf1b0ce2b21bdfb33d2a6ea5bfe4cc Mon Sep 17 00:00:00 2001 From: tlacombe Date: Mon, 16 Dec 2019 19:32:26 +0100 Subject: correction test --- src/python/gudhi/barycenter.py | 6 ++++-- src/python/test/test_wasserstein_barycenter.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py index 41418454..b76166c0 100644 --- a/src/python/gudhi/barycenter.py +++ b/src/python/gudhi/barycenter.py @@ -318,10 +318,12 @@ def _sanity_check(verbose): dg1 = np.array([[0.1, 0.12], [0.21, 0.7], [0.4, 0.5], [0.3, 0.4], [0.35, 0.7], [0.5, 0.55], [0.32, 0.42], [0.1, 0.4], [0.2, 0.4]]) dg2 = np.array([[0.09, 0.11], [0.3, 0.43], [0.5, 0.61], [0.3, 0.7], [0.42, 0.5], [0.35, 0.41], [0.74, 0.9], [0.5, 0.95], [0.35, 0.45], [0.13, 0.48], [0.32, 0.45]]) dg3 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]]) - X = [dg3] + dg4 = np.array([]) + X = [dg4] Y, a = lagrangian_barycenter(X, verbose=verbose) - _plot_barycenter(X, Y, a) + #_plot_barycenter(X, Y, a) print(Y) + print(np.array_equal(Y, np.empty(shape=(0,2) ))) #dg1 = np.array([[0.2, 0.5]]) diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py index dc82a57c..910d23ff 100755 --- a/src/python/test/test_wasserstein_barycenter.py +++ b/src/python/test/test_wasserstein_barycenter.py @@ -29,5 +29,5 @@ def test_lagrangian_barycenter(): dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]]) assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < 0.001 - assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.array([], shape=(0,2))) + assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.empty(shape=(0,2))) assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < 0.001 -- cgit v1.2.3