summaryrefslogtreecommitdiff
path: root/test/test_sync_filter.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_sync_filter.py')
-rw-r--r--test/test_sync_filter.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/test/test_sync_filter.py b/test/test_sync_filter.py
index e259903..0b915db 100644
--- a/test/test_sync_filter.py
+++ b/test/test_sync_filter.py
@@ -10,7 +10,7 @@ Distributed under the BSD License
from __future__ import print_function
import numpy as np
-from numpy.testing import assert_equal, assert_almost_equal, \
+from numpy.testing import assert_allclose, assert_almost_equal, \
assert_array_almost_equal
import pyspike as spk
@@ -36,21 +36,21 @@ def test_single_prof():
coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0))
print(coincidences)
for i, t in enumerate(st1):
- assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
- "At index %d" % i)
+ assert_allclose(coincidences[i], sync_prof.y[sync_prof.x == t],
+ err_msg="At index %d" % i)
coincidences = np.array(coincidence_impl(st2, st1, 0, 5.0, 0.0))
for i, t in enumerate(st2):
- assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
- "At index %d" % i)
+ assert_allclose(coincidences[i], sync_prof.y[sync_prof.x == t],
+ err_msg="At index %d" % i)
sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
SpikeTrain(st3, 5.0))
coincidences = np.array(coincidence_impl(st1, st3, 0, 5.0, 0.0))
for i, t in enumerate(st1):
- assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
- "At index %d" % i)
+ assert_allclose(coincidences[i], sync_prof.y[sync_prof.x == t],
+ err_msg="At index %d" % i)
st1 = np.array([1.0, 2.0, 3.0, 4.0])
st2 = np.array([1.0, 2.0, 4.0])
@@ -61,8 +61,8 @@ def test_single_prof():
coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0))
for i, t in enumerate(st1):
expected = sync_prof.y[sync_prof.x == t]/sync_prof.mp[sync_prof.x == t]
- assert_equal(coincidences[i], expected,
- "At index %d" % i)
+ assert_allclose(coincidences[i], expected,
+ err_msg="At index %d" % i)
def test_filter():
@@ -72,22 +72,22 @@ def test_filter():
# filtered_spike_trains = spk.filter_by_spike_sync([st1, st2], 0.5)
- # assert_equal(filtered_spike_trains[0].spikes, [1.0, 2.0, 4.0])
- # assert_equal(filtered_spike_trains[1].spikes, [1.1, 2.1, 3.8])
+ # assert_allclose(filtered_spike_trains[0].spikes, [1.0, 2.0, 4.0])
+ # assert_allclose(filtered_spike_trains[1].spikes, [1.1, 2.1, 3.8])
# filtered_spike_trains = spk.filter_by_spike_sync([st2, st1], 0.5)
- # assert_equal(filtered_spike_trains[0].spikes, [1.1, 2.1, 3.8])
- # assert_equal(filtered_spike_trains[1].spikes, [1.0, 2.0, 4.0])
+ # assert_allclose(filtered_spike_trains[0].spikes, [1.1, 2.1, 3.8])
+ # assert_allclose(filtered_spike_trains[1].spikes, [1.0, 2.0, 4.0])
filtered_spike_trains = spk.filter_by_spike_sync([st1, st2, st3], 0.75)
for st in filtered_spike_trains:
print(st.spikes)
- assert_equal(filtered_spike_trains[0].spikes, [1.0, 4.0])
- assert_equal(filtered_spike_trains[1].spikes, [1.1, 3.8])
- assert_equal(filtered_spike_trains[2].spikes, [0.9, 4.1])
+ assert_allclose(filtered_spike_trains[0].spikes, [1.0, 4.0])
+ assert_allclose(filtered_spike_trains[1].spikes, [1.1, 3.8])
+ assert_allclose(filtered_spike_trains[2].spikes, [0.9, 4.1])
if __name__ == "main":