summaryrefslogtreecommitdiff
path: root/test/test_sync_filter.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sync_filter.py')
-rw-r--r--test/test_sync_filter.py61
1 files changed, 56 insertions, 5 deletions
diff --git a/test/test_sync_filter.py b/test/test_sync_filter.py
index ce03b23..66ffcb6 100644
--- a/test/test_sync_filter.py
+++ b/test/test_sync_filter.py
@@ -17,17 +17,18 @@ import pyspike as spk
from pyspike import SpikeTrain
-def test_cython():
+def test_single_prof():
st1 = np.array([1.0, 2.0, 3.0, 4.0])
st2 = np.array([1.1, 2.1, 3.8])
+ st3 = np.array([0.9, 3.1, 4.1])
# cython implementation
try:
- from pyspike.cython.cython_profiles import coincidence_single_profile_cython \
- as coincidence_impl
+ from pyspike.cython.cython_profiles import \
+ coincidence_single_profile_cython as coincidence_impl
except ImportError:
- from pyspike.cython.python_backend import coincidence_single_profile_python \
- as coincidence_impl
+ from pyspike.cython.python_backend import \
+ coincidence_single_profile_python as coincidence_impl
sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
SpikeTrain(st2, 5.0))
@@ -41,3 +42,53 @@ def test_cython():
for i, t in enumerate(st2):
assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
"At index %d" % i)
+
+ sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
+ SpikeTrain(st3, 5.0))
+
+ coincidences = np.array(coincidence_impl(st1, st3, 0, 5.0, 0.0))
+ for i, t in enumerate(st1):
+ assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
+ "At index %d" % i)
+
+ st1 = np.array([1.0, 2.0, 3.0, 4.0])
+ st2 = np.array([1.0, 2.0, 4.0])
+
+ sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
+ SpikeTrain(st2, 5.0))
+
+ coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0))
+ for i, t in enumerate(st1):
+ expected = sync_prof.y[sync_prof.x == t]/sync_prof.mp[sync_prof.x == t]
+ assert_equal(coincidences[i], expected,
+ "At index %d" % i)
+
+
+def test_filter():
+ st1 = SpikeTrain(np.array([1.0, 2.0, 3.0, 4.0]), 5.0)
+ st2 = SpikeTrain(np.array([1.1, 2.1, 3.8]), 5.0)
+ st3 = SpikeTrain(np.array([0.9, 3.1, 4.1]), 5.0)
+
+ # filtered_spike_trains = spk.filter_by_spike_sync([st1, st2], 0.5)
+
+ # assert_equal(filtered_spike_trains[0].spikes, [1.0, 2.0, 4.0])
+ # assert_equal(filtered_spike_trains[1].spikes, [1.1, 2.1, 3.8])
+
+ # filtered_spike_trains = spk.filter_by_spike_sync([st2, st1], 0.5)
+
+ # assert_equal(filtered_spike_trains[0].spikes, [1.1, 2.1, 3.8])
+ # assert_equal(filtered_spike_trains[1].spikes, [1.0, 2.0, 4.0])
+
+ filtered_spike_trains = spk.filter_by_spike_sync([st1, st2, st3], 0.75)
+
+ for st in filtered_spike_trains:
+ print(st.spikes)
+
+ assert_equal(filtered_spike_trains[0].spikes, [1.0, 4.0])
+ assert_equal(filtered_spike_trains[1].spikes, [1.1, 3.8])
+ assert_equal(filtered_spike_trains[2].spikes, [0.9, 4.1])
+
+
+if __name__ == "main":
+ test_single_prof()
+ test_filter()