summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-01-19 23:07:09 +0100
committerMario Mulansky <mario.mulansky@gmx.net>2015-01-19 23:07:09 +0100
commitf2d742c06fd013a013c811593257b67502ea9486 (patch)
tree5d347cc4dc562e7389416cb6a0ce4a907c04a70e
parent6c0f966649c8dedd4115d6809e569732ee5709c9 (diff)
fixed bug for multiple intervals
-rw-r--r--pyspike/function.py60
-rw-r--r--test/test_function.py7
2 files changed, 36 insertions, 31 deletions
diff --git a/pyspike/function.py b/pyspike/function.py
index ebf4189..62b0e2c 100644
--- a/pyspike/function.py
+++ b/pyspike/function.py
@@ -420,26 +420,44 @@ class DiscreteFunction(object):
function, this amounts to the sum over all values divided by the total
multiplicity.
- :param interval: integration interval given as a pair of floats, if
+ :param interval: integration interval given as a pair of floats, or a
+ sequence of pairs in case of multiple intervals, if
None the integral over the whole function is computed.
- :type interval: Pair of floats or None.
+ :type interval: Pair, sequence of pairs, or None.
:returns: the integral
:rtype: float
"""
+
+ def get_indices(ival):
+ start_ind = np.searchsorted(self.x, ival[0], side='right')
+ end_ind = np.searchsorted(self.x, ival[1], side='left')
+ assert start_ind > 0 and end_ind < len(self.x), \
+ "Invalid averaging interval"
+ return start_ind, end_ind
+
if interval is None:
# no interval given, integrate over the whole spike train
# don't count the first value, which is zero by definition
- a = 1.0 * np.sum(self.y[1:-1]) / np.sum(self.mp[1:-1])
- else:
+ return 1.0 * np.sum(self.y[1:-1]) / np.sum(self.mp[1:-1])
+
+ # check if interval is as sequence
+ assert isinstance(interval, collections.Sequence), \
+ "Invalid value for `interval`. None, Sequence or Tuple expected."
+ # check if interval is a sequence of intervals
+ if not isinstance(interval[0], collections.Sequence):
# find the indices corresponding to the interval
- start_ind = np.searchsorted(self.x, interval[0], side='right')
- end_ind = np.searchsorted(self.x, interval[1], side='left')
- assert start_ind > 0 and end_ind < len(self.x), \
- "Invalid averaging interval"
- # first the contribution from between the indices
- a = np.sum(self.y[start_ind:end_ind]) / \
- np.sum(self.mp[start_ind:end_ind])
- return a
+ start_ind, end_ind = get_indices(interval)
+ return (np.sum(self.y[start_ind:end_ind]) /
+ np.sum(self.mp[start_ind:end_ind]))
+ else:
+ value = 0.0
+ multiplicity = 0.0
+ for ival in interval:
+ # find the indices corresponding to the interval
+ start_ind, end_ind = get_indices(ival)
+ value += np.sum(self.y[start_ind:end_ind])
+ multiplicity += np.sum(self.mp[start_ind:end_ind])
+ return value/multiplicity
def avrg(self, interval=None):
""" Computes the average of the interval sequence:
@@ -453,23 +471,7 @@ class DiscreteFunction(object):
:returns: the average a.
:rtype: float
"""
- if interval is None:
- # no interval given, average over the whole spike train
- return self.integral()
-
- # check if interval is as sequence
- assert isinstance(interval, collections.Sequence), \
- "Invalid value for `interval`. None, Sequence or Tuple expected."
- # check if interval is a sequence of intervals
- if not isinstance(interval[0], collections.Sequence):
- # just one interval
- a = self.integral(interval)
- else:
- # several intervals
- a = 0.0
- for ival in interval:
- a += self.integral(ival)
- return a
+ return self.integral(interval)
def add(self, f):
""" Adds another `DiscreteFunction` function to this function.
diff --git a/test/test_function.py b/test/test_function.py
index da3d851..933fd2e 100644
--- a/test/test_function.py
+++ b/test/test_function.py
@@ -216,8 +216,7 @@ def test_df():
assert_array_almost_equal(xp, xp_expected, decimal=16)
assert_array_almost_equal(yp, yp_expected, decimal=16)
- avrg_expected = 2.0 / 5.0
- assert_almost_equal(f.avrg(), avrg_expected, decimal=16)
+ assert_almost_equal(f.avrg(), 2.0/5.0, decimal=16)
# interval averaging
a = f.avrg([0.5, 2.4])
@@ -229,6 +228,10 @@ def test_df():
a = f.avrg([1.1, 4.0])
assert_almost_equal(a, 1.0/3.0, decimal=16)
+ # averaging over multiple intervals
+ a = f.avrg([(0.5, 1.5), (1.5, 2.6)])
+ assert_almost_equal(a, 2.0/5.0, decimal=16)
+
if __name__ == "__main__":
test_pwc()