summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-08-17 15:05:08 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-08-17 15:05:08 +0200
commit0d7255c1fb7398720de10653efee617075c30892 (patch)
tree03bbb5c62ee8e48e6634da2387ad4a867cc6f25f
parentbe74318a2b269ec0c1e16981e7286679746f1a49 (diff)
fix for #14
test case and fix for Issue #14. Spike-Sync function now correctly deal with empty intervals as well.
-rw-r--r--pyspike/DiscreteFunc.py43
-rw-r--r--test/test_empty.py12
2 files changed, 35 insertions, 20 deletions
diff --git a/pyspike/DiscreteFunc.py b/pyspike/DiscreteFunc.py
index a8c054e..4fd496d 100644
--- a/pyspike/DiscreteFunc.py
+++ b/pyspike/DiscreteFunc.py
@@ -137,9 +137,8 @@ class DiscreteFunc(object):
:rtype: pair of float
"""
- if len(self.y) <= 2:
- # no actual values in the profile, return spike sync of 1
- return 1.0, 1.0
+ value = 0.0
+ multiplicity = 0.0
def get_indices(ival):
""" Retuns the indeces surrounding the given interval"""
@@ -152,25 +151,29 @@ class DiscreteFunc(object):
if interval is None:
# no interval given, integrate over the whole spike train
# don't count the first value, which is zero by definition
- return (1.0 * np.sum(self.y[1:-1]), np.sum(self.mp[1:-1]))
-
- # check if interval is as sequence
- assert isinstance(interval, collections.Sequence), \
- "Invalid value for `interval`. None, Sequence or Tuple expected."
- # check if interval is a sequence of intervals
- if not isinstance(interval[0], collections.Sequence):
- # find the indices corresponding to the interval
- start_ind, end_ind = get_indices(interval)
- return (np.sum(self.y[start_ind:end_ind]),
- np.sum(self.mp[start_ind:end_ind]))
+ value = 1.0 * np.sum(self.y[1:-1])
+ multiplicity = np.sum(self.mp[1:-1])
else:
- value = 0.0
- multiplicity = 0.0
- for ival in interval:
+ # check if interval is as sequence
+ assert isinstance(interval, collections.Sequence), \
+ "Invalid value for `interval`. None, Sequence or Tuple \
+expected."
+ # check if interval is a sequence of intervals
+ if not isinstance(interval[0], collections.Sequence):
# find the indices corresponding to the interval
- start_ind, end_ind = get_indices(ival)
- value += np.sum(self.y[start_ind:end_ind])
- multiplicity += np.sum(self.mp[start_ind:end_ind])
+ start_ind, end_ind = get_indices(interval)
+ value = np.sum(self.y[start_ind:end_ind])
+ multiplicity = np.sum(self.mp[start_ind:end_ind])
+ else:
+ for ival in interval:
+ # find the indices corresponding to the interval
+ start_ind, end_ind = get_indices(ival)
+ value += np.sum(self.y[start_ind:end_ind])
+ multiplicity += np.sum(self.mp[start_ind:end_ind])
+ if multiplicity == 0.0:
+ # empty profile, return spike sync of 1
+ value = 1.0
+ multiplicity = 1.0
return (value, multiplicity)
def avrg(self, interval=None):
diff --git a/test/test_empty.py b/test/test_empty.py
index 48be25d..af7fb36 100644
--- a/test/test_empty.py
+++ b/test/test_empty.py
@@ -139,6 +139,18 @@ def test_spike_sync_empty():
assert_array_almost_equal(prof.x, [0.0, 0.2, 0.8, 1.0], decimal=15)
assert_array_almost_equal(prof.y, [0.0, 0.0, 0.0, 0.0], decimal=15)
+ # test with empty intervals
+ st1 = SpikeTrain([2.0, 5.0], [0, 10.0])
+ st2 = SpikeTrain([2.1, 7.0], [0, 10.0])
+ st3 = SpikeTrain([5.1, 6.0], [0, 10.0])
+ res = spk.spike_sync_profile(st1, st2).avrg(interval=[3.0, 4.0])
+ assert_equal(res, 1.0)
+ res = spk.spike_sync(st1, st2, interval=[3.0, 4.0])
+ assert_equal(res, 1.0)
+
+ sync_matrix = spk.spike_sync_matrix([st1, st2, st3], interval=[3.0, 4.0])
+ assert_array_equal(sync_matrix, np.ones((3, 3)) - np.diag(np.ones(3)))
+
if __name__ == "__main__":
test_get_non_empty()