summaryrefslogtreecommitdiff
path: root/test/test_spikes.py
blob: 579f8e1489995af1cf485a32ab01ce34c09c3214 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
""" test_load.py

Test loading of spike trains from text files

Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>

Distributed under the BSD License
"""

from __future__ import print_function
import numpy as np
from numpy.testing import assert_allclose

import pyspike as spk

import os
TEST_PATH = os.path.dirname(os.path.realpath(__file__))
TEST_DATA = os.path.join(TEST_PATH, "PySpike_testdata.txt")

TIME_SERIES_DATA = os.path.join(TEST_PATH, "time_series.txt")
TIME_SERIES_SPIKES = os.path.join(TEST_PATH, "time_series_spike_trains.txt")


def test_load_from_txt():
    spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
    assert len(spike_trains) == 40

    # check the first spike train
    spike_times = [64.886, 305.81, 696, 937.77, 1059.7, 1322.2, 1576.1,
                   1808.1, 2121.5, 2381.1, 2728.6, 2966.9, 3223.7, 3473.7,
                   3644.3, 3936.3]
    assert_allclose(spike_times, spike_trains[0].spikes)

    # check auxiliary spikes
    for spike_train in spike_trains:
        assert spike_train.t_start == 0.0
        assert spike_train.t_end == 4000


def test_load_time_series():
    spike_trains = spk.import_spike_trains_from_time_series(TIME_SERIES_DATA,
                                                            start_time=0,
                                                            time_bin=1)
    assert len(spike_trains) == 40
    spike_trains_check = spk.load_spike_trains_from_txt(TIME_SERIES_SPIKES,
                                                        edges=(0, 4000))

    # check spike trains
    for n in range(len(spike_trains)):
        assert_allclose(spike_trains[n].spikes, spike_trains_check[n].spikes)
        assert_allclose(spike_trains[n].t_start, 0)
        assert_allclose(spike_trains[n].t_end, 4000)


def check_merged_spikes(merged_spikes, spike_trains):
    # create a flat array with all spike events
    all_spikes = np.array([])
    for spike_train in spike_trains:
        all_spikes = np.append(all_spikes, spike_train)
    indices = np.zeros_like(all_spikes, dtype='bool')
    # check if we find all the spike events in the original spike trains
    for x in merged_spikes:
        i = np.where(all_spikes == x)[0][0]  # first axis and first entry
        # change to something impossible so we dont find this event again
        all_spikes[i] = -1.0
        indices[i] = True
    assert indices.all()


def test_merge_spike_trains():
    # first load the data
    spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))

    merged_spikes = spk.merge_spike_trains([spike_trains[0], spike_trains[1]])
    # test if result is sorted
    assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
    # check merging
    check_merged_spikes(merged_spikes.spikes, [spike_trains[0].spikes,
                                               spike_trains[1].spikes])

    merged_spikes = spk.merge_spike_trains(spike_trains)
    # test if result is sorted
    assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
    # check merging
    check_merged_spikes(merged_spikes.spikes,
                        [st.spikes for st in spike_trains])

def test_merge_empty_spike_trains():
    # first load the data
    spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
    # take two non-empty trains, and one empty one
    empty = spk.SpikeTrain([],[spike_trains[0].t_start,spike_trains[0].t_end])
    merged_spikes = spk.merge_spike_trains([spike_trains[0], empty, spike_trains[1]])
    # test if result is sorted
    assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
    # we don't need to check more, that's done by test_merge_spike_trains


if __name__ == "main":
    test_load_from_txt()
    test_merge_spike_trains()