summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2016-02-03 13:07:26 +0100
committerMario Mulansky <mario.mulansky@gmx.net>2016-02-03 13:07:26 +0100
commitb09561705ab9c67c93a384248f7c3bc9ad5bdd32 (patch)
tree077648ff6011c0ad48114d01f65e755223f0827a /test
parent2f48f27b55f63726216b6e674fb88b3790b59147 (diff)
fixed spike-sync bug
fixed ugly bugs in code for computing multi-variate spike sync profile and multi-variate spike sync value.
Diffstat (limited to 'test')
-rw-r--r--test/numeric/test_regression_random_spikes.py27
-rw-r--r--test/test_distance.py33
2 files changed, 52 insertions, 8 deletions
diff --git a/test/numeric/test_regression_random_spikes.py b/test/numeric/test_regression_random_spikes.py
index 6156bb4..73d39a6 100644
--- a/test/numeric/test_regression_random_spikes.py
+++ b/test/numeric/test_regression_random_spikes.py
@@ -35,7 +35,9 @@ def test_regression_random():
spike = spk.spike_distance_multi(spike_trains)
spike_prof = spk.spike_profile_multi(spike_trains).avrg()
- # spike_sync = spk.spike_sync_multi(spike_trains)
+
+ spike_sync = spk.spike_sync_multi(spike_trains)
+ spike_sync_prof = spk.spike_sync_profile_multi(spike_trains).avrg()
assert_almost_equal(isi, results_cSPIKY[i][0], decimal=14,
err_msg="Index: %d, ISI" % i)
@@ -47,6 +49,9 @@ def test_regression_random():
assert_almost_equal(spike_prof, results_cSPIKY[i][1], decimal=14,
err_msg="Index: %d, SPIKE" % i)
+ assert_almost_equal(spike_sync, spike_sync_prof, decimal=14,
+ err_msg="Index: %d, SPIKE-Sync" % i)
+
def check_regression_dataset(spike_file="benchmark.mat",
spikes_name="spikes",
@@ -109,19 +114,25 @@ def check_single_spike_train_set(index):
spike_train_data = spike_train_sets[index]
spike_trains = []
+ N = 0
for spikes in spike_train_data[0]:
- print("Spikes:", spikes.flatten())
- spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))
+ N += len(spikes.flatten())
+ print("Spikes:", len(spikes.flatten()))
+ spikes_array = spikes.flatten()
+ if len(spikes_array > 0) and (spikes_array[-1] > 100.0):
+ spikes_array[-1] = 100.0
+ spike_trains.append(spk.SpikeTrain(spikes_array, 100.0))
+ print(spike_trains[-1].spikes)
- print(spk.spike_distance_multi(spike_trains))
+ print(N)
- print(results_cSPIKY[index][1])
+ print(spk.spike_sync_multi(spike_trains))
- print(spike_trains[1].spikes)
+ print(spk.spike_sync_profile_multi(spike_trains).integral())
if __name__ == "__main__":
- test_regression_random()
+ # test_regression_random()
# check_regression_dataset()
- # check_single_spike_train_set(7633)
+ check_single_spike_train_set(4)
diff --git a/test/test_distance.py b/test/test_distance.py
index 083d8a3..fe09f34 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -217,6 +217,28 @@ def test_spike_sync():
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.4, decimal=16)
+ spikes1 = SpikeTrain([1.0, 2.0, 4.0], 4.0)
+ spikes2 = SpikeTrain([3.8], 4.0)
+ spikes3 = SpikeTrain([3.9, ], 4.0)
+
+ expected_x = np.array([0.0, 1.0, 2.0, 3.8, 4.0, 4.0])
+ expected_y = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0])
+
+ f = spk.spike_sync_profile(spikes1, spikes2)
+
+ assert_array_almost_equal(f.x, expected_x, decimal=16)
+ assert_array_almost_equal(f.y, expected_y, decimal=16)
+
+ f2 = spk.spike_sync_profile(spikes2, spikes3)
+
+ i1 = f.integral()
+ i2 = f2.integral()
+ f.add(f2)
+ i12 = f.integral()
+
+ assert_equal(i1[0]+i2[0], i12[0])
+ assert_equal(i1[1]+i2[1], i12[1])
+
def check_multi_profile(profile_func, profile_func_multi, dist_func_multi):
# generate spike trains:
@@ -313,6 +335,17 @@ def test_multi_spike_sync():
assert_equal(np.sum(f.y[1:-1]), 39932)
assert_equal(np.sum(f.mp[1:-1]), 85554)
+ # example with 2 empty spike trains
+ sts = []
+ sts.append(SpikeTrain([1, 9], [0, 10]))
+ sts.append(SpikeTrain([1, 3], [0, 10]))
+ sts.append(SpikeTrain([], [0, 10]))
+ sts.append(SpikeTrain([], [0, 10]))
+
+ assert_almost_equal(spk.spike_sync_multi(sts), 1.0/6.0, decimal=15)
+ assert_almost_equal(spk.spike_sync_profile_multi(sts).avrg(), 1.0/6.0,
+ decimal=15)
+
def check_dist_matrix(dist_func, dist_matrix_func):
# generate spike trains: