From 11f05d9dac3711d89db37a043db3d9437958c6f3 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Wed, 22 Oct 2014 22:17:27 +0200 Subject: avrg functions now take intervals --- pyspike/function.py | 51 ++++++++++++++++++++++++++++++++++++++++++++------- test/test_function.py | 19 ++++++++++++++++++- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/pyspike/function.py b/pyspike/function.py index 14ad7bd..e340096 100644 --- a/pyspike/function.py +++ b/pyspike/function.py @@ -83,20 +83,24 @@ class PieceWiseConstFunc(object): """ if interval is None: # no interval given, average over the whole spike train - return np.sum((self.x[1:]-self.x[:-1]) * self.y) / \ + a = np.sum((self.x[1:]-self.x[:-1]) * self.y) / \ (self.x[-1]-self.x[0]) else: + # find the indices corresponding to the interval start_ind = np.searchsorted(self.x, interval[0], side='right') end_ind = np.searchsorted(self.x, interval[1], side='left')-1 - print(start_ind, end_ind) assert start_ind > 0 and end_ind < len(self.x), \ "Invalid averaging interval" - a = np.sum((self.x[start_ind+1:end_ind] - - self.x[start_ind:end_ind-1]) * + # first the contribution from between the indices + a = np.sum((self.x[start_ind+1:end_ind+1] - + self.x[start_ind:end_ind]) * self.y[start_ind:end_ind]) - a += (self.x[start_ind]-interval[0]) * self.y[start_ind] + # correction from start to first index + a += (self.x[start_ind]-interval[0]) * self.y[start_ind-1] + # correction from last index to end a += (interval[1]-self.x[end_ind]) * self.y[end_ind] - return a / (interval[1]-interval[0]) + a /= (interval[1]-interval[0]) + return a def add(self, f): """ Adds another PieceWiseConst function to this function. @@ -202,10 +206,43 @@ class PieceWiseLinFunc: :returns: the average a. :rtype: double """ + + def intermediate_value(x0, x1, y0, y1, x): + """ computes the intermediate value of a linear function """ + return y0 + (y1-y0)*(x-x0)/(x1-x0) + if interval is None: # no interval given, average over the whole spike train - return np.sum((self.x[1:]-self.x[:-1]) * 0.5*(self.y1+self.y2)) / \ + a = np.sum((self.x[1:]-self.x[:-1]) * 0.5*(self.y1+self.y2)) / \ (self.x[-1]-self.x[0]) + else: + # find the indices corresponding to the interval + start_ind = np.searchsorted(self.x, interval[0], side='right') + end_ind = np.searchsorted(self.x, interval[1], side='left')-1 + assert start_ind > 0 and end_ind < len(self.x), \ + "Invalid averaging interval" + # first the contribution from between the indices + a = np.sum((self.x[start_ind+1:end_ind+1] - + self.x[start_ind:end_ind]) * + 0.5*(self.y1[start_ind:end_ind] + + self.y2[start_ind:end_ind])) + # correction from start to first index + a += (self.x[start_ind]-interval[0]) * 0.5 * \ + (self.y2[start_ind-1] + + intermediate_value(self.x[start_ind-1], self.x[start_ind], + self.y1[start_ind-1], + self.y2[start_ind-1], + interval[0] + )) + # correction from last index to end + a += (interval[1]-self.x[end_ind]) * 0.5 * \ + (self.y1[end_ind] + + intermediate_value(self.x[end_ind], self.x[end_ind+1], + self.y1[end_ind], self.y2[end_ind], + interval[1] + )) + a /= (interval[1]-interval[0]) + return a def add(self, f): """ Adds another PieceWiseLin function to this function. diff --git a/test/test_function.py b/test/test_function.py index cabcf44..ed0b2ed 100644 --- a/test/test_function.py +++ b/test/test_function.py @@ -31,8 +31,13 @@ def test_pwc(): # interval averaging a = f.avrg([0.5, 3.5]) - print(a) assert_almost_equal(a, (0.5-0.5+0.5*1.5+1.0*0.75)/3.0, decimal=16) + a = f.avrg([1.5, 3.5]) + assert_almost_equal(a, (-0.5*0.5+0.5*1.5+1.0*0.75)/2.0, decimal=16) + a = f.avrg([1.0, 3.5]) + assert_almost_equal(a, (-0.5*1.0+0.5*1.5+1.0*0.75)/2.5, decimal=16) + a = f.avrg([1.0, 4.0]) + assert_almost_equal(a, (-0.5*1.0+0.5*1.5+1.5*0.75)/3.0, decimal=16) def test_pwc_add(): @@ -105,6 +110,18 @@ def test_pwl(): avrg_expected = (1.25 - 0.45 + 0.75 + 1.5*0.5) / 4.0 assert_almost_equal(f.avrg(), avrg_expected, decimal=16) + # interval averaging + a = f.avrg([0.5, 2.5]) + assert_almost_equal(a, (1.375*0.5 - 0.45 + 0.75)/2.0, decimal=16) + a = f.avrg([1.5, 3.5]) + assert_almost_equal(a, (-0.425*0.5 + 0.75 + (0.75+0.75-0.5/1.5)/2) / 2.0, + decimal=16) + a = f.avrg([1.0, 3.5]) + assert_almost_equal(a, (-0.45 + 0.75 + (0.75+0.75-0.5/1.5)/2) / 2.5, + decimal=16) + a = f.avrg([1.0, 4.0]) + assert_almost_equal(a, (-0.45 + 0.75 + 1.5*0.5) / 3.0, decimal=16) + def test_pwl_add(): x = [0.0, 1.0, 2.0, 2.5, 4.0] -- cgit v1.2.3