summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-05-11 15:54:19 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-05-11 15:54:19 +0200
commitb6521869cad89eae119391557bfa57e818dc9894 (patch)
treef27c04399919b9d770af68393fa0584689866557
parent017893e782d1747e1f031131077a40a48f882e86 (diff)
treatment of empty spike trains in spike distance
-rw-r--r--pyspike/spike_distance.py13
-rw-r--r--test/test_empty.py48
2 files changed, 54 insertions, 7 deletions
diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py
index d727fa2..ac2d260 100644
--- a/pyspike/spike_distance.py
+++ b/pyspike/spike_distance.py
@@ -41,10 +41,11 @@ Falling back to slow python backend.")
from cython.python_backend import spike_distance_python \
as spike_profile_impl
- times, y_starts, y_ends = spike_profile_impl(spike_train1.spikes,
- spike_train2.spikes,
- spike_train1.t_start,
- spike_train1.t_end)
+ times, y_starts, y_ends = spike_profile_impl(
+ spike_train1.get_spikes_non_empty(),
+ spike_train2.get_spikes_non_empty(),
+ spike_train1.t_start, spike_train1.t_end)
+
return PieceWiseLinFunc(times, y_starts, y_ends)
@@ -74,8 +75,8 @@ def spike_distance(spike_train1, spike_train2, interval=None):
try:
from cython.cython_distances import spike_distance_cython \
as spike_distance_impl
- return spike_distance_impl(spike_train1.spikes,
- spike_train2.spikes,
+ return spike_distance_impl(spike_train1.get_spikes_non_empty(),
+ spike_train2.get_spikes_non_empty(),
spike_train1.t_start,
spike_train1.t_end)
except ImportError:
diff --git a/test/test_empty.py b/test/test_empty.py
index 22982c7..e08bd1a 100644
--- a/test/test_empty.py
+++ b/test/test_empty.py
@@ -10,7 +10,6 @@ Distributed under the BSD License
from __future__ import print_function
import numpy as np
-from copy import copy
from numpy.testing import assert_equal, assert_almost_equal, \
assert_array_equal, assert_array_almost_equal
@@ -57,6 +56,53 @@ def test_isi_empty():
assert_array_almost_equal(prof.y, [0.2/0.6, 0.0, 0.2/0.6], decimal=15)
+def test_spike_empty():
+ st1 = SpikeTrain([], edges=(0.0, 1.0))
+ st2 = SpikeTrain([], edges=(0.0, 1.0))
+ d = spk.spike_distance(st1, st2)
+ assert_equal(d, 0.0)
+ prof = spk.spike_profile(st1, st2)
+ assert_equal(d, prof.avrg())
+ assert_array_equal(prof.x, [0.0, 1.0])
+ assert_array_equal(prof.y1, [0.0, ])
+ assert_array_equal(prof.y2, [0.0, ])
+
+ st1 = SpikeTrain([], edges=(0.0, 1.0))
+ st2 = SpikeTrain([0.4, ], edges=(0.0, 1.0))
+ d = spk.spike_distance(st1, st2)
+ assert_almost_equal(d, 0.4*0.4*1.0/(0.4+1.0)**2 + 0.6*0.4*1.0/(0.6+1.0)**2,
+ decimal=15)
+ prof = spk.spike_profile(st1, st2)
+ assert_equal(d, prof.avrg())
+ assert_array_equal(prof.x, [0.0, 0.4, 1.0])
+ assert_array_almost_equal(prof.y1, [0.0, 2*0.4*1.0/(0.6+1.0)**2],
+ decimal=15)
+ assert_array_almost_equal(prof.y2, [2*0.4*1.0/(0.4+1.0)**2, 0.0],
+ decimal=15)
+
+ st1 = SpikeTrain([0.6, ], edges=(0.0, 1.0))
+ st2 = SpikeTrain([0.4, ], edges=(0.0, 1.0))
+ d = spk.spike_distance(st1, st2)
+ s1 = np.array([0.0, 0.4*0.2/0.6, 0.2, 0.0])
+ s2 = np.array([0.0, 0.2, 0.2*0.4/0.6, 0.0])
+ isi1 = np.array([0.6, 0.6, 0.4])
+ isi2 = np.array([0.4, 0.6, 0.6])
+ expected_y1 = (s1[:-1]*isi2+s2[:-1]*isi1) / (0.5*(isi1+isi2)**2)
+ expected_y2 = (s1[1:]*isi2+s2[1:]*isi1) / (0.5*(isi1+isi2)**2)
+ expected_times = np.array([0.0, 0.4, 0.6, 1.0])
+ expected_spike_val = sum((expected_times[1:] - expected_times[:-1]) *
+ (expected_y1+expected_y2)/2)
+ expected_spike_val /= (expected_times[-1]-expected_times[0])
+
+ assert_almost_equal(d, expected_spike_val, decimal=15)
+ prof = spk.spike_profile(st1, st2)
+ assert_equal(d, prof.avrg())
+ assert_array_almost_equal(prof.x, [0.0, 0.4, 0.6, 1.0], decimal=15)
+ assert_array_almost_equal(prof.y1, expected_y1, decimal=15)
+ assert_array_almost_equal(prof.y2, expected_y2, decimal=15)
+
+
if __name__ == "__main__":
test_get_non_empty()
test_isi_empty()
+ test_spike_empty()