summaryrefslogtreecommitdiff
path: root/test/test_generic_interfaces.py
blob: 7f080671a7e1604e8aef5cb301613d9c4ba09d5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
""" test_generic_interface.py

Tests the generic interfaces of the profile and distance functions

Copyright 2016, Mario Mulansky <mario.mulansky@gmx.net>

Distributed under the BSD License

"""

from __future__ import print_function
from numpy.testing import assert_equal

import pyspike as spk
from pyspike import SpikeTrain


class dist_from_prof:
    """ Simple functor that turns profile function into distance function by
    calling profile.avrg().
    """
    def __init__(self, prof_func):
        self.prof_func = prof_func

    def __call__(self, *args, **kwargs):
        if "interval" in kwargs:
            # forward interval arg into avrg function
            interval = kwargs.pop("interval")
            return self.prof_func(*args, **kwargs).avrg(interval=interval)
        else:
            return self.prof_func(*args, **kwargs).avrg()


def check_func(dist_func):
    """ generic checker that tests the given distance function.
    """
    # generate spike trains:
    t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
    t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0)
    t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0)
    t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0)
    spike_trains = [t1, t2, t3, t4]

    isi12 = dist_func(t1, t2)
    isi12_ = dist_func([t1, t2])
    assert_equal(isi12, isi12_)

    isi12_ = dist_func(spike_trains, indices=[0, 1])
    assert_equal(isi12, isi12_)

    isi123 = dist_func(t1, t2, t3)
    isi123_ = dist_func([t1, t2, t3])
    assert_equal(isi123, isi123_)

    isi123_ = dist_func(spike_trains, indices=[0, 1, 2])
    assert_equal(isi123, isi123_)

    # run the same test with an additional interval parameter

    isi12 = dist_func(t1, t2, interval=[0.0, 0.5])
    isi12_ = dist_func([t1, t2], interval=[0.0, 0.5])
    assert_equal(isi12, isi12_)

    isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5])
    assert_equal(isi12, isi12_)

    isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5])
    isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5])
    assert_equal(isi123, isi123_)

    isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5])
    assert_equal(isi123, isi123_)


def test_isi_profile():
    check_func(dist_from_prof(spk.isi_profile))


def test_isi_distance():
    check_func(spk.isi_distance)


def test_spike_profile():
    check_func(dist_from_prof(spk.spike_profile))


def test_spike_distance():
    check_func(spk.spike_distance)


def test_spike_sync_profile():
    check_func(dist_from_prof(spk.spike_sync_profile))


def test_spike_sync():
    check_func(spk.spike_sync)


if __name__ == "__main__":
    test_isi_profile()
    test_isi_distance()
    test_spike_profile()
    test_spike_distance()
    test_spike_sync_profile()
    test_spike_sync()