From a61a14295e28e6e95fa510693a11ae8c78a552ab Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Sun, 17 May 2015 17:50:39 +0200 Subject: return correct values at exact spike times pwc and pwl function object return the average of the left and right limit as function value at the exact spike times. --- pyspike/PieceWiseConstFunc.py | 15 ++++++++----- pyspike/PieceWiseLinFunc.py | 51 +++++++++++++++++++++++++++++++++---------- test/test_function.py | 13 ++++++----- 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/pyspike/PieceWiseConstFunc.py b/pyspike/PieceWiseConstFunc.py index dea1a56..6d7a845 100644 --- a/pyspike/PieceWiseConstFunc.py +++ b/pyspike/PieceWiseConstFunc.py @@ -49,13 +49,16 @@ class PieceWiseConstFunc(object): ind_l = np.searchsorted(self.x, t, side='left') # if left and right side indices differ, the time t has to appear # in self.x - ind_at_spike = ind[np.logical_and(np.logical_and(ind != ind_l, - ind > 1), - ind < len(self.x))] - value[ind_at_spike] = 0.5 * (self.y[ind_at_spike-1] + - self.y[ind_at_spike-2]) + ind_at_spike = np.logical_and(np.logical_and(ind != ind_l, + ind > 1), + ind < len(self.x)) + # get the corresponding indices for the resulting value array + val_ind = np.arange(len(ind))[ind_at_spike] + # and for the arrays self.x, y1, y2 + xy_ind = ind[ind_at_spike] + value[val_ind] = 0.5 * (self.y[xy_ind-1] + self.y[xy_ind-2]) return value - else: + else: # t is a single value # specific check for interval edges if t == self.x[0]: return self.y[0] diff --git a/pyspike/PieceWiseLinFunc.py b/pyspike/PieceWiseLinFunc.py index b9787eb..03c2da2 100644 --- a/pyspike/PieceWiseLinFunc.py +++ b/pyspike/PieceWiseLinFunc.py @@ -44,20 +44,47 @@ class PieceWiseLinFunc: "Invalid time: " + str(t) ind = np.searchsorted(self.x, t, side='right') - # correct the cases t == x[0], t == x[-1] - try: + if isinstance(t, collections.Sequence): + # t is a sequence of values + # correct the cases t == x[0], t == x[-1] ind[ind == 0] = 1 ind[ind == len(self.x)] = len(self.x)-1 - except TypeError: - if ind == 0: - ind = 1 - if ind == len(self.x): - ind = len(self.x)-1 - return intermediate_value(self.x[ind-1], - self.x[ind], - self.y1[ind-1], - self.y2[ind-1], - t) + value = intermediate_value(self.x[ind-1], + self.x[ind], + self.y1[ind-1], + self.y2[ind-1], + t) + # correct the values at exact spike times: there the value should + # be the at half of the step + # obtain the 'left' side indices for t + ind_l = np.searchsorted(self.x, t, side='left') + # if left and right side indices differ, the time t has to appear + # in self.x + ind_at_spike = np.logical_and(np.logical_and(ind != ind_l, + ind > 1), + ind < len(self.x)) + # get the corresponding indices for the resulting value array + val_ind = np.arange(len(ind))[ind_at_spike] + # and for the values in self.x, y1, y2 + xy_ind = ind[ind_at_spike] + # the values are defined as the average of the left and right limit + value[val_ind] = 0.5 * (self.y1[xy_ind-1] + self.y2[xy_ind-2]) + return value + else: # t is a single value + # specific check for interval edges + if t == self.x[0]: + return self.y1[0] + if t == self.x[-1]: + return self.y2[-1] + # check if we are on any other exact spike time + if sum(self.x == t) > 0: + # use the middle of the left and right Spike value + return 0.5 * (self.y1[ind-1] + self.y2[ind-2]) + return intermediate_value(self.x[ind-1], + self.x[ind], + self.y1[ind-1], + self.y2[ind-1], + t) def copy(self): """ Returns a copy of itself diff --git a/test/test_function.py b/test/test_function.py index 8ad4b17..92d378d 100644 --- a/test/test_function.py +++ b/test/test_function.py @@ -26,13 +26,14 @@ def test_pwc(): assert_equal(f(0.0), 1.0) assert_equal(f(0.5), 1.0) assert_equal(f(1.0), 0.25) + assert_equal(f(2.0), 0.5) assert_equal(f(2.25), 1.5) assert_equal(f(2.5), 2.25/2) assert_equal(f(3.5), 0.75) assert_equal(f(4.0), 0.75) - assert_array_equal(f([0.0, 0.5, 1.0, 2.25, 2.5, 3.5, 4.0]), - [1.0, 1.0, 0.25, 1.5, 2.25/2, 0.75, 0.75]) + assert_array_equal(f([0.0, 0.5, 1.0, 2.0, 2.25, 2.5, 3.5, 4.0]), + [1.0, 1.0, 0.25, 0.5, 1.5, 2.25/2, 0.75, 0.75]) xp, yp = f.get_plottable_data() @@ -129,13 +130,15 @@ def test_pwl(): # function values assert_equal(f(0.0), 1.0) assert_equal(f(0.5), 1.25) + assert_equal(f(1.0), 0.5) + assert_equal(f(2.0), 1.1/2) assert_equal(f(2.25), 1.5) - assert_equal(f(2.5), 0.75) + assert_equal(f(2.5), 2.25/2) assert_equal(f(3.5), 0.75-0.5*1.0/1.5) assert_equal(f(4.0), 0.25) - assert_array_equal(f([0.0, 0.5, 2.25, 2.5, 3.5, 4.0]), - [1.0, 1.25, 1.5, 0.75, 0.75-0.5*1.0/1.5, 0.25]) + assert_array_equal(f([0.0, 0.5, 1.0, 2.0, 2.25, 2.5, 3.5, 4.0]), + [1.0, 1.25, 0.5, 0.55, 1.5, 2.25/2, 0.75-0.5/1.5, 0.25]) xp, yp = f.get_plottable_data() -- cgit v1.2.3