summaryrefslogtreecommitdiff
path: root/phstuff/barcode.py
blob: 9c7b05eecf0e2c86cbd4a1a49d16e2a117f381c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-

import numpy as np

class Interval:
    def __init__(self, birth, death = None):
        if death is not None and death < birth:
            raise ValueError("Death must be at least birth.")
        self.birth = birth
        self.death = death

    def is_finite(self):
        return self.death is not None
    
    def __str__(self):
        if self.is_finite():
            return "(%s, %s)" %(str(self.birth), str(self.death))
        else:
            return "(%s, ∞)" %(str(self.birth))

    def __repr__(self):
        if self.is_finite():
            return "Interval(%s, %s)" %(repr(self.birth), repr(self.death))
        else:
            return "Interval(%s)" %(repr(self.birth))

    def size(self):
        if self.is_finite():
            return self.death - self.birth
        else:
            return float("inf")


def finitize(barcode, max_scale):
    ret = []
    for bar in barcode:
        if bar.is_finite():
            ret.append(bar)
        else:
            ret.append(Interval(bar.birth, max_scale))
    return ret
    
def betti_curve(barcode, min_scale, max_scale):
    xs = []
    for interval in pd:
        xs.append((interval.l, 1))
        if interval.is_finite():
            xs.append((interval.u, -1))

    if not xs:
        return np.zeros((2,1))
    
    xs.sort()
    scales = [min_scale]
    bettis = [0]
    
    for (t, delta) in xs:
        if scales[-1] == t:
            bettis[-1] += delta
        else:
            scales.append(t)
            bettis.append(bettis[-1] + delta)

    if scales[-1] < max_scale:
        scales.append(max_scale)
        bettis.append(bettis[-1])
        
    return (np.array(scales), np.array(bettis, dtype=int))