summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-04-24 12:08:05 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-04-24 12:08:05 +0200
commit7da6da8533f9f76a99b959c9de37138377119ffc (patch)
treecb62f32240db22be53b6faea844f5039bf28a161 /test
parented85a9b72edcb7bba6ae1105e213b3b0a2f78d3a (diff)
changed spike sync implementation to SpikeTrain
Diffstat (limited to 'test')
-rw-r--r--test/test_distance.py62
1 files changed, 37 insertions, 25 deletions
diff --git a/test/test_distance.py b/test/test_distance.py
index 4af0e63..dbb72f1 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -138,10 +138,17 @@ def test_spike():
def test_spike_sync():
- spikes1 = np.array([1.0, 2.0, 3.0])
- spikes2 = np.array([2.1])
- spikes1 = spk.add_auxiliary_spikes(spikes1, 4.0)
- spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
+ spikes1 = SpikeTrain([1.0, 2.0, 3.0], 4.0)
+ spikes2 = SpikeTrain([2.1], 4.0)
+
+ expected_x = np.array([0.0, 1.0, 2.0, 2.1, 3.0, 4.0])
+ expected_y = np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0])
+
+ f = spk.spike_sync_profile(spikes1, spikes2)
+
+ assert_array_almost_equal(f.x, expected_x, decimal=16)
+ assert_array_almost_equal(f.y, expected_y, decimal=16)
+
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=16)
@@ -149,28 +156,34 @@ def test_spike_sync():
assert_almost_equal(spk.spike_sync(spikes1, spikes2, max_tau=0.05),
0.0, decimal=16)
- spikes2 = np.array([3.1])
- spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
+ spikes2 = SpikeTrain([3.1], 4.0)
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=16)
- spikes2 = np.array([1.1])
- spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
+ spikes2 = SpikeTrain([1.1], 4.0)
+
+ expected_x = np.array([0.0, 1.0, 1.1, 2.0, 3.0, 4.0])
+ expected_y = np.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
+
+ f = spk.spike_sync_profile(spikes1, spikes2)
+
+ assert_array_almost_equal(f.x, expected_x, decimal=16)
+ assert_array_almost_equal(f.y, expected_y, decimal=16)
+
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=16)
- spikes2 = np.array([0.9])
- spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
+ spikes2 = SpikeTrain([0.9], 4.0)
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=16)
def check_multi_profile(profile_func, profile_func_multi):
# generate spike trains:
- t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)
- t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0)
- t3 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6]), 1.0)
- t4 = spk.add_auxiliary_spikes(np.array([0.1, 0.4, 0.5, 0.6]), 1.0)
+ t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
+ t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0)
+ t3 = SpikeTrain([0.2, 0.4, 0.6], 1.0)
+ t4 = SpikeTrain([0.1, 0.4, 0.5, 0.6], 1.0)
spike_trains = [t1, t2, t3, t4]
f12 = profile_func(t1, t2)
@@ -213,15 +226,12 @@ def test_multi_spike():
def test_multi_spike_sync():
# some basic multivariate check
- spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800,
- 805, 810, 815, 900], dtype=float)
- spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510,
- 600, 605, 700, 910], dtype=float)
- spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640,
- 695, 795, 820, 920], dtype=float)
- spikes1 = spk.add_auxiliary_spikes(spikes1, 1000)
- spikes2 = spk.add_auxiliary_spikes(spikes2, 1000)
- spikes3 = spk.add_auxiliary_spikes(spikes3, 1000)
+ spikes1 = SpikeTrain([100, 300, 400, 405, 410, 500, 700, 800,
+ 805, 810, 815, 900], 1000)
+ spikes2 = SpikeTrain([100, 200, 205, 210, 295, 350, 400, 510,
+ 600, 605, 700, 910], 1000)
+ spikes3 = SpikeTrain([100, 180, 198, 295, 412, 420, 510, 640,
+ 695, 795, 820, 920], 1000)
assert_almost_equal(spk.spike_sync(spikes1, spikes2),
0.5, decimal=15)
assert_almost_equal(spk.spike_sync(spikes1, spikes3),
@@ -326,5 +336,7 @@ def test_multi_variate_subsets():
if __name__ == "__main__":
test_isi()
test_spike()
- # test_multi_isi()
- # test_multi_spike()
+ test_spike_sync()
+ test_multi_isi()
+ test_multi_spike()
+ test_multi_spike_sync()