diff options
author | Mario Mulansky <mario.mulansky@gmx.net> | 2016-02-03 13:07:26 +0100 |
---|---|---|
committer | Mario Mulansky <mario.mulansky@gmx.net> | 2016-02-03 13:07:26 +0100 |
commit | b09561705ab9c67c93a384248f7c3bc9ad5bdd32 (patch) | |
tree | 077648ff6011c0ad48114d01f65e755223f0827a /test/test_distance.py | |
parent | 2f48f27b55f63726216b6e674fb88b3790b59147 (diff) |
fixed spike-sync bug
fixed ugly bugs in code for computing multi-variate spike sync profile and
multi-variate spike sync value.
Diffstat (limited to 'test/test_distance.py')
-rw-r--r-- | test/test_distance.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/test/test_distance.py b/test/test_distance.py index 083d8a3..fe09f34 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -217,6 +217,28 @@ def test_spike_sync(): assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.4, decimal=16) + spikes1 = SpikeTrain([1.0, 2.0, 4.0], 4.0) + spikes2 = SpikeTrain([3.8], 4.0) + spikes3 = SpikeTrain([3.9, ], 4.0) + + expected_x = np.array([0.0, 1.0, 2.0, 3.8, 4.0, 4.0]) + expected_y = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0]) + + f = spk.spike_sync_profile(spikes1, spikes2) + + assert_array_almost_equal(f.x, expected_x, decimal=16) + assert_array_almost_equal(f.y, expected_y, decimal=16) + + f2 = spk.spike_sync_profile(spikes2, spikes3) + + i1 = f.integral() + i2 = f2.integral() + f.add(f2) + i12 = f.integral() + + assert_equal(i1[0]+i2[0], i12[0]) + assert_equal(i1[1]+i2[1], i12[1]) + def check_multi_profile(profile_func, profile_func_multi, dist_func_multi): # generate spike trains: @@ -313,6 +335,17 @@ def test_multi_spike_sync(): assert_equal(np.sum(f.y[1:-1]), 39932) assert_equal(np.sum(f.mp[1:-1]), 85554) + # example with 2 empty spike trains + sts = [] + sts.append(SpikeTrain([1, 9], [0, 10])) + sts.append(SpikeTrain([1, 3], [0, 10])) + sts.append(SpikeTrain([], [0, 10])) + sts.append(SpikeTrain([], [0, 10])) + + assert_almost_equal(spk.spike_sync_multi(sts), 1.0/6.0, decimal=15) + assert_almost_equal(spk.spike_sync_profile_multi(sts).avrg(), 1.0/6.0, + decimal=15) + def check_dist_matrix(dist_func, dist_matrix_func): # generate spike trains: |