summaryrefslogtreecommitdiff
path: root/test/test_merge_spikes.py
blob: 3162700d253142982d112302cb854040a4c0a42a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
""" test_merge_spikes.py

Tests merging spikes

Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
"""
from __future__ import print_function
import numpy as np

import pyspike as spk

def check_merged_spikes( merged_spikes, spike_trains ):
    # create a flat array with all spike events
    all_spikes = np.array([])
    for spike_train in spike_trains:
        all_spikes = np.append(all_spikes, spike_train)
    indices = np.zeros_like(all_spikes, dtype='bool')
    # check if we find all the spike events in the original spike trains
    for x in merged_spikes:
        i = np.where(all_spikes == x)[0][0] # the first axis and the first entry
        # change to something impossible so we dont find this event again
        all_spikes[i] = -1.0
        indices[i] = True
    assert( indices.all() )

def test_merge_spike_trains():

    # first load the data
    spike_trains = []
    spike_file = open("SPIKY_testdata.txt", 'r')
    for line in spike_file:
        spike_trains.append(spk.spike_train_from_string(line))
        
    spikes = spk.merge_spike_trains([spike_trains[0], spike_trains[1]])
    # test if result is sorted
    assert((spikes == np.sort(spikes)).all())
    # check merging
    check_merged_spikes( spikes, [spike_trains[0], spike_trains[1]] )

    spikes = spk.merge_spike_trains(spike_trains)
    # test if result is sorted
    assert((spikes == np.sort(spikes)).all())
    # check merging
    check_merged_spikes( spikes, spike_trains )


if __name__ == "main":
    test_merge_spike_trains()