summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2018-09-19 16:53:50 -0700
committerMario Mulansky <mario.mulansky@gmx.net>2018-09-19 16:58:11 -0700
commit50b85d976f2f5ec7e40faec1ede047cf45b10bc1 (patch)
tree98f7f1b61cc2ee68c8c6dab1a5a0d651d6fd38b4
parent5178df8d74bcdde310f8007ac891dd8a1bf4a9d2 (diff)
Fix incorrect integrals in PieceWiseLinFunc (#38)
Integrals of piece-wise linear functions were incorrect if the requested interval lies completely between two support points. This has been fixed, and a unit test exercising this behavior was added. Fixes #38
-rw-r--r--pyspike/PieceWiseLinFunc.py42
-rw-r--r--test/test_function.py14
2 files changed, 42 insertions, 14 deletions
diff --git a/pyspike/PieceWiseLinFunc.py b/pyspike/PieceWiseLinFunc.py
index 8145e63..8faaec4 100644
--- a/pyspike/PieceWiseLinFunc.py
+++ b/pyspike/PieceWiseLinFunc.py
@@ -146,31 +146,47 @@ class PieceWiseLinFunc:
if interval is None:
# no interval given, integrate over the whole spike train
- integral = np.sum((self.x[1:]-self.x[:-1]) * 0.5*(self.y1+self.y2))
+ return np.sum((self.x[1:]-self.x[:-1]) * 0.5*(self.y1+self.y2))
+
+ # find the indices corresponding to the interval
+ start_ind = np.searchsorted(self.x, interval[0], side='right')
+ end_ind = np.searchsorted(self.x, interval[1], side='left')-1
+ assert start_ind > 0 and end_ind < len(self.x), \
+ "Invalid averaging interval"
+ if start_ind > end_ind:
+ print(start_ind, end_ind, self.x[start_ind])
+ # contribution from between two closest edges
+ y_x0 = intermediate_value(self.x[start_ind-1],
+ self.x[start_ind],
+ self.y1[start_ind-1],
+ self.y2[start_ind-1],
+ interval[0])
+ y_x1 = intermediate_value(self.x[start_ind-1],
+ self.x[start_ind],
+ self.y1[start_ind-1],
+ self.y2[start_ind-1],
+ interval[1])
+ print(y_x0, y_x1, interval[1] - interval[0])
+ integral = (y_x0 + y_x1) * 0.5 * (interval[1] - interval[0])
+ print(integral)
else:
- # find the indices corresponding to the interval
- start_ind = np.searchsorted(self.x, interval[0], side='right')
- end_ind = np.searchsorted(self.x, interval[1], side='left')-1
- assert start_ind > 0 and end_ind < len(self.x), \
- "Invalid averaging interval"
# first the contribution from between the indices
integral = np.sum((self.x[start_ind+1:end_ind+1] -
- self.x[start_ind:end_ind]) *
- 0.5*(self.y1[start_ind:end_ind] +
- self.y2[start_ind:end_ind]))
+ self.x[start_ind:end_ind]) *
+ 0.5*(self.y1[start_ind:end_ind] +
+ self.y2[start_ind:end_ind]))
# correction from start to first index
integral += (self.x[start_ind]-interval[0]) * 0.5 * \
(self.y2[start_ind-1] +
- intermediate_value(self.x[start_ind-1],
+ intermediate_value(self.x[start_ind-1],
self.x[start_ind],
self.y1[start_ind-1],
self.y2[start_ind-1],
- interval[0]
- ))
+ interval[0]))
# correction from last index to end
integral += (interval[1]-self.x[end_ind]) * 0.5 * \
(self.y1[end_ind] +
- intermediate_value(self.x[end_ind], self.x[end_ind+1],
+ intermediate_value(self.x[end_ind], self.x[end_ind+1],
self.y1[end_ind], self.y2[end_ind],
interval[1]
))
diff --git a/test/test_function.py b/test/test_function.py
index ddeb4b1..6c04839 100644
--- a/test/test_function.py
+++ b/test/test_function.py
@@ -136,7 +136,7 @@ def test_pwc_integral():
# test part interval, spanning an edge
assert_equal(f1.integral((0.5,1.5)), 0.5*1.0 + 0.5*-0.5)
# test part interval, just over two edges
- assert_almost_equal(f1.integral((1.0-1e-16,2+1e-16)), 1.0*-0.5, decimal=16)
+ assert_almost_equal(f1.integral((1.0-1e-16,2+1e-16)), 1.0*-0.5, decimal=14)
# test part interval, between two edges
assert_equal(f1.integral((1.0,2.0)), 1.0*-0.5)
assert_equal(f1.integral((1.2,1.7)), (1.7-1.2)*-0.5)
@@ -212,6 +212,18 @@ def test_pwl():
a = f.avrg([1.0, 4.0])
assert_almost_equal(a, (-0.45 + 0.75 + 1.5*0.5) / 3.0, decimal=16)
+ # interval between support points
+ a = f.avrg([1.1, 1.5])
+ assert_almost_equal(a, (-0.5+0.1*0.1 - 0.45) * 0.5, decimal=14)
+
+ # starting at a support point
+ a = f.avrg([1.0, 1.5])
+ assert_almost_equal(a, (-0.5 - 0.45) * 0.5, decimal=14)
+
+ # start and end at support point
+ a = f.avrg([1.0, 2.0])
+ assert_almost_equal(a, (-0.5 - 0.4) * 0.5, decimal=14)
+
# averaging over multiple intervals
a = f.avrg([(0.5, 1.5), (1.5, 2.5)])
assert_almost_equal(a, (1.375*0.5 - 0.45 + 0.75)/2.0, decimal=16)