summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authorVincent Rouvreau <10407034+VincentRouvreau@users.noreply.github.com>2019-11-04 17:41:48 +0100
committerGitHub <noreply@github.com>2019-11-04 17:41:48 +0100
commit6e5f3f2c5ed908774c9005fa3ba07694bb2c6b0c (patch)
tree3e9242cf413e1ca63c258dd704ca04049fccf7a8 /src/python/test/test_wasserstein_distance.py
parent8e7fabec7a8b79b8f0248ec580e4cd7950f9cec1 (diff)
parentee4934750e8c9dbdee4874d56921aeb9bf7b7bb7 (diff)
Merge pull request #95 from tlacombe/wdist-theo
wasserstein distance added on fork
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
new file mode 100755
index 00000000..c1b568e2
--- /dev/null
+++ b/src/python/test/test_wasserstein_distance.py
@@ -0,0 +1,50 @@
+import gudhi
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_basic_wasserstein():
+ diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
+ diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
+ diag3 = np.array([[0, 2], [4, 6]])
+ diag4 = np.array([[0, 3], [4, 8]])
+ emptydiag = np.array([[]])
+
+ assert gudhi.wasserstein_distance(emptydiag, emptydiag, q=2., p=1.) == 0.
+ assert gudhi.wasserstein_distance(emptydiag, emptydiag, q=np.inf, p=1.) == 0.
+ assert gudhi.wasserstein_distance(emptydiag, emptydiag, q=np.inf, p=2.) == 0.
+ assert gudhi.wasserstein_distance(emptydiag, emptydiag, q=2., p=2.) == 0.
+
+ assert gudhi.wasserstein_distance(diag3, emptydiag, q=np.inf, p=1.) == 2.
+ assert gudhi.wasserstein_distance(diag3, emptydiag, q=1., p=1.) == 4.
+
+ assert gudhi.wasserstein_distance(diag4, emptydiag, q=1., p=2.) == 5. # thank you Pythagorician triplets
+ assert gudhi.wasserstein_distance(diag4, emptydiag, q=np.inf, p=2.) == 2.5
+ assert gudhi.wasserstein_distance(diag4, emptydiag, q=2., p=2.) == 3.5355339059327378
+
+ assert gudhi.wasserstein_distance(diag1, diag2, q=2., p=1.) == 1.4453593023967701
+ assert gudhi.wasserstein_distance(diag1, diag2, q=2.35, p=1.74) == 0.9772734057168739
+
+ assert gudhi.wasserstein_distance(diag1, emptydiag, q=2.35, p=1.7863) == 3.141592214572228
+
+ assert gudhi.wasserstein_distance(diag3, diag4, q=1., p=1.) == 3.
+ assert gudhi.wasserstein_distance(diag3, diag4, q=np.inf, p=1.) == 3. # no diag matching here
+ assert gudhi.wasserstein_distance(diag3, diag4, q=np.inf, p=2.) == np.sqrt(5)
+ assert gudhi.wasserstein_distance(diag3, diag4, q=1., p=2.) == np.sqrt(5)
+ assert gudhi.wasserstein_distance(diag3, diag4, q=4.5, p=2.) == np.sqrt(5)
+
+
+