diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_directionality.py | 89 | ||||
-rw-r--r-- | test/test_sync_filter.py | 95 |
2 files changed, 184 insertions, 0 deletions
diff --git a/test/test_directionality.py b/test/test_directionality.py new file mode 100644 index 0000000..63865cc --- /dev/null +++ b/test/test_directionality.py @@ -0,0 +1,89 @@ +""" test_directionality.py + +Tests the directionality functions + +Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net> + +Distributed under the BSD License + +""" + +import numpy as np +from numpy.testing import assert_equal, assert_almost_equal, \ + assert_array_equal + +import pyspike as spk +from pyspike import SpikeTrain, DiscreteFunc +# from pyspike.spike_directionality import _spike_directionality_profile + + +def test_spike_directionality(): + st1 = SpikeTrain([100, 200, 300], [0, 1000]) + st2 = SpikeTrain([105, 205, 300], [0, 1000]) + assert_almost_equal(spk.spike_directionality(st1, st2), 2.0/3.0) + assert_almost_equal(spk.spike_directionality(st1, st2, normalize=False), + 2.0) + + # exchange order of spike trains should give exact negative profile + assert_almost_equal(spk.spike_directionality(st2, st1), -2.0/3.0) + assert_almost_equal(spk.spike_directionality(st2, st1, normalize=False), + -2.0) + + st3 = SpikeTrain([105, 195, 500], [0, 1000]) + assert_almost_equal(spk.spike_directionality(st1, st3), 0.0) + assert_almost_equal(spk.spike_directionality(st1, st3, normalize=False), + 0.0) + assert_almost_equal(spk.spike_directionality(st3, st1), 0.0) + + D = spk.spike_directionality_matrix([st1, st2, st3], normalize=False) + D_expected = np.array([[0, 2.0, 0.0], [-2.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) + assert_array_equal(D, D_expected) + + dir_profs = spk.spike_directionality_profiles([st1, st2, st3]) + assert_array_equal(dir_profs[0], [1.0, 0.0, 0.0]) + assert_array_equal(dir_profs[1], [-0.5, -1.0, 0.0]) + + +def test_spike_train_order(): + st1 = SpikeTrain([100, 200, 300], [0, 1000]) + st2 = SpikeTrain([105, 205, 300], [0, 1000]) + st3 = SpikeTrain([105, 195, 500], [0, 1000]) + + expected_x12 = np.array([0, 100, 105, 200, 205, 300, 1000]) + expected_y12 = np.array([1, 1, 1, 1, 1, 0, 0]) + expected_mp12 = np.array([1, 1, 1, 1, 1, 2, 2]) + + f = spk.spike_train_order_profile(st1, st2) + + assert f.almost_equal(DiscreteFunc(expected_x12, expected_y12, + expected_mp12)) + assert_almost_equal(f.avrg(), 2.0/3.0) + assert_almost_equal(f.avrg(normalize=False), 4.0) + assert_almost_equal(spk.spike_train_order(st1, st2), 2.0/3.0) + assert_almost_equal(spk.spike_train_order(st1, st2, normalize=False), 4.0) + + expected_x23 = np.array([0, 105, 195, 205, 300, 500, 1000]) + expected_y23 = np.array([0, 0, -1, -1, 0, 0, 0]) + expected_mp23 = np.array([2, 2, 1, 1, 1, 1, 1]) + + f = spk.spike_train_order_profile(st2, st3) + + assert_array_equal(f.x, expected_x23) + assert_array_equal(f.y, expected_y23) + assert_array_equal(f.mp, expected_mp23) + assert f.almost_equal(DiscreteFunc(expected_x23, expected_y23, + expected_mp23)) + assert_almost_equal(f.avrg(), -1.0/3.0) + assert_almost_equal(f.avrg(normalize=False), -2.0) + assert_almost_equal(spk.spike_train_order(st2, st3), -1.0/3.0) + assert_almost_equal(spk.spike_train_order(st2, st3, normalize=False), -2.0) + + f = spk.spike_train_order_profile_multi([st1, st2, st3]) + + expected_x = np.array([0, 100, 105, 195, 200, 205, 300, 500, 1000]) + expected_y = np.array([2, 2, 2, -2, 0, 0, 0, 0, 0]) + expected_mp = np.array([2, 2, 4, 2, 2, 2, 4, 2, 2]) + + assert_array_equal(f.x, expected_x) + assert_array_equal(f.y, expected_y) + assert_array_equal(f.mp, expected_mp) diff --git a/test/test_sync_filter.py b/test/test_sync_filter.py new file mode 100644 index 0000000..e259903 --- /dev/null +++ b/test/test_sync_filter.py @@ -0,0 +1,95 @@ +""" test_sync_filter.py + +Tests the spike sync based filtering + +Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net> + +Distributed under the BSD License + +""" + +from __future__ import print_function +import numpy as np +from numpy.testing import assert_equal, assert_almost_equal, \ + assert_array_almost_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +def test_single_prof(): + st1 = np.array([1.0, 2.0, 3.0, 4.0]) + st2 = np.array([1.1, 2.1, 3.8]) + st3 = np.array([0.9, 3.1, 4.1]) + + # cython implementation + try: + from pyspike.cython.cython_profiles import \ + coincidence_single_profile_cython as coincidence_impl + except ImportError: + from pyspike.cython.python_backend import \ + coincidence_single_python as coincidence_impl + + sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0), + SpikeTrain(st2, 5.0)) + + coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0)) + print(coincidences) + for i, t in enumerate(st1): + assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t], + "At index %d" % i) + + coincidences = np.array(coincidence_impl(st2, st1, 0, 5.0, 0.0)) + for i, t in enumerate(st2): + assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t], + "At index %d" % i) + + sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0), + SpikeTrain(st3, 5.0)) + + coincidences = np.array(coincidence_impl(st1, st3, 0, 5.0, 0.0)) + for i, t in enumerate(st1): + assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t], + "At index %d" % i) + + st1 = np.array([1.0, 2.0, 3.0, 4.0]) + st2 = np.array([1.0, 2.0, 4.0]) + + sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0), + SpikeTrain(st2, 5.0)) + + coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0)) + for i, t in enumerate(st1): + expected = sync_prof.y[sync_prof.x == t]/sync_prof.mp[sync_prof.x == t] + assert_equal(coincidences[i], expected, + "At index %d" % i) + + +def test_filter(): + st1 = SpikeTrain(np.array([1.0, 2.0, 3.0, 4.0]), 5.0) + st2 = SpikeTrain(np.array([1.1, 2.1, 3.8]), 5.0) + st3 = SpikeTrain(np.array([0.9, 3.1, 4.1]), 5.0) + + # filtered_spike_trains = spk.filter_by_spike_sync([st1, st2], 0.5) + + # assert_equal(filtered_spike_trains[0].spikes, [1.0, 2.0, 4.0]) + # assert_equal(filtered_spike_trains[1].spikes, [1.1, 2.1, 3.8]) + + # filtered_spike_trains = spk.filter_by_spike_sync([st2, st1], 0.5) + + # assert_equal(filtered_spike_trains[0].spikes, [1.1, 2.1, 3.8]) + # assert_equal(filtered_spike_trains[1].spikes, [1.0, 2.0, 4.0]) + + filtered_spike_trains = spk.filter_by_spike_sync([st1, st2, st3], 0.75) + + for st in filtered_spike_trains: + print(st.spikes) + + assert_equal(filtered_spike_trains[0].spikes, [1.0, 4.0]) + assert_equal(filtered_spike_trains[1].spikes, [1.1, 3.8]) + assert_equal(filtered_spike_trains[2].spikes, [0.9, 4.1]) + + +if __name__ == "main": + test_single_prof() + test_filter() |