diff options
Diffstat (limited to 'test/test_generic_interfaces.py')
-rw-r--r-- | test/test_generic_interfaces.py | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py new file mode 100644 index 0000000..7f08067 --- /dev/null +++ b/test/test_generic_interfaces.py @@ -0,0 +1,105 @@ +""" test_generic_interface.py + +Tests the generic interfaces of the profile and distance functions + +Copyright 2016, Mario Mulansky <mario.mulansky@gmx.net> + +Distributed under the BSD License + +""" + +from __future__ import print_function +from numpy.testing import assert_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +class dist_from_prof: + """ Simple functor that turns profile function into distance function by + calling profile.avrg(). + """ + def __init__(self, prof_func): + self.prof_func = prof_func + + def __call__(self, *args, **kwargs): + if "interval" in kwargs: + # forward interval arg into avrg function + interval = kwargs.pop("interval") + return self.prof_func(*args, **kwargs).avrg(interval=interval) + else: + return self.prof_func(*args, **kwargs).avrg() + + +def check_func(dist_func): + """ generic checker that tests the given distance function. + """ + # generate spike trains: + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) + spike_trains = [t1, t2, t3, t4] + + isi12 = dist_func(t1, t2) + isi12_ = dist_func([t1, t2]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3) + isi123_ = dist_func([t1, t2, t3]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) + assert_equal(isi123, isi123_) + + # run the same test with an additional interval parameter + + isi12 = dist_func(t1, t2, interval=[0.0, 0.5]) + isi12_ = dist_func([t1, t2], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5]) + isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + + +def test_isi_profile(): + check_func(dist_from_prof(spk.isi_profile)) + + +def test_isi_distance(): + check_func(spk.isi_distance) + + +def test_spike_profile(): + check_func(dist_from_prof(spk.spike_profile)) + + +def test_spike_distance(): + check_func(spk.spike_distance) + + +def test_spike_sync_profile(): + check_func(dist_from_prof(spk.spike_sync_profile)) + + +def test_spike_sync(): + check_func(spk.spike_sync) + + +if __name__ == "__main__": + test_isi_profile() + test_isi_distance() + test_spike_profile() + test_spike_distance() + test_spike_sync_profile() + test_spike_sync() |