diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/numeric/regression_random_results_cSPIKY.mat | bin | 0 -> 149104 bytes | |||
-rw-r--r-- | test/numeric/regression_random_spikes.mat | bin | 0 -> 16241579 bytes | |||
-rw-r--r-- | test/numeric/test_regression_random_spikes.py | 127 | ||||
-rw-r--r-- | test/test_distance.py | 1 | ||||
-rw-r--r-- | test/test_empty.py | 16 | ||||
-rw-r--r-- | test/test_generic_interfaces.py | 105 | ||||
-rw-r--r-- | test/test_regression/test_regression_15.py | 1 |
7 files changed, 244 insertions, 6 deletions
diff --git a/test/numeric/regression_random_results_cSPIKY.mat b/test/numeric/regression_random_results_cSPIKY.mat Binary files differnew file mode 100644 index 0000000..26f29ff --- /dev/null +++ b/test/numeric/regression_random_results_cSPIKY.mat diff --git a/test/numeric/regression_random_spikes.mat b/test/numeric/regression_random_spikes.mat Binary files differnew file mode 100644 index 0000000..e5ebeb1 --- /dev/null +++ b/test/numeric/regression_random_spikes.mat diff --git a/test/numeric/test_regression_random_spikes.py b/test/numeric/test_regression_random_spikes.py new file mode 100644 index 0000000..6156bb4 --- /dev/null +++ b/test/numeric/test_regression_random_spikes.py @@ -0,0 +1,127 @@ +""" regression benchmark + +Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net> + +Distributed under the BSD License +""" +from __future__ import print_function + +import numpy as np +from scipy.io import loadmat +import pyspike as spk + +from numpy.testing import assert_almost_equal + +spk.disable_backend_warning = True + + +def test_regression_random(): + + spike_file = "test/numeric/regression_random_spikes.mat" + spikes_name = "spikes" + result_name = "Distances" + result_file = "test/numeric/regression_random_results_cSPIKY.mat" + + spike_train_sets = loadmat(spike_file)[spikes_name][0] + results_cSPIKY = loadmat(result_file)[result_name] + + for i, spike_train_data in enumerate(spike_train_sets): + spike_trains = [] + for spikes in spike_train_data[0]: + spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0)) + + isi = spk.isi_distance_multi(spike_trains) + isi_prof = spk.isi_profile_multi(spike_trains).avrg() + + spike = spk.spike_distance_multi(spike_trains) + spike_prof = spk.spike_profile_multi(spike_trains).avrg() + # spike_sync = spk.spike_sync_multi(spike_trains) + + assert_almost_equal(isi, results_cSPIKY[i][0], decimal=14, + err_msg="Index: %d, ISI" % i) + assert_almost_equal(isi_prof, results_cSPIKY[i][0], decimal=14, + err_msg="Index: %d, ISI" % i) + + assert_almost_equal(spike, results_cSPIKY[i][1], decimal=14, + err_msg="Index: %d, SPIKE" % i) + assert_almost_equal(spike_prof, results_cSPIKY[i][1], decimal=14, + err_msg="Index: %d, SPIKE" % i) + + +def check_regression_dataset(spike_file="benchmark.mat", + spikes_name="spikes", + result_file="results_cSPIKY.mat", + result_name="Distances"): + """ Debuging function """ + np.set_printoptions(precision=15) + + spike_train_sets = loadmat(spike_file)[spikes_name][0] + + results_cSPIKY = loadmat(result_file)[result_name] + + err_max = 0.0 + err_max_ind = -1 + err_count = 0 + + for i, spike_train_data in enumerate(spike_train_sets): + spike_trains = [] + for spikes in spike_train_data[0]: + spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0)) + + isi = spk.isi_distance_multi(spike_trains) + spike = spk.spike_distance_multi(spike_trains) + # spike_sync = spk.spike_sync_multi(spike_trains) + + if abs(isi - results_cSPIKY[i][0]) > 1E-14: + print("Error in ISI:", i, isi, results_cSPIKY[i][0]) + print("Spike trains:") + for st in spike_trains: + print(st.spikes) + + err = abs(spike - results_cSPIKY[i][1]) + if err > 1E-14: + err_count += 1 + if err > err_max: + err_max = err + err_max_ind = i + + print("Total Errors:", err_count) + + if err_max_ind > -1: + print("Max SPIKE distance error:", err_max, "at index:", err_max_ind) + spike_train_data = spike_train_sets[err_max_ind] + for spikes in spike_train_data[0]: + print(spikes.flatten()) + + +def check_single_spike_train_set(index): + """ Debuging function """ + np.set_printoptions(precision=15) + spike_file = "regression_random_spikes.mat" + spikes_name = "spikes" + result_name = "Distances" + result_file = "regression_random_results_cSPIKY.mat" + + spike_train_sets = loadmat(spike_file)[spikes_name][0] + + results_cSPIKY = loadmat(result_file)[result_name] + + spike_train_data = spike_train_sets[index] + + spike_trains = [] + for spikes in spike_train_data[0]: + print("Spikes:", spikes.flatten()) + spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0)) + + print(spk.spike_distance_multi(spike_trains)) + + print(results_cSPIKY[index][1]) + + print(spike_trains[1].spikes) + + +if __name__ == "__main__": + + test_regression_random() + # check_regression_dataset() + # check_single_spike_train_set(7633) diff --git a/test/test_distance.py b/test/test_distance.py index 8cf81e2..083d8a3 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -20,6 +20,7 @@ from pyspike import SpikeTrain import os TEST_PATH = os.path.dirname(os.path.realpath(__file__)) + def test_isi(): # generate two spike trains: t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) diff --git a/test/test_empty.py b/test/test_empty.py index 5a0042f..4d0a5cf 100644 --- a/test/test_empty.py +++ b/test/test_empty.py @@ -24,7 +24,9 @@ def test_get_non_empty(): st = SpikeTrain([0.5, ], edges=(0.0, 1.0)) spikes = st.get_spikes_non_empty() - assert_array_equal(spikes, [0.0, 0.5, 1.0]) + # assert_array_equal(spikes, [0.0, 0.5, 1.0]) + # spike trains with one spike don't get edge spikes anymore + assert_array_equal(spikes, [0.5, ]) def test_isi_empty(): @@ -70,21 +72,23 @@ def test_spike_empty(): st1 = SpikeTrain([], edges=(0.0, 1.0)) st2 = SpikeTrain([0.4, ], edges=(0.0, 1.0)) d = spk.spike_distance(st1, st2) - d_expect = 0.4*0.4*1.0/(0.4+1.0)**2 + 0.6*0.4*1.0/(0.6+1.0)**2 + d_expect = 2*0.4*0.4*1.0/(0.4+1.0)**2 + 2*0.6*0.4*1.0/(0.6+1.0)**2 assert_almost_equal(d, d_expect, decimal=15) prof = spk.spike_profile(st1, st2) assert_equal(d, prof.avrg()) assert_array_equal(prof.x, [0.0, 0.4, 1.0]) - assert_array_almost_equal(prof.y1, [0.0, 2*0.4*1.0/(0.6+1.0)**2], + assert_array_almost_equal(prof.y1, [2*0.4*1.0/(0.4+1.0)**2, + 2*0.4*1.0/(0.6+1.0)**2], decimal=15) - assert_array_almost_equal(prof.y2, [2*0.4*1.0/(0.4+1.0)**2, 0.0], + assert_array_almost_equal(prof.y2, [2*0.4*1.0/(0.4+1.0)**2, + 2*0.4*1.0/(0.6+1.0)**2], decimal=15) st1 = SpikeTrain([0.6, ], edges=(0.0, 1.0)) st2 = SpikeTrain([0.4, ], edges=(0.0, 1.0)) d = spk.spike_distance(st1, st2) - s1 = np.array([0.0, 0.4*0.2/0.6, 0.2, 0.0]) - s2 = np.array([0.0, 0.2, 0.2*0.4/0.6, 0.0]) + s1 = np.array([0.2, 0.2, 0.2, 0.2]) + s2 = np.array([0.2, 0.2, 0.2, 0.2]) isi1 = np.array([0.6, 0.6, 0.4]) isi2 = np.array([0.4, 0.6, 0.6]) expected_y1 = (s1[:-1]*isi2+s2[:-1]*isi1) / (0.5*(isi1+isi2)**2) diff --git a/test/test_generic_interfaces.py b/test/test_generic_interfaces.py new file mode 100644 index 0000000..7f08067 --- /dev/null +++ b/test/test_generic_interfaces.py @@ -0,0 +1,105 @@ +""" test_generic_interface.py + +Tests the generic interfaces of the profile and distance functions + +Copyright 2016, Mario Mulansky <mario.mulansky@gmx.net> + +Distributed under the BSD License + +""" + +from __future__ import print_function +from numpy.testing import assert_equal + +import pyspike as spk +from pyspike import SpikeTrain + + +class dist_from_prof: + """ Simple functor that turns profile function into distance function by + calling profile.avrg(). + """ + def __init__(self, prof_func): + self.prof_func = prof_func + + def __call__(self, *args, **kwargs): + if "interval" in kwargs: + # forward interval arg into avrg function + interval = kwargs.pop("interval") + return self.prof_func(*args, **kwargs).avrg(interval=interval) + else: + return self.prof_func(*args, **kwargs).avrg() + + +def check_func(dist_func): + """ generic checker that tests the given distance function. + """ + # generate spike trains: + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) + spike_trains = [t1, t2, t3, t4] + + isi12 = dist_func(t1, t2) + isi12_ = dist_func([t1, t2]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3) + isi123_ = dist_func([t1, t2, t3]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2]) + assert_equal(isi123, isi123_) + + # run the same test with an additional interval parameter + + isi12 = dist_func(t1, t2, interval=[0.0, 0.5]) + isi12_ = dist_func([t1, t2], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi12_ = dist_func(spike_trains, indices=[0, 1], interval=[0.0, 0.5]) + assert_equal(isi12, isi12_) + + isi123 = dist_func(t1, t2, t3, interval=[0.0, 0.5]) + isi123_ = dist_func([t1, t2, t3], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + + isi123_ = dist_func(spike_trains, indices=[0, 1, 2], interval=[0.0, 0.5]) + assert_equal(isi123, isi123_) + + +def test_isi_profile(): + check_func(dist_from_prof(spk.isi_profile)) + + +def test_isi_distance(): + check_func(spk.isi_distance) + + +def test_spike_profile(): + check_func(dist_from_prof(spk.spike_profile)) + + +def test_spike_distance(): + check_func(spk.spike_distance) + + +def test_spike_sync_profile(): + check_func(dist_from_prof(spk.spike_sync_profile)) + + +def test_spike_sync(): + check_func(spk.spike_sync) + + +if __name__ == "__main__": + test_isi_profile() + test_isi_distance() + test_spike_profile() + test_spike_distance() + test_spike_sync_profile() + test_spike_sync() diff --git a/test/test_regression/test_regression_15.py b/test/test_regression/test_regression_15.py index dcacae2..54adf23 100644 --- a/test/test_regression/test_regression_15.py +++ b/test/test_regression/test_regression_15.py @@ -20,6 +20,7 @@ import os TEST_PATH = os.path.dirname(os.path.realpath(__file__)) TEST_DATA = os.path.join(TEST_PATH, "..", "SPIKE_Sync_Test.txt") + def test_regression_15_isi(): # load spike trains spike_trains = spk.load_spike_trains_from_txt(TEST_DATA, edges=[0, 4000]) |