summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-05-17 17:50:39 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-05-17 17:50:39 +0200
commita61a14295e28e6e95fa510693a11ae8c78a552ab (patch)
tree9d0194163e2ce6727603fde241c03b470339dad1
parent8841138b74242ed9eb77c972c76e9a617778a79a (diff)
return correct values at exact spike times
pwc and pwl function object return the average of the left and right limit as function value at the exact spike times.
-rw-r--r--pyspike/PieceWiseConstFunc.py15
-rw-r--r--pyspike/PieceWiseLinFunc.py51
-rw-r--r--test/test_function.py13
3 files changed, 56 insertions, 23 deletions
diff --git a/pyspike/PieceWiseConstFunc.py b/pyspike/PieceWiseConstFunc.py
index dea1a56..6d7a845 100644
--- a/pyspike/PieceWiseConstFunc.py
+++ b/pyspike/PieceWiseConstFunc.py
@@ -49,13 +49,16 @@ class PieceWiseConstFunc(object):
ind_l = np.searchsorted(self.x, t, side='left')
# if left and right side indices differ, the time t has to appear
# in self.x
- ind_at_spike = ind[np.logical_and(np.logical_and(ind != ind_l,
- ind > 1),
- ind < len(self.x))]
- value[ind_at_spike] = 0.5 * (self.y[ind_at_spike-1] +
- self.y[ind_at_spike-2])
+ ind_at_spike = np.logical_and(np.logical_and(ind != ind_l,
+ ind > 1),
+ ind < len(self.x))
+ # get the corresponding indices for the resulting value array
+ val_ind = np.arange(len(ind))[ind_at_spike]
+ # and for the arrays self.x, y1, y2
+ xy_ind = ind[ind_at_spike]
+ value[val_ind] = 0.5 * (self.y[xy_ind-1] + self.y[xy_ind-2])
return value
- else:
+ else: # t is a single value
# specific check for interval edges
if t == self.x[0]:
return self.y[0]
diff --git a/pyspike/PieceWiseLinFunc.py b/pyspike/PieceWiseLinFunc.py
index b9787eb..03c2da2 100644
--- a/pyspike/PieceWiseLinFunc.py
+++ b/pyspike/PieceWiseLinFunc.py
@@ -44,20 +44,47 @@ class PieceWiseLinFunc:
"Invalid time: " + str(t)
ind = np.searchsorted(self.x, t, side='right')
- # correct the cases t == x[0], t == x[-1]
- try:
+ if isinstance(t, collections.Sequence):
+ # t is a sequence of values
+ # correct the cases t == x[0], t == x[-1]
ind[ind == 0] = 1
ind[ind == len(self.x)] = len(self.x)-1
- except TypeError:
- if ind == 0:
- ind = 1
- if ind == len(self.x):
- ind = len(self.x)-1
- return intermediate_value(self.x[ind-1],
- self.x[ind],
- self.y1[ind-1],
- self.y2[ind-1],
- t)
+ value = intermediate_value(self.x[ind-1],
+ self.x[ind],
+ self.y1[ind-1],
+ self.y2[ind-1],
+ t)
+ # correct the values at exact spike times: there the value should
+ # be the at half of the step
+ # obtain the 'left' side indices for t
+ ind_l = np.searchsorted(self.x, t, side='left')
+ # if left and right side indices differ, the time t has to appear
+ # in self.x
+ ind_at_spike = np.logical_and(np.logical_and(ind != ind_l,
+ ind > 1),
+ ind < len(self.x))
+ # get the corresponding indices for the resulting value array
+ val_ind = np.arange(len(ind))[ind_at_spike]
+ # and for the values in self.x, y1, y2
+ xy_ind = ind[ind_at_spike]
+ # the values are defined as the average of the left and right limit
+ value[val_ind] = 0.5 * (self.y1[xy_ind-1] + self.y2[xy_ind-2])
+ return value
+ else: # t is a single value
+ # specific check for interval edges
+ if t == self.x[0]:
+ return self.y1[0]
+ if t == self.x[-1]:
+ return self.y2[-1]
+ # check if we are on any other exact spike time
+ if sum(self.x == t) > 0:
+ # use the middle of the left and right Spike value
+ return 0.5 * (self.y1[ind-1] + self.y2[ind-2])
+ return intermediate_value(self.x[ind-1],
+ self.x[ind],
+ self.y1[ind-1],
+ self.y2[ind-1],
+ t)
def copy(self):
""" Returns a copy of itself
diff --git a/test/test_function.py b/test/test_function.py
index 8ad4b17..92d378d 100644
--- a/test/test_function.py
+++ b/test/test_function.py
@@ -26,13 +26,14 @@ def test_pwc():
assert_equal(f(0.0), 1.0)
assert_equal(f(0.5), 1.0)
assert_equal(f(1.0), 0.25)
+ assert_equal(f(2.0), 0.5)
assert_equal(f(2.25), 1.5)
assert_equal(f(2.5), 2.25/2)
assert_equal(f(3.5), 0.75)
assert_equal(f(4.0), 0.75)
- assert_array_equal(f([0.0, 0.5, 1.0, 2.25, 2.5, 3.5, 4.0]),
- [1.0, 1.0, 0.25, 1.5, 2.25/2, 0.75, 0.75])
+ assert_array_equal(f([0.0, 0.5, 1.0, 2.0, 2.25, 2.5, 3.5, 4.0]),
+ [1.0, 1.0, 0.25, 0.5, 1.5, 2.25/2, 0.75, 0.75])
xp, yp = f.get_plottable_data()
@@ -129,13 +130,15 @@ def test_pwl():
# function values
assert_equal(f(0.0), 1.0)
assert_equal(f(0.5), 1.25)
+ assert_equal(f(1.0), 0.5)
+ assert_equal(f(2.0), 1.1/2)
assert_equal(f(2.25), 1.5)
- assert_equal(f(2.5), 0.75)
+ assert_equal(f(2.5), 2.25/2)
assert_equal(f(3.5), 0.75-0.5*1.0/1.5)
assert_equal(f(4.0), 0.25)
- assert_array_equal(f([0.0, 0.5, 2.25, 2.5, 3.5, 4.0]),
- [1.0, 1.25, 1.5, 0.75, 0.75-0.5*1.0/1.5, 0.25])
+ assert_array_equal(f([0.0, 0.5, 1.0, 2.0, 2.25, 2.5, 3.5, 4.0]),
+ [1.0, 1.25, 0.5, 0.55, 1.5, 2.25/2, 0.75-0.5/1.5, 0.25])
xp, yp = f.get_plottable_data()