diff options
author | Mario Mulansky <mario.mulansky@gmx.net> | 2016-01-31 15:05:21 +0100 |
---|---|---|
committer | Mario Mulansky <mario.mulansky@gmx.net> | 2016-01-31 15:05:21 +0100 |
commit | 5a556a11fbf8434bd38fa73e05054d581018a4da (patch) | |
tree | 6114d1c64d29484a4f4fe3a7309f16f8d437f372 /test/test_generic_interfaces.py | |
parent | 2f48f27b55f63726216b6e674fb88b3790b59147 (diff) |
generalized interface to isi profile and distance
isi profile and distance functionc an now compute bi-variate and multi-variate
results. Therefore, it can be called with different "overloads".
Diffstat (limited to 'test/test_generic_interfaces.py')
-rw-r--r-- | test/test_generic_interfaces.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py new file mode 100644 index 0000000..caa9ee4 --- /dev/null +++ b/test/test_generic_interfaces.py @@ -0,0 +1,64 @@ +""" test_isi_interface.py + +Tests the generic interfaces of the profile and distance functions + +Copyright 2016, Mario Mulansky <mario.mulansky@gmx.net> + +Distributed under the BSD License + +""" + +from __future__ import print_function +from numpy.testing import assert_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +class dist_from_prof: + """ Simple functor that turns profile function into distance function by + calling profile.avrg(). + """ + def __init__(self, prof_func): + self.prof_func = prof_func + + def __call__(self, *args, **kwargs): + return self.prof_func(*args, **kwargs).avrg() + + +def check_func(dist_func): + """ generic checker that tests the given distance function. + """ + # generate spike trains: + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) + spike_trains = [t1, t2, t3, t4] + + isi12 = dist_func(t1, t2) + isi12_ = dist_func([t1, t2]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3) + isi123_ = dist_func([t1, t2, t3]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) + assert_equal(isi123, isi123_) + + +def test_isi_profile(): + check_func(dist_from_prof(spk.isi_profile)) + + +def test_isi_distance(): + check_func(spk.isi_distance) + + +if __name__ == "__main__": + test_isi_profile() + test_isi_distance() |