diff options
Diffstat (limited to 'test/test_distance.py')
-rw-r--r-- | test/test_distance.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/test/test_distance.py b/test/test_distance.py index 7be0d9b..3b4329c 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -130,6 +130,30 @@ def test_spike(): decimal=16) +def test_spike_sync(): + spikes1 = np.array([1.0, 2.0, 3.0]) + spikes2 = np.array([2.1]) + spikes1 = spk.add_auxiliary_spikes(spikes1, 4.0) + spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=16) + + spikes2 = np.array([3.1]) + spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=16) + + spikes2 = np.array([1.1]) + spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=16) + + spikes2 = np.array([0.9]) + spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=16) + + def check_multi_profile(profile_func, profile_func_multi): # generate spike trains: t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) @@ -148,6 +172,10 @@ def check_multi_profile(profile_func, profile_func_multi): f_multi = profile_func_multi(spike_trains, [0, 1]) assert f_multi.almost_equal(f12, decimal=14) + f_multi1 = profile_func_multi(spike_trains, [1, 2, 3]) + f_multi2 = profile_func_multi(spike_trains[1:]) + assert f_multi1.almost_equal(f_multi2, decimal=14) + f = copy(f12) f.add(f13) f.add(f23) @@ -172,6 +200,41 @@ def test_multi_spike(): check_multi_profile(spk.spike_profile, spk.spike_profile_multi) +def test_multi_spike_sync(): + # some basic multivariate check + spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800, + 805, 810, 815, 900], dtype=float) + spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510, + 600, 605, 700, 910], dtype=float) + spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640, + 695, 795, 820, 920], dtype=float) + spikes1 = spk.add_auxiliary_spikes(spikes1, 1000) + spikes2 = spk.add_auxiliary_spikes(spikes2, 1000) + spikes3 = spk.add_auxiliary_spikes(spikes3, 1000) + assert_almost_equal(spk.spike_sync(spikes1, spikes2), + 0.5, decimal=15) + assert_almost_equal(spk.spike_sync(spikes1, spikes3), + 0.5, decimal=15) + assert_almost_equal(spk.spike_sync(spikes2, spikes3), + 0.5, decimal=15) + + f = spk.spike_sync_profile_multi([spikes1, spikes2, spikes3]) + # hands on definition of the average multivariate spike synchronization + # expected = (f1.integral() + f2.integral() + f3.integral()) / \ + # (np.sum(f1.mp[1:-1])+np.sum(f2.mp[1:-1])+np.sum(f3.mp[1:-1])) + expected = 0.5 + assert_almost_equal(f.avrg(), expected, decimal=15) + assert_almost_equal(spk.spike_sync_multi([spikes1, spikes2, spikes3]), + expected, decimal=15) + + # multivariate regression test + spike_trains = spk.load_spike_trains_from_txt("test/SPIKE_Sync_Test.txt", + time_interval=(0, 4000)) + f = spk.spike_sync_profile_multi(spike_trains) + assert_equal(np.sum(f.y[1:-1]), 39932) + assert_equal(np.sum(f.mp[1:-1]), 85554) + + def check_dist_matrix(dist_func, dist_matrix_func): # generate spike trains: t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) @@ -210,6 +273,10 @@ def test_spike_matrix(): check_dist_matrix(spk.spike_distance, spk.spike_distance_matrix) +def test_spike_sync_matrix(): + check_dist_matrix(spk.spike_sync, spk.spike_sync_matrix) + + def test_regression_spiky(): spike_trains = spk.load_spike_trains_from_txt("test/PySpike_testdata.txt", (0.0, 4000.0)) |