summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyspike/PieceWiseConstFunc.py15
-rw-r--r--pyspike/PieceWiseLinFunc.py51
-rw-r--r--test/test_function.py13
3 files changed, 56 insertions, 23 deletions
diff --git a/pyspike/PieceWiseConstFunc.py b/pyspike/PieceWiseConstFunc.py
index dea1a56..6d7a845 100644
--- a/pyspike/PieceWiseConstFunc.py
+++ b/pyspike/PieceWiseConstFunc.py
@@ -49,13 +49,16 @@ class PieceWiseConstFunc(object):
ind_l = np.searchsorted(self.x, t, side='left')
# if left and right side indices differ, the time t has to appear
# in self.x
- ind_at_spike = ind[np.logical_and(np.logical_and(ind != ind_l,
- ind > 1),
- ind < len(self.x))]
- value[ind_at_spike] = 0.5 * (self.y[ind_at_spike-1] +
- self.y[ind_at_spike-2])
+ ind_at_spike = np.logical_and(np.logical_and(ind != ind_l,
+ ind > 1),
+ ind < len(self.x))
+ # get the corresponding indices for the resulting value array
+ val_ind = np.arange(len(ind))[ind_at_spike]
+ # and for the arrays self.x, y1, y2
+ xy_ind = ind[ind_at_spike]
+ value[val_ind] = 0.5 * (self.y[xy_ind-1] + self.y[xy_ind-2])
return value
- else:
+ else: # t is a single value
# specific check for interval edges
if t == self.x[0]:
return self.y[0]
diff --git a/pyspike/PieceWiseLinFunc.py b/pyspike/PieceWiseLinFunc.py
index b9787eb..03c2da2 100644
--- a/pyspike/PieceWiseLinFunc.py
+++ b/pyspike/PieceWiseLinFunc.py
@@ -44,20 +44,47 @@ class PieceWiseLinFunc:
"Invalid time: " + str(t)
ind = np.searchsorted(self.x, t, side='right')
- # correct the cases t == x[0], t == x[-1]
- try:
+ if isinstance(t, collections.Sequence):
+ # t is a sequence of values
+ # correct the cases t == x[0], t == x[-1]
ind[ind == 0] = 1
ind[ind == len(self.x)] = len(self.x)-1
- except TypeError:
- if ind == 0:
- ind = 1
- if ind == len(self.x):
- ind = len(self.x)-1
- return intermediate_value(self.x[ind-1],
- self.x[ind],
- self.y1[ind-1],
- self.y2[ind-1],
- t)
+ value = intermediate_value(self.x[ind-1],
+ self.x[ind],
+ self.y1[ind-1],
+ self.y2[ind-1],
+ t)
+ # correct the values at exact spike times: there the value should
+ # be the at half of the step
+ # obtain the 'left' side indices for t
+ ind_l = np.searchsorted(self.x, t, side='left')
+ # if left and right side indices differ, the time t has to appear
+ # in self.x
+ ind_at_spike = np.logical_and(np.logical_and(ind != ind_l,
+ ind > 1),
+ ind < len(self.x))
+ # get the corresponding indices for the resulting value array
+ val_ind = np.arange(len(ind))[ind_at_spike]
+ # and for the values in self.x, y1, y2
+ xy_ind = ind[ind_at_spike]
+ # the values are defined as the average of the left and right limit
+ value[val_ind] = 0.5 * (self.y1[xy_ind-1] + self.y2[xy_ind-2])
+ return value
+ else: # t is a single value
+ # specific check for interval edges
+ if t == self.x[0]:
+ return self.y1[0]
+ if t == self.x[-1]:
+ return self.y2[-1]
+ # check if we are on any other exact spike time
+ if sum(self.x == t) > 0:
+ # use the middle of the left and right Spike value
+ return 0.5 * (self.y1[ind-1] + self.y2[ind-2])
+ return intermediate_value(self.x[ind-1],
+ self.x[ind],
+ self.y1[ind-1],
+ self.y2[ind-1],
+ t)
def copy(self):
""" Returns a copy of itself
diff --git a/test/test_function.py b/test/test_function.py
index 8ad4b17..92d378d 100644
--- a/test/test_function.py
+++ b/test/test_function.py
@@ -26,13 +26,14 @@ def test_pwc():
assert_equal(f(0.0), 1.0)
assert_equal(f(0.5), 1.0)
assert_equal(f(1.0), 0.25)
+ assert_equal(f(2.0), 0.5)
assert_equal(f(2.25), 1.5)
assert_equal(f(2.5), 2.25/2)
assert_equal(f(3.5), 0.75)
assert_equal(f(4.0), 0.75)
- assert_array_equal(f([0.0, 0.5, 1.0, 2.25, 2.5, 3.5, 4.0]),
- [1.0, 1.0, 0.25, 1.5, 2.25/2, 0.75, 0.75])
+ assert_array_equal(f([0.0, 0.5, 1.0, 2.0, 2.25, 2.5, 3.5, 4.0]),
+ [1.0, 1.0, 0.25, 0.5, 1.5, 2.25/2, 0.75, 0.75])
xp, yp = f.get_plottable_data()
@@ -129,13 +130,15 @@ def test_pwl():
# function values
assert_equal(f(0.0), 1.0)
assert_equal(f(0.5), 1.25)
+ assert_equal(f(1.0), 0.5)
+ assert_equal(f(2.0), 1.1/2)
assert_equal(f(2.25), 1.5)
- assert_equal(f(2.5), 0.75)
+ assert_equal(f(2.5), 2.25/2)
assert_equal(f(3.5), 0.75-0.5*1.0/1.5)
assert_equal(f(4.0), 0.25)
- assert_array_equal(f([0.0, 0.5, 2.25, 2.5, 3.5, 4.0]),
- [1.0, 1.25, 1.5, 0.75, 0.75-0.5*1.0/1.5, 0.25])
+ assert_array_equal(f([0.0, 0.5, 1.0, 2.0, 2.25, 2.5, 3.5, 4.0]),
+ [1.0, 1.25, 0.5, 0.55, 1.5, 2.25/2, 0.75-0.5/1.5, 0.25])
xp, yp = f.get_plottable_data()