From 0d7255c1fb7398720de10653efee617075c30892 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Mon, 17 Aug 2015 15:05:08 +0200 Subject: fix for #14 test case and fix for Issue #14. Spike-Sync function now correctly deal with empty intervals as well. --- pyspike/DiscreteFunc.py | 43 +++++++++++++++++++++++-------------------- test/test_empty.py | 12 ++++++++++++ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/pyspike/DiscreteFunc.py b/pyspike/DiscreteFunc.py index a8c054e..4fd496d 100644 --- a/pyspike/DiscreteFunc.py +++ b/pyspike/DiscreteFunc.py @@ -137,9 +137,8 @@ class DiscreteFunc(object): :rtype: pair of float """ - if len(self.y) <= 2: - # no actual values in the profile, return spike sync of 1 - return 1.0, 1.0 + value = 0.0 + multiplicity = 0.0 def get_indices(ival): """ Retuns the indeces surrounding the given interval""" @@ -152,25 +151,29 @@ class DiscreteFunc(object): if interval is None: # no interval given, integrate over the whole spike train # don't count the first value, which is zero by definition - return (1.0 * np.sum(self.y[1:-1]), np.sum(self.mp[1:-1])) - - # check if interval is as sequence - assert isinstance(interval, collections.Sequence), \ - "Invalid value for `interval`. None, Sequence or Tuple expected." - # check if interval is a sequence of intervals - if not isinstance(interval[0], collections.Sequence): - # find the indices corresponding to the interval - start_ind, end_ind = get_indices(interval) - return (np.sum(self.y[start_ind:end_ind]), - np.sum(self.mp[start_ind:end_ind])) + value = 1.0 * np.sum(self.y[1:-1]) + multiplicity = np.sum(self.mp[1:-1]) else: - value = 0.0 - multiplicity = 0.0 - for ival in interval: + # check if interval is as sequence + assert isinstance(interval, collections.Sequence), \ + "Invalid value for `interval`. None, Sequence or Tuple \ +expected." + # check if interval is a sequence of intervals + if not isinstance(interval[0], collections.Sequence): # find the indices corresponding to the interval - start_ind, end_ind = get_indices(ival) - value += np.sum(self.y[start_ind:end_ind]) - multiplicity += np.sum(self.mp[start_ind:end_ind]) + start_ind, end_ind = get_indices(interval) + value = np.sum(self.y[start_ind:end_ind]) + multiplicity = np.sum(self.mp[start_ind:end_ind]) + else: + for ival in interval: + # find the indices corresponding to the interval + start_ind, end_ind = get_indices(ival) + value += np.sum(self.y[start_ind:end_ind]) + multiplicity += np.sum(self.mp[start_ind:end_ind]) + if multiplicity == 0.0: + # empty profile, return spike sync of 1 + value = 1.0 + multiplicity = 1.0 return (value, multiplicity) def avrg(self, interval=None): diff --git a/test/test_empty.py b/test/test_empty.py index 48be25d..af7fb36 100644 --- a/test/test_empty.py +++ b/test/test_empty.py @@ -139,6 +139,18 @@ def test_spike_sync_empty(): assert_array_almost_equal(prof.x, [0.0, 0.2, 0.8, 1.0], decimal=15) assert_array_almost_equal(prof.y, [0.0, 0.0, 0.0, 0.0], decimal=15) + # test with empty intervals + st1 = SpikeTrain([2.0, 5.0], [0, 10.0]) + st2 = SpikeTrain([2.1, 7.0], [0, 10.0]) + st3 = SpikeTrain([5.1, 6.0], [0, 10.0]) + res = spk.spike_sync_profile(st1, st2).avrg(interval=[3.0, 4.0]) + assert_equal(res, 1.0) + res = spk.spike_sync(st1, st2, interval=[3.0, 4.0]) + assert_equal(res, 1.0) + + sync_matrix = spk.spike_sync_matrix([st1, st2, st3], interval=[3.0, 4.0]) + assert_array_equal(sync_matrix, np.ones((3, 3)) - np.diag(np.ones(3))) + if __name__ == "__main__": test_get_non_empty() -- cgit v1.2.3