summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authortlacombe <lacombe1993@gmail.com>2019-09-23 11:14:24 +0200
committertlacombe <lacombe1993@gmail.com>2019-09-23 11:14:24 +0200
commit1b007fc59f08bd01e1521eb1c0773b598bdf158b (patch)
tree6a5169be2dd87c28a27cd0054a8d0cd3b251c178 /src
parent36d82a6ffe7c099da9241f7268637feaeef6bf55 (diff)
wasserstein distance added on fork
Diffstat (limited to 'src')
-rw-r--r--src/python/doc/wasserstein_distance_sum.inc14
-rw-r--r--src/python/doc/wasserstein_distance_user.rst39
-rw-r--r--src/python/gudhi/wasserstein.py75
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py22
4 files changed, 150 insertions, 0 deletions
diff --git a/src/python/doc/wasserstein_distance_sum.inc b/src/python/doc/wasserstein_distance_sum.inc
new file mode 100644
index 00000000..0263f80f
--- /dev/null
+++ b/src/python/doc/wasserstein_distance_sum.inc
@@ -0,0 +1,14 @@
+.. table::
+ :widths: 30 50 20
+
+ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
+ | .. figure:: | The p-Wasserstein distance measures the similarity between two | :Author: Theo Lacombe |
+ | ../../doc/Bottleneck_distance/perturb_pd.png | persistence diagrams. It's the minimum value c that can be achieve by| |
+ | :figclass: align-center | a perfect matching between the points of the two diagrams (+ all the | :Introduced in: GUDHI 2.0.0 |
+ | | diagonal points), where the value of a matching is defined as the | |
+ | Wasserstein distance is the p-th root of the sum of the | p-th root of the sum of all edges lengths to the power p. Edges | :Copyright: MIT (`GPL v3 </licensing/>`_) |
+ | edges lengths to the power p. | lengths are measured in norm q, for $1 \leq q \leq \infty$. | |
+ | | | :Requires: `Python Optimal Transport (POT)` |
+ +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+
+ | * :doc:`wasserstein_distance_user` | |
+ +-----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst
new file mode 100644
index 00000000..a51cfb71
--- /dev/null
+++ b/src/python/doc/wasserstein_distance_user.rst
@@ -0,0 +1,39 @@
+:orphan:
+
+.. To get rid of WARNING: document isn't included in any toctree
+
+Wasserstein distance user manual
+===============================
+Definition
+----------
+
+.. include:: wasserstein_distance_sum.inc
+
+This implementation is based on ideas from "Large Scale Computation of Means and Cluster for Persistence Diagrams via Optimal Transport".
+
+Function
+--------
+.. autofunction:: gudhi.wasserstein_distance
+
+
+Basic example
+-------------
+
+This example computes the 1-Wasserstein distance from 2 persistence diagrams with euclidean ground metric.
+Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values.
+
+.. testcode::
+
+ import gudhi
+
+ diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]])
+ diag2 = np.array([[2.8, 4.45],[9.5, 14.1]])
+
+ message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein_distance(diag1, diag2, q=2., p=1.)
+ print(message)
+
+The output is:
+
+.. testoutput::
+
+ Wasserstein distance value = 1.45
diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py
new file mode 100644
index 00000000..cc527ed8
--- /dev/null
+++ b/src/python/gudhi/wasserstein.py
@@ -0,0 +1,75 @@
+import numpy as np
+import scipy.spatial.distance as sc
+try:
+ import ot
+except ImportError:
+ print("POT (Python Optimal Transport) package is not installed. Try to run $ pip install POT")
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2016 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+def proj_on_diag(X):
+ '''
+ param X: (n x 2) array encoding the points of a persistent diagram.
+ return: (n x 2) arary encoding the (respective orthogonal) projections of the points onto the diagonal
+ '''
+ Z = (X[:,0] + X[:,1]) / 2.
+ return np.array([Z , Z]).T
+
+
+def build_dist_matrix(X, Y, p=2., q=2.):
+ '''
+ param X: (n x 2) np.array encoding the (points of the) first diagram.
+ param Y: (m x 2) np.array encoding the second diagram.
+ param q: Ground metric (i.e. norm l_q).
+ param p: exponent for the Wasserstein metric.
+ return: (n+1) x (m+1) np.array encoding the cost matrix C.
+ For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal.
+ note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal).
+ '''
+ Xdiag = proj_on_diag(X)
+ Ydiag = proj_on_diag(Y)
+ if np.isinf(p):
+ C = sc.cdist(X,Y, metric='chebyshev', p=q)**p
+ Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p
+ Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p
+ else:
+ C = sc.cdist(X,Y, metric='minkowski', p=q)**p
+ Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p
+ Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p
+ Cf = np.hstack((C, Cxd[:,None]))
+ Cdy = np.append(Cdy, 0)
+
+ Cf = np.vstack((Cf, Cdy[None,:]))
+
+ return Cf
+
+
+def wasserstein_distance(X, Y, p=2., q=2.):
+ '''
+ param X, Y: (n x 2) and (m x 2) numpy array (points of persistence diagrams)
+ param q: Ground metric (i.e. norm l_q); Default value is 2 (euclidean norm).
+ param p: exponent for Wasserstein; Default value is 2.
+ return: float, the p-Wasserstein distance (1 <= p < infty) with respect to the q-norm as ground metric.
+ '''
+ M = build_dist_matrix(X, Y, p=p, q=q)
+ n = len(X)
+ m = len(Y)
+ a = 1.0 / (n + m) * np.ones(n) # weight vector of the input diagram. Uniform here.
+ hat_a = np.append(a, m/(n+m)) # so that we have a probability measure, required by POT
+ b = 1.0 / (n + m) * np.ones(m) # weight vector of the input diagram. Uniform here.
+ hat_b = np.append(b, n/(m+n)) # so that we have a probability measure, required by POT
+
+ # Comptuation of the otcost using the ot.emd2 library.
+ # Note: it is the squared Wasserstein distance.
+ ot_cost = (n+m) * ot.emd2(hat_a, hat_b, M)
+
+ return np.power(ot_cost, 1./p)
+
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
new file mode 100755
index 00000000..a5f7cf77
--- /dev/null
+++ b/src/python/test/test_wasserstein_distance.py
@@ -0,0 +1,22 @@
+import gudhi
+
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Theo Lacombe
+
+ Copyright (C) 2016 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+__author__ = "Theo Lacombe"
+__copyright__ = "Copyright (C) 2016 Inria"
+__license__ = "MIT"
+
+
+def test_basic_bottleneck():
+ diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]])
+ diag2 = np.array([[2.8, 4.45], [9.5, 14.1]])
+
+ assert gudhi.wasserstein_distance(diag1, diag2) == 1.4453593023967701