summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGard Spreemann <gspreemann@gmail.com>2016-03-22 17:54:43 +0100
committerGard Spreemann <gspreemann@gmail.com>2016-03-22 17:54:43 +0100
commit2579d2d7c30465ee0ffa66371f36a656d2b4aa2e (patch)
tree282a1463072897b6d6cab2eda4c9da8a7ddfafc3
parentf95c4cd479caabf54d425b3f62901f977098abe0 (diff)
Persistence diagram plotting.
-rw-r--r--phstuff/barcode.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/phstuff/barcode.py b/phstuff/barcode.py
index 9c7b05e..9442a87 100644
--- a/phstuff/barcode.py
+++ b/phstuff/barcode.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np
+import matplotlib
class Interval:
def __init__(self, birth, death = None):
@@ -66,3 +67,53 @@ def betti_curve(barcode, min_scale, max_scale):
bettis.append(bettis[-1])
return (np.array(scales), np.array(bettis, dtype=int))
+
+
+def plot(ax, barcode, min_scale, max_scale, color = "black", finite_marker = ".", infinite_marker = "^", alpha = 1.0):
+ finites = []
+ infinites = []
+
+ for interval in barcode:
+ if interval.is_finite():
+ finites.append((interval.birth, interval.death))
+ else:
+ infinites.append((interval.birth, max_scale))
+
+ finites = np.array(finites)
+ infinites = np.array(infinites)
+
+ if finites.shape[0] > 0:
+ ax.scatter(finites[:, 0], finites[:, 1], color = color, marker = finite_marker, alpha = alpha, zorder=10)
+
+ if infinites.shape[0] > 0:
+ ax.scatter(infinites[:, 0], infinites[:, 1], color = color, marker = infinite_marker, alpha = alpha, zorder=10)
+
+
+ below_diag_patch = ax.add_patch(
+ matplotlib.patches.Polygon([[min_scale, min_scale],
+ [max_scale, min_scale],
+ [max_scale, max_scale]],
+ closed = True, fill = True, color = "lightgray"))
+ min_scale_line = ax.add_line(matplotlib.lines.Line2D([min_scale, min_scale],
+ [min_scale, max_scale],
+ color = "black", linewidth=0.5))
+ max_scale_line = ax.add_line(matplotlib.lines.Line2D([min_scale, max_scale],
+ [max_scale, max_scale],
+ color = "black", linewidth=0.5))
+
+ def callback(cax):
+ (xmin, xmax) = cax.get_xlim()
+ (ymin, ymax) = cax.get_ylim()
+ low = min(xmin, ymin)
+ high = max(xmax, ymax)
+ below_diag_patch.set_xy([[low, low],
+ [high, low],
+ [high, high]])
+
+ ax.callbacks.connect('xlim_changed', callback)
+ ax.callbacks.connect('ylim_changed', callback)
+ ax.set_xlim(min_scale, max_scale)
+ ax.set_ylim(min_scale, max_scale)
+
+ ax.set_xlabel("Birth")
+ ax.set_ylabel("Death")