# -*- coding: utf-8 -*- import numpy as np import matplotlib class Interval: def __init__(self, birth, death = None): if death is not None and death < birth: raise ValueError("Death must be at least birth.") self.birth = birth self.death = death def is_finite(self): return self.death is not None def __str__(self): if self.is_finite(): return "(%s, %s)" %(str(self.birth), str(self.death)) else: return "(%s, ∞)" %(str(self.birth)) def __repr__(self): if self.is_finite(): return "Interval(%s, %s)" %(repr(self.birth), repr(self.death)) else: return "Interval(%s)" %(repr(self.birth)) def size(self): if self.is_finite(): return self.death - self.birth else: return float("inf") def finitize(barcode, max_scale): ret = [] for bar in barcode: if bar.is_finite(): ret.append(bar) else: ret.append(Interval(bar.birth, max_scale)) return ret def betti_curve(barcode, min_scale, max_scale): xs = [] for interval in barcode: xs.append((interval.birth, 1)) if interval.is_finite(): xs.append((interval.death, -1)) if not xs: return np.zeros((2,1)) xs.sort() scales = [min_scale] bettis = [0] for (t, delta) in xs: if scales[-1] == t: bettis[-1] += delta else: scales.append(t) bettis.append(bettis[-1] + delta) if scales[-1] < max_scale: scales.append(max_scale) bettis.append(bettis[-1]) return (np.array(scales), np.array(bettis, dtype=int)) def plot(ax, barcode, min_scale, max_scale, color = "black", finite_marker = ".", infinite_marker = "^", alpha = 1.0): finites = [] infinites = [] for interval in barcode: if interval.is_finite(): finites.append((interval.birth, interval.death)) else: infinites.append((interval.birth, max_scale)) finites = np.array(finites) infinites = np.array(infinites) if finites.shape[0] > 0: ax.scatter(finites[:, 0], finites[:, 1], color = color, marker = finite_marker, alpha = alpha, zorder=10) if infinites.shape[0] > 0: ax.scatter(infinites[:, 0], infinites[:, 1], color = color, marker = infinite_marker, alpha = alpha, zorder=10) below_diag_patch = ax.add_patch( matplotlib.patches.Polygon([[min_scale, min_scale], [max_scale, min_scale], [max_scale, max_scale]], closed = True, fill = True, color = "lightgray")) min_scale_line = ax.add_line(matplotlib.lines.Line2D([min_scale, min_scale], [min_scale, max_scale], color = "black", linewidth=0.5)) max_scale_line = ax.add_line(matplotlib.lines.Line2D([min_scale, max_scale], [max_scale, max_scale], color = "black", linewidth=0.5)) def callback(cax): (xmin, xmax) = cax.get_xlim() (ymin, ymax) = cax.get_ylim() low = min(xmin, ymin) high = max(xmax, ymax) below_diag_patch.set_xy([[low, low], [high, low], [high, high]]) ax.callbacks.connect('xlim_changed', callback) ax.callbacks.connect('ylim_changed', callback) ax.set_xlim(min_scale, max_scale) ax.set_ylim(min_scale, max_scale) ax.set_xlabel("Birth") ax.set_ylabel("Death")