diff options
author | Mario Mulansky <mario.mulansky@gmx.net> | 2015-04-24 12:08:05 +0200 |
---|---|---|
committer | Mario Mulansky <mario.mulansky@gmx.net> | 2015-04-24 12:08:05 +0200 |
commit | 7da6da8533f9f76a99b959c9de37138377119ffc (patch) | |
tree | cb62f32240db22be53b6faea844f5039bf28a161 /test | |
parent | ed85a9b72edcb7bba6ae1105e213b3b0a2f78d3a (diff) |
changed spike sync implementation to SpikeTrain
Diffstat (limited to 'test')
-rw-r--r-- | test/test_distance.py | 62 |
1 files changed, 37 insertions, 25 deletions
diff --git a/test/test_distance.py b/test/test_distance.py index 4af0e63..dbb72f1 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -138,10 +138,17 @@ def test_spike(): def test_spike_sync(): - spikes1 = np.array([1.0, 2.0, 3.0]) - spikes2 = np.array([2.1]) - spikes1 = spk.add_auxiliary_spikes(spikes1, 4.0) - spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + spikes1 = SpikeTrain([1.0, 2.0, 3.0], 4.0) + spikes2 = SpikeTrain([2.1], 4.0) + + expected_x = np.array([0.0, 1.0, 2.0, 2.1, 3.0, 4.0]) + expected_y = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0]) + + f = spk.spike_sync_profile(spikes1, spikes2) + + assert_array_almost_equal(f.x, expected_x, decimal=16) + assert_array_almost_equal(f.y, expected_y, decimal=16) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.5, decimal=16) @@ -149,28 +156,34 @@ def test_spike_sync(): assert_almost_equal(spk.spike_sync(spikes1, spikes2, max_tau=0.05), 0.0, decimal=16) - spikes2 = np.array([3.1]) - spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + spikes2 = SpikeTrain([3.1], 4.0) assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.5, decimal=16) - spikes2 = np.array([1.1]) - spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + spikes2 = SpikeTrain([1.1], 4.0) + + expected_x = np.array([0.0, 1.0, 1.1, 2.0, 3.0, 4.0]) + expected_y = np.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0]) + + f = spk.spike_sync_profile(spikes1, spikes2) + + assert_array_almost_equal(f.x, expected_x, decimal=16) + assert_array_almost_equal(f.y, expected_y, decimal=16) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.5, decimal=16) - spikes2 = np.array([0.9]) - spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + spikes2 = SpikeTrain([0.9], 4.0) assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.5, decimal=16) def check_multi_profile(profile_func, profile_func_multi): # generate spike trains: - t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) - t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0) - t3 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6]), 1.0) - t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0) + t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) + t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) + t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0) + t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0) spike_trains = [t1, t2, t3, t4] f12 = profile_func(t1, t2) @@ -213,15 +226,12 @@ def test_multi_spike(): def test_multi_spike_sync(): # some basic multivariate check - spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800, - 805, 810, 815, 900], dtype=float) - spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510, - 600, 605, 700, 910], dtype=float) - spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640, - 695, 795, 820, 920], dtype=float) - spikes1 = spk.add_auxiliary_spikes(spikes1, 1000) - spikes2 = spk.add_auxiliary_spikes(spikes2, 1000) - spikes3 = spk.add_auxiliary_spikes(spikes3, 1000) + spikes1 = SpikeTrain([100, 300, 400, 405, 410, 500, 700, 800, + 805, 810, 815, 900], 1000) + spikes2 = SpikeTrain([100, 200, 205, 210, 295, 350, 400, 510, + 600, 605, 700, 910], 1000) + spikes3 = SpikeTrain([100, 180, 198, 295, 412, 420, 510, 640, + 695, 795, 820, 920], 1000) assert_almost_equal(spk.spike_sync(spikes1, spikes2), 0.5, decimal=15) assert_almost_equal(spk.spike_sync(spikes1, spikes3), @@ -326,5 +336,7 @@ def test_multi_variate_subsets(): if __name__ == "__main__": test_isi() test_spike() - # test_multi_isi() - # test_multi_spike() + test_spike_sync() + test_multi_isi() + test_multi_spike() + test_multi_spike_sync() |