summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2016-06-18 16:32:39 -0700
committerMario Mulansky <mario.mulansky@gmx.net>2016-06-18 16:32:39 -0700
commita63839708d6574567bcd518a822a9f10629ed80b (patch)
tree02c28e0856b62c08fd4d3622ee68d145d8eee6eb /test
parentd9e2125dcc76693056ab04264add29227f398f4f (diff)
parent4691d0e77a024fbc73d1098ee557d65f8f2ddc89 (diff)
Merge branch 'import_time_series' into develop
Conflicts: pyspike/__init__.py resolved
Diffstat (limited to 'test')
-rw-r--r--test/test_generic_interfaces.py105
-rw-r--r--test/test_regression/test_regression_15.py1
-rw-r--r--test/test_spikes.py19
3 files changed, 125 insertions, 0 deletions
diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py
new file mode 100644
index 0000000..7f08067
--- /dev/null
+++ b/test/test_generic_interfaces.py
@@ -0,0 +1,105 @@
+""" test_generic_interface.py
+
+Tests the generic interfaces of the profile and distance functions
+
+Copyright 2016, Mario Mulansky <mario.mulansky@gmx.net>
+
+Distributed under the BSD License
+
+"""
+
+from __future__ import print_function
+from numpy.testing import assert_equal
+
+import pyspike as spk
+from pyspike import SpikeTrain
+
+
+class dist_from_prof:
+ """ Simple functor that turns profile function into distance function by
+ calling profile.avrg().
+ """
+ def __init__(self, prof_func):
+ self.prof_func = prof_func
+
+ def __call__(self, *args, **kwargs):
+ if "interval" in kwargs:
+ # forward interval arg into avrg function
+ interval = kwargs.pop("interval")
+ return self.prof_func(*args, **kwargs).avrg(interval=interval)
+ else:
+ return self.prof_func(*args, **kwargs).avrg()
+
+
+def check_func(dist_func):
+ """ generic checker that tests the given distance function.
+ """
+ # generate spike trains:
+ t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
+ t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0)
+ t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0)
+ t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0)
+ spike_trains = [t1, t2, t3, t4]
+
+ isi12 = dist_func(t1, t2)
+ isi12_ = dist_func([t1, t2])
+ assert_equal(isi12, isi12_)
+
+ isi12_ = dist_func(spike_trains, indices=[0, 1])
+ assert_equal(isi12, isi12_)
+
+ isi123 = dist_func(t1, t2, t3)
+ isi123_ = dist_func([t1, t2, t3])
+ assert_equal(isi123, isi123_)
+
+ isi123_ = dist_func(spike_trains, indices=[0, 1, 2])
+ assert_equal(isi123, isi123_)
+
+ # run the same test with an additional interval parameter
+
+ isi12 = dist_func(t1, t2, interval=[0.0, 0.5])
+ isi12_ = dist_func([t1, t2], interval=[0.0, 0.5])
+ assert_equal(isi12, isi12_)
+
+ isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5])
+ assert_equal(isi12, isi12_)
+
+ isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5])
+ isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5])
+ assert_equal(isi123, isi123_)
+
+ isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5])
+ assert_equal(isi123, isi123_)
+
+
+def test_isi_profile():
+ check_func(dist_from_prof(spk.isi_profile))
+
+
+def test_isi_distance():
+ check_func(spk.isi_distance)
+
+
+def test_spike_profile():
+ check_func(dist_from_prof(spk.spike_profile))
+
+
+def test_spike_distance():
+ check_func(spk.spike_distance)
+
+
+def test_spike_sync_profile():
+ check_func(dist_from_prof(spk.spike_sync_profile))
+
+
+def test_spike_sync():
+ check_func(spk.spike_sync)
+
+
+if __name__ == "__main__":
+ test_isi_profile()
+ test_isi_distance()
+ test_spike_profile()
+ test_spike_distance()
+ test_spike_sync_profile()
+ test_spike_sync()
diff --git a/test/test_regression/test_regression_15.py b/test/test_regression/test_regression_15.py
index dcacae2..54adf23 100644
--- a/test/test_regression/test_regression_15.py
+++ b/test/test_regression/test_regression_15.py
@@ -20,6 +20,7 @@ import os
TEST_PATH = os.path.dirname(os.path.realpath(__file__))
TEST_DATA = os.path.join(TEST_PATH, "..", "SPIKE_Sync_Test.txt")
+
def test_regression_15_isi():
# load spike trains
spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=[0, 4000])
diff --git a/test/test_spikes.py b/test/test_spikes.py
index 609a819..bcface2 100644
--- a/test/test_spikes.py
+++ b/test/test_spikes.py
@@ -17,6 +17,10 @@ import os
TEST_PATH = os.path.dirname(os.path.realpath(__file__))
TEST_DATA = os.path.join(TEST_PATH, "PySpike_testdata.txt")
+TIME_SERIES_DATA = os.path.join(TEST_PATH, "time_series.txt")
+TIME_SERIES_SPIKES = os.path.join(TEST_PATH, "time_series_spike_trains.txt")
+
+
def test_load_from_txt():
spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=(0, 4000))
assert len(spike_trains) == 40
@@ -33,6 +37,21 @@ def test_load_from_txt():
assert spike_train.t_end == 4000
+def test_load_time_series():
+ spike_trains = spk.import_spike_trains_from_time_series(TIME_SERIES_DATA,
+ start_time=0,
+ time_bin=1)
+ assert len(spike_trains) == 40
+ spike_trains_check = spk.load_spike_trains_from_txt(TIME_SERIES_SPIKES,
+ edges=(0, 4000))
+
+ # check spike trains
+ for n in range(len(spike_trains)):
+ assert_equal(spike_trains[n].spikes, spike_trains_check[n].spikes)
+ assert_equal(spike_trains[n].t_start, 0)
+ assert_equal(spike_trains[n].t_end, 4000)
+
+
def check_merged_spikes(merged_spikes, spike_trains):
# create a flat array with all spike events
all_spikes = np.array([])