summaryrefslogtreecommitdiff
path: root/test/test_spikes.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_spikes.py')
-rw-r--r--test/test_spikes.py101
1 files changed, 101 insertions, 0 deletions
diff --git a/test/test_spikes.py b/test/test_spikes.py
new file mode 100644
index 0000000..ee505b5
--- /dev/null
+++ b/test/test_spikes.py
@@ -0,0 +1,101 @@
+""" test_load.py
+
+Test loading of spike trains from text files
+
+Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+"""
+
+from __future__ import print_function
+import numpy as np
+from numpy.testing import assert_equal
+
+import pyspike as spk
+
+import os
+TEST_PATH = os.path.dirname(os.path.realpath(__file__))
+TEST_DATA = os.path.join(TEST_PATH, "PySpike_testdata.txt")
+
+TIME_SERIES_DATA = os.path.join(TEST_PATH, "time_series.txt")
+TIME_SERIES_SPIKES = os.path.join(TEST_PATH, "time_series_spike_trains.txt")
+
+
+def test_load_from_txt():
+ spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
+ assert len(spike_trains) == 40
+
+ # check the first spike train
+ spike_times = [64.886, 305.81, 696, 937.77, 1059.7, 1322.2, 1576.1,
+ 1808.1, 2121.5, 2381.1, 2728.6, 2966.9, 3223.7, 3473.7,
+ 3644.3, 3936.3]
+ assert_equal(spike_times, spike_trains[0].spikes)
+
+ # check auxiliary spikes
+ for spike_train in spike_trains:
+ assert spike_train.t_start == 0.0
+ assert spike_train.t_end == 4000
+
+
+def test_load_time_series():
+ spike_trains = spk.import_spike_trains_from_time_series(TIME_SERIES_DATA,
+ start_time=0,
+ time_bin=1)
+ assert len(spike_trains) == 40
+ spike_trains_check = spk.load_spike_trains_from_txt(TIME_SERIES_SPIKES,
+ edges=(0, 4000))
+
+ # check spike trains
+ for n in range(len(spike_trains)):
+ assert_equal(spike_trains[n].spikes, spike_trains_check[n].spikes)
+ assert_equal(spike_trains[n].t_start, 0)
+ assert_equal(spike_trains[n].t_end, 4000)
+
+
+def check_merged_spikes(merged_spikes, spike_trains):
+ # create a flat array with all spike events
+ all_spikes = np.array([])
+ for spike_train in spike_trains:
+ all_spikes = np.append(all_spikes, spike_train)
+ indices = np.zeros_like(all_spikes, dtype='bool')
+ # check if we find all the spike events in the original spike trains
+ for x in merged_spikes:
+ i = np.where(all_spikes == x)[0][0] # first axis and first entry
+ # change to something impossible so we dont find this event again
+ all_spikes[i] = -1.0
+ indices[i] = True
+ assert indices.all()
+
+
+def test_merge_spike_trains():
+ # first load the data
+ spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
+
+ merged_spikes = spk.merge_spike_trains([spike_trains[0], spike_trains[1]])
+ # test if result is sorted
+ assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
+ # check merging
+ check_merged_spikes(merged_spikes.spikes, [spike_trains[0].spikes,
+ spike_trains[1].spikes])
+
+ merged_spikes = spk.merge_spike_trains(spike_trains)
+ # test if result is sorted
+ assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
+ # check merging
+ check_merged_spikes(merged_spikes.spikes,
+ [st.spikes for st in spike_trains])
+
+def test_merge_empty_spike_trains():
+ # first load the data
+ spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
+ # take two non-empty trains, and one empty one
+ empty = spk.SpikeTrain([],[spike_trains[0].t_start,spike_trains[0].t_end])
+ merged_spikes = spk.merge_spike_trains([spike_trains[0], empty, spike_trains[1]])
+ # test if result is sorted
+ assert((merged_spikes.spikes == np.sort(merged_spikes.spikes)).all())
+ # we don't need to check more, that's done by test_merge_spike_trains
+
+
+if __name__ == "main":
+ test_load_from_txt()
+ test_merge_spike_trains()