From f2d742c06fd013a013c811593257b67502ea9486 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Mon, 19 Jan 2015 23:07:09 +0100 Subject: fixed bug for multiple intervals --- pyspike/function.py | 60 ++++++++++++++++++++++++++------------------------- test/test_function.py | 7 ++++-- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/pyspike/function.py b/pyspike/function.py index ebf4189..62b0e2c 100644 --- a/pyspike/function.py +++ b/pyspike/function.py @@ -420,26 +420,44 @@ class DiscreteFunction(object): function, this amounts to the sum over all values divided by the total multiplicity. - :param interval: integration interval given as a pair of floats, if + :param interval: integration interval given as a pair of floats, or a + sequence of pairs in case of multiple intervals, if None the integral over the whole function is computed. - :type interval: Pair of floats or None. + :type interval: Pair, sequence of pairs, or None. :returns: the integral :rtype: float """ + + def get_indices(ival): + start_ind = np.searchsorted(self.x, ival[0], side='right') + end_ind = np.searchsorted(self.x, ival[1], side='left') + assert start_ind > 0 and end_ind < len(self.x), \ + "Invalid averaging interval" + return start_ind, end_ind + if interval is None: # no interval given, integrate over the whole spike train # don't count the first value, which is zero by definition - a = 1.0 * np.sum(self.y[1:-1]) / np.sum(self.mp[1:-1]) - else: + return 1.0 * np.sum(self.y[1:-1]) / np.sum(self.mp[1:-1]) + + # check if interval is as sequence + assert isinstance(interval, collections.Sequence), \ + "Invalid value for `interval`. None, Sequence or Tuple expected." + # check if interval is a sequence of intervals + if not isinstance(interval[0], collections.Sequence): # find the indices corresponding to the interval - start_ind = np.searchsorted(self.x, interval[0], side='right') - end_ind = np.searchsorted(self.x, interval[1], side='left') - assert start_ind > 0 and end_ind < len(self.x), \ - "Invalid averaging interval" - # first the contribution from between the indices - a = np.sum(self.y[start_ind:end_ind]) / \ - np.sum(self.mp[start_ind:end_ind]) - return a + start_ind, end_ind = get_indices(interval) + return (np.sum(self.y[start_ind:end_ind]) / + np.sum(self.mp[start_ind:end_ind])) + else: + value = 0.0 + multiplicity = 0.0 + for ival in interval: + # find the indices corresponding to the interval + start_ind, end_ind = get_indices(ival) + value += np.sum(self.y[start_ind:end_ind]) + multiplicity += np.sum(self.mp[start_ind:end_ind]) + return value/multiplicity def avrg(self, interval=None): """ Computes the average of the interval sequence: @@ -453,23 +471,7 @@ class DiscreteFunction(object): :returns: the average a. :rtype: float """ - if interval is None: - # no interval given, average over the whole spike train - return self.integral() - - # check if interval is as sequence - assert isinstance(interval, collections.Sequence), \ - "Invalid value for `interval`. None, Sequence or Tuple expected." - # check if interval is a sequence of intervals - if not isinstance(interval[0], collections.Sequence): - # just one interval - a = self.integral(interval) - else: - # several intervals - a = 0.0 - for ival in interval: - a += self.integral(ival) - return a + return self.integral(interval) def add(self, f): """ Adds another `DiscreteFunction` function to this function. diff --git a/test/test_function.py b/test/test_function.py index da3d851..933fd2e 100644 --- a/test/test_function.py +++ b/test/test_function.py @@ -216,8 +216,7 @@ def test_df(): assert_array_almost_equal(xp, xp_expected, decimal=16) assert_array_almost_equal(yp, yp_expected, decimal=16) - avrg_expected = 2.0 / 5.0 - assert_almost_equal(f.avrg(), avrg_expected, decimal=16) + assert_almost_equal(f.avrg(), 2.0/5.0, decimal=16) # interval averaging a = f.avrg([0.5, 2.4]) @@ -229,6 +228,10 @@ def test_df(): a = f.avrg([1.1, 4.0]) assert_almost_equal(a, 1.0/3.0, decimal=16) + # averaging over multiple intervals + a = f.avrg([(0.5, 1.5), (1.5, 2.6)]) + assert_almost_equal(a, 2.0/5.0, decimal=16) + if __name__ == "__main__": test_pwc() -- cgit v1.2.3