summaryrefslogtreecommitdiff
path: root/phstuff/barcode.py
blob: adce5e5498afb4d4f72ce7bdd9b869b2880c912e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib

class Interval:
    def __init__(self, birth, death = None):
        if death is not None and death < birth:
            raise ValueError("Death must be at least birth.")
        self.birth = birth
        self.death = death

    def is_finite(self):
        return self.death is not None
    
    def __str__(self):
        if self.is_finite():
            return "(%s, %s)" %(str(self.birth), str(self.death))
        else:
            return "(%s, ∞)" %(str(self.birth))

    def __repr__(self):
        if self.is_finite():
            return "Interval(%s, %s)" %(repr(self.birth), repr(self.death))
        else:
            return "Interval(%s)" %(repr(self.birth))

    def size(self):
        if self.is_finite():
            return self.death - self.birth
        else:
            return float("inf")


def finitize(barcode, max_scale):
    ret = []
    for bar in barcode:
        if bar.is_finite():
            ret.append(bar)
        else:
            ret.append(Interval(bar.birth, max_scale))
    return ret
    
def betti_curve(barcode, min_scale, max_scale):
    xs = []
    for interval in barcode:
        xs.append((interval.birth, 1))
        if interval.is_finite():
            xs.append((interval.death, -1))

    if not xs:
        return np.zeros((2,1))
    
    xs.sort()
    scales = [min_scale]
    bettis = [0]
    
    for (t, delta) in xs:
        if scales[-1] == t:
            bettis[-1] += delta
        else:
            scales.append(t)
            bettis.append(bettis[-1] + delta)

    if scales[-1] < max_scale:
        scales.append(max_scale)
        bettis.append(bettis[-1])
        
    return (np.array(scales), np.array(bettis, dtype=int))


def plot(ax, barcode, min_scale, max_scale, color = "black", finite_marker = ".", infinite_marker = "^", alpha = 1.0):
    finites = []
    infinites = []

    for interval in barcode:
        if interval.is_finite():
            finites.append((interval.birth, interval.death))
        else:
            infinites.append((interval.birth, max_scale))

    finites = np.array(finites)
    infinites = np.array(infinites)

    if finites.shape[0] > 0:
        ax.plot(finites[:, 0], finites[:, 1], color = color, marker = finite_marker, linestyle = "None", alpha = alpha, zorder=10)
        
    if infinites.shape[0] > 0:
        ax.plot(infinites[:, 0], infinites[:, 1], color = color, marker = infinite_marker, linestyle = "None", alpha = alpha, zorder=10)

    
    below_diag_patch = ax.add_patch(
        matplotlib.patches.Polygon([[min_scale, min_scale],
                                    [max_scale, min_scale],
                                    [max_scale, max_scale]],
                                   closed = True, fill = True, color = "lightgray"))
    min_scale_line = ax.add_line(matplotlib.lines.Line2D([min_scale, min_scale],
                                                         [min_scale, max_scale],
                                                         color = "black", linewidth=0.5))
    max_scale_line = ax.add_line(matplotlib.lines.Line2D([min_scale, max_scale],
                                                         [max_scale, max_scale],
                                                         color = "black", linewidth=0.5))

    def callback(cax):
        (xmin, xmax) = cax.get_xlim()
        (ymin, ymax) = cax.get_ylim()
        low = min(xmin, ymin)
        high = max(xmax, ymax)
        below_diag_patch.set_xy([[low, low],
                                 [high, low],
                                 [high, high]])

    ax.callbacks.connect('xlim_changed', callback)
    ax.callbacks.connect('ylim_changed', callback)
    ax.set_xlim(min_scale, max_scale)
    ax.set_ylim(min_scale, max_scale)

    ax.set_xlabel("Birth")
    ax.set_ylabel("Death")