From 1783c047302414bbcd6ff4f7c73dcc5a6501fd81 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 20 Jan 2020 17:51:28 +0100 Subject: Share tests for wasserstein_distance --- src/python/test/test_wasserstein_distance.py | 59 ++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 17 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 43dda77e..46a7079f 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -8,41 +8,66 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein import wasserstein_distance +from gudhi.wasserstein import wasserstein_distance as pot +from gudhi.hera import wasserstein_distance as hera import numpy as np __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" - -def test_basic_wasserstein(): +def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) diag3 = np.array([[0, 2], [4, 6]]) diag4 = np.array([[0, 3], [4, 8]]) - emptydiag = np.array([[]]) + emptydiag = np.array([]) + + # We just need to handle positive numbers here + def approx(a, b): + f = 1 + delta + return a <= b*f and b <= a*f assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0. - assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == 2. - assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == 4. + assert approx(wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.), 2.) + assert approx(wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.), 4.) + + assert approx(wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.), 5.) # thank you Pythagorician triplets + assert approx(wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.), 2.5) + assert approx(wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.), 3.5355339059327378) + + assert approx(wasserstein_distance(diag1, diag2, internal_p=2., order=1.) , 1.4453593023967701) + assert approx(wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74), 0.9772734057168739) + + assert approx(wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863), 3.141592214572228) + + assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=1.), 3.) + assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.), 3.) # no diag matching here + assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.), np.sqrt(5)) + assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=2.), np.sqrt(5)) + assert approx(wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.), np.sqrt(5)) + + if(not test_infinity): + return - assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == 5. # thank you Pythagorician triplets - assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == 2.5 - assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == 3.5355339059327378 + diag5 = np.array([[0, 3], [4, np.inf]]) + diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) - assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == 1.4453593023967701 - assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == 0.9772734057168739 + assert wasserstein_distance(diag4, diag5) == np.inf + assert approx(wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf), 4.) - assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == 3.141592214572228 +def hera_wrap(delta): + def fun(*kargs,**kwargs): + return hera(*kargs,**kwargs,delta=delta) + return fun - assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == 3. - assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == 3. # no diag matching here - assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == np.sqrt(5) - assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == np.sqrt(5) - assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == np.sqrt(5) +def test_wasserstein_distance_pot(): + _basic_wasserstein(pot, 1e-15, False) +def test_wasserstein_distance_hera(): + _basic_wasserstein(hera_wrap(1e-12), 1e-12) + _basic_wasserstein(hera_wrap(.1), .1) -- cgit v1.2.3 From e8c908469cb4ac547d4fd46ad8daf5ee21739f58 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Thu, 6 Feb 2020 22:14:08 +0100 Subject: pytest.approx --- src/python/test/test_wasserstein_distance.py | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) (limited to 'src/python/test') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 46a7079f..6a14c50e 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -11,6 +11,7 @@ from gudhi.wasserstein import wasserstein_distance as pot from gudhi.hera import wasserstein_distance as hera import numpy as np +import pytest __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" @@ -24,32 +25,31 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): emptydiag = np.array([]) # We just need to handle positive numbers here - def approx(a, b): - f = 1 + delta - return a <= b*f and b <= a*f + def approx(x): + return pytest.approx(x, rel=delta) assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0. - assert approx(wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.), 2.) - assert approx(wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.), 4.) + assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == approx(2.) + assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == approx(4.) - assert approx(wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.), 5.) # thank you Pythagorician triplets - assert approx(wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.), 2.5) - assert approx(wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.), 3.5355339059327378) + assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == approx(5.) # thank you Pythagorician triplets + assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5) + assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == approx(3.5355339059327378) - assert approx(wasserstein_distance(diag1, diag2, internal_p=2., order=1.) , 1.4453593023967701) - assert approx(wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74), 0.9772734057168739) + assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == approx(1.4453593023967701) + assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739) - assert approx(wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863), 3.141592214572228) + assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228) - assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=1.), 3.) - assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.), 3.) # no diag matching here - assert approx(wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.), np.sqrt(5)) - assert approx(wasserstein_distance(diag3, diag4, internal_p=1., order=2.), np.sqrt(5)) - assert approx(wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.), np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == approx(3.) + assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.) # no diag matching here + assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5)) if(not test_infinity): return @@ -58,7 +58,7 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) assert wasserstein_distance(diag4, diag5) == np.inf - assert approx(wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf), 4.) + assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) def hera_wrap(delta): def fun(*kargs,**kwargs): -- cgit v1.2.3 From 458ee3e95c752f09058d933349851c8a3a730cad Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Fri, 7 Feb 2020 19:41:38 +0100 Subject: Name argument MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Théo Lacombe --- src/python/test/test_wasserstein_distance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src/python/test') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 6a14c50e..4bc7114e 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -66,7 +66,7 @@ def hera_wrap(delta): return fun def test_wasserstein_distance_pot(): - _basic_wasserstein(pot, 1e-15, False) + _basic_wasserstein(pot, 1e-15, test_infinity=False) def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(1e-12), 1e-12) -- cgit v1.2.3 From d6f3165831d20bf3a91f1ff7e9734a574eaa567a Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Tue, 11 Feb 2020 13:06:48 +0100 Subject: License and author --- src/python/gudhi/hera.cc | 13 +++++++++++-- src/python/test/test_wasserstein_distance.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) (limited to 'src/python/test') diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc index 61f0da10..0d562b4c 100644 --- a/src/python/gudhi/hera.cc +++ b/src/python/gudhi/hera.cc @@ -1,9 +1,19 @@ +/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + * Author(s): Marc Glisse + * + * Copyright (C) 2020 Inria + * + * Modification(s): + * - YYYY/MM Author: Description of the modification + */ + #include #include #include -#include +#include // Hera #include @@ -41,7 +51,6 @@ double wasserstein_distance( PYBIND11_MODULE(hera, m) { m.def("wasserstein_distance", &wasserstein_distance, py::arg("X"), py::arg("Y"), - // Should we name those q, p and d instead? py::arg("order") = 1, py::arg("internal_p") = std::numeric_limits::infinity(), py::arg("delta") = .01, diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 4bc7114e..6a6b217b 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -1,6 +1,6 @@ """ This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. - Author(s): Theo Lacombe + Author(s): Theo Lacombe, Marc Glisse Copyright (C) 2019 Inria -- cgit v1.2.3