summaryrefslogtreecommitdiff
path: root/test/test_distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_distance.py')
-rw-r--r--test/test_distance.py79
1 files changed, 41 insertions, 38 deletions
diff --git a/test/test_distance.py b/test/test_distance.py
index d98069d..4f8f6e8 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -135,53 +135,23 @@ def test_spike_sync():
spikes2 = np.array([2.1])
spikes1 = spk.add_auxiliary_spikes(spikes1, 4.0)
spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
- for k in xrange(1, 3):
- assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2, k=k),
- 0.5, decimal=16)
+ assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2),
+ 0.5, decimal=16)
spikes2 = np.array([3.1])
spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
- for k in xrange(1, 3):
- assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2, k=k),
- 0.5, decimal=16)
+ assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2),
+ 0.5, decimal=16)
spikes2 = np.array([1.1])
spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
- for k in xrange(1, 3):
- assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2, k=k),
- 0.5, decimal=16)
+ assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2),
+ 0.5, decimal=16)
spikes2 = np.array([0.9])
spikes2 = spk.add_auxiliary_spikes(spikes2, 4.0)
- for k in xrange(1, 3):
- assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2, k=k),
- 0.5, decimal=16)
-
- spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800,
- 805, 810, 815, 900])
- spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510,
- 600, 605, 700, 910])
- spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640,
- 695, 795, 820, 920])
- spikes1 = spk.add_auxiliary_spikes(spikes1, 1000)
- spikes2 = spk.add_auxiliary_spikes(spikes2, 1000)
- spikes3 = spk.add_auxiliary_spikes(spikes3, 1000)
- for k in xrange(1, 10):
- assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2, k=k),
- 0.5, decimal=15)
- assert_almost_equal(spk.spike_sync_distance(spikes1, spikes3, k=k),
- 0.5, decimal=15)
- assert_almost_equal(spk.spike_sync_distance(spikes2, spikes3, k=k),
- 0.5, decimal=15)
-
- f1 = spk.spike_sync_profile(spikes1, spikes2, k=1)
- f2 = spk.spike_sync_profile(spikes1, spikes3, k=1)
- f3 = spk.spike_sync_profile(spikes2, spikes3, k=1)
- f = spk.spike_sync_profile_multi([spikes1, spikes2, spikes3], k=1)
- # hands on definition of the average multivariate spike synchronization
- expected = (f1.integral() + f2.integral() + f3.integral()) / \
- (len(f1.y)+len(f2.y)+len(f3.y)-3)
- assert_almost_equal(f.avrg(), expected, decimal=15)
+ assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2),
+ 0.5, decimal=16)
def check_multi_profile(profile_func, profile_func_multi):
@@ -226,6 +196,39 @@ def test_multi_spike():
check_multi_profile(spk.spike_profile, spk.spike_profile_multi)
+def test_multi_spike_sync():
+ # some basic multivariate check
+ spikes1 = np.array([100, 300, 400, 405, 410, 500, 700, 800,
+ 805, 810, 815, 900])
+ spikes2 = np.array([100, 200, 205, 210, 295, 350, 400, 510,
+ 600, 605, 700, 910])
+ spikes3 = np.array([100, 180, 198, 295, 412, 420, 510, 640,
+ 695, 795, 820, 920])
+ spikes1 = spk.add_auxiliary_spikes(spikes1, 1000)
+ spikes2 = spk.add_auxiliary_spikes(spikes2, 1000)
+ spikes3 = spk.add_auxiliary_spikes(spikes3, 1000)
+ assert_almost_equal(spk.spike_sync_distance(spikes1, spikes2),
+ 0.5, decimal=15)
+ assert_almost_equal(spk.spike_sync_distance(spikes1, spikes3),
+ 0.5, decimal=15)
+ assert_almost_equal(spk.spike_sync_distance(spikes2, spikes3),
+ 0.5, decimal=15)
+
+ f = spk.spike_sync_profile_multi([spikes1, spikes2, spikes3])
+ # hands on definition of the average multivariate spike synchronization
+ # expected = (f1.integral() + f2.integral() + f3.integral()) / \
+ # (np.sum(f1.mp[1:-1])+np.sum(f2.mp[1:-1])+np.sum(f3.mp[1:-1]))
+ expected = 0.5
+ assert_almost_equal(f.avrg(), expected, decimal=15)
+
+ # multivariate regression test
+ spike_trains = spk.load_spike_trains_from_txt("test/SPIKE_Sync_Test.txt",
+ time_interval=(0, 4000))
+ f = spk.spike_sync_profile_multi(spike_trains)
+ assert_equal(np.sum(f.y[1:-1]), 39932)
+ assert_equal(np.sum(f.mp[1:-1]), 85554)
+
+
def check_dist_matrix(dist_func, dist_matrix_func):
# generate spike trains:
t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)