summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2019-12-16 17:53:59 +0100
committertlacombe <lacombe1993@gmail.com>2019-12-16 17:53:59 +0100
commit5877b4d3b7aca645ba906dfe0be598b1881d8798 (patch)
tree46e2edb8322b535a7ddb710f2924be7079580425
parentefebda596ae5a03dd0f15317ebfe74b5f19c78aa (diff)
update CMakeLists and create test_wasserstein_bary
-rw-r--r--src/python/CMakeLists.txt3
-rw-r--r--src/python/gudhi/barycenter.py26
-rwxr-xr-xsrc/python/test/test_wasserstein_barycenter.py33
3 files changed, 50 insertions, 12 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt
index 9af85eac..7f9ff38f 100644
--- a/src/python/CMakeLists.txt
+++ b/src/python/CMakeLists.txt
@@ -52,6 +52,7 @@ if(PYTHONINTERP_FOUND)
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
+ set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'barycenter', ")
add_gudhi_debug_info("Python version ${PYTHON_VERSION_STRING}")
add_gudhi_debug_info("Cython version ${CYTHON_VERSION}")
@@ -210,6 +211,7 @@ if(PYTHONINTERP_FOUND)
file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/representations" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/")
file(COPY "gudhi/wasserstein.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
+ file(COPY "gudhi/barycenter.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
add_custom_command(
OUTPUT gudhi.so
@@ -385,6 +387,7 @@ if(PYTHONINTERP_FOUND)
# Wasserstein
if(OT_FOUND)
add_gudhi_py_test(test_wasserstein_distance)
+ add_gudhi_py_test(test_wasserstein_barycenter)
endif(OT_FOUND)
# Representations
diff --git a/src/python/gudhi/barycenter.py b/src/python/gudhi/barycenter.py
index b4afdb6a..41418454 100644
--- a/src/python/gudhi/barycenter.py
+++ b/src/python/gudhi/barycenter.py
@@ -293,12 +293,12 @@ def _test_perf():
def _sanity_check(verbose):
- dg1 = np.array([[0.2, 0.5]])
- dg2 = np.array([[0.2, 0.7], [0.73, 0.88]])
- dg3 = np.array([[0.3, 0.6], [0.7, 0.85], [0.2, 0.3]])
- X = [dg1, dg2, dg3]
- Y, a = lagrangian_barycenter(X, verbose=verbose)
- _plot_barycenter(X, Y, a)
+ #dg1 = np.array([[0.2, 0.5]])
+ #dg2 = np.array([[0.2, 0.7], [0.73, 0.88]])
+ #dg3 = np.array([[0.3, 0.6], [0.7, 0.85], [0.2, 0.3]])
+ #X = [dg1, dg2, dg3]
+ #Y, a = lagrangian_barycenter(X, verbose=verbose)
+ #_plot_barycenter(X, Y, a)
#dg1 = np.array([[0.2, 0.5]])
#dg2 = np.array([]) # The empty diagram
@@ -313,13 +313,15 @@ def _sanity_check(verbose):
#X = [dg1, dg2, dg3]
#Y, a = lagrangian_barycenter(X, verbose=verbose)
#_plot_barycenter(X, Y, a)
+ #print(Y)
- #dg1 = np.array([[0.1, 0.12], [0.21, 0.7], [0.4, 0.5], [0.3, 0.4], [0.35, 0.7], [0.5, 0.55], [0.32, 0.42], [0.1, 0.4], [0.2, 0.4]])
- #dg2 = np.array([[0.09, 0.11], [0.3, 0.43], [0.5, 0.61], [0.3, 0.7], [0.42, 0.5], [0.35, 0.41], [0.74, 0.9], [0.5, 0.95], [0.35, 0.45], [0.13, 0.48], [0.32, 0.45]])
- #dg3 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
- #X = [dg1, dg2, dg3]
- #Y, a = lagrangian_barycenter(X, init=1, verbose=verbose)
- #_plot_barycenter(X, Y, a)
+ dg1 = np.array([[0.1, 0.12], [0.21, 0.7], [0.4, 0.5], [0.3, 0.4], [0.35, 0.7], [0.5, 0.55], [0.32, 0.42], [0.1, 0.4], [0.2, 0.4]])
+ dg2 = np.array([[0.09, 0.11], [0.3, 0.43], [0.5, 0.61], [0.3, 0.7], [0.42, 0.5], [0.35, 0.41], [0.74, 0.9], [0.5, 0.95], [0.35, 0.45], [0.13, 0.48], [0.32, 0.45]])
+ dg3 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+ X = [dg3]
+ Y, a = lagrangian_barycenter(X, verbose=verbose)
+ _plot_barycenter(X, Y, a)
+ print(Y)
#dg1 = np.array([[0.2, 0.5]])
diff --git a/src/python/test/test_wasserstein_barycenter.py b/src/python/test/test_wasserstein_barycenter.py
new file mode 100755
index 00000000..6074f250
--- /dev/null
+++ b/src/python/test/test_wasserstein_barycenter.py
@@ -0,0 +1,33 @@
+from gudhi.barycenter import lagrangian_barycenter
+import numpy as np
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2019 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2019 Inria"
+__license__ = "MIT"
+
+
+def test_lagrangian_barycenter():
+
+ dg1 = np.array([[0.2, 0.5]])
+ dg2 = np.array([[0.2, 0.7]])
+ dg3 = np.array([[0.3, 0.6], [0.7, 0.8], [0.2, 0.3]])
+ dg4 = np.array([])
+ dg5 = np.array([])
+ dg6 = np.array([])
+ res = np.array([[0.27916667, 0.55416667], [0.7375, 0.7625], [0.2375, 0.2625]])
+
+ dg7 = np.array([[0.1, 0.15], [0.1, 0.7], [0.2, 0.22], [0.55, 0.84], [0.11, 0.91], [0.61, 0.75], [0.33, 0.46], [0.12, 0.41], [0.32, 0.48]])
+
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg1, dg2, dg3, dg4],init=3, verbose=False) - res) < 0.001
+ assert np.array_equal(lagrangian_barycenter(pdiagset=[dg4, dg5, dg6], verbose=False), np.array([]))
+ assert np.linalg.norm(lagrangian_barycenter(pdiagset=[dg7], verbose=False) - dg7) < 0.001