summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyspike/function.py43
-rw-r--r--test/test_function.py53
2 files changed, 82 insertions, 14 deletions
diff --git a/pyspike/function.py b/pyspike/function.py
index 7177a3d..ef6b0f1 100644
--- a/pyspike/function.py
+++ b/pyspike/function.py
@@ -20,8 +20,8 @@ class PieceWiseConstFunc:
function.
- y: array of length N defining the function values at the intervals.
"""
- self.x = x
- self.y = y
+ self.x = np.array(x)
+ self.y = np.array(y)
def get_plottable_data(self):
""" Returns two arrays containing x- and y-coordinates for immeditate
@@ -66,29 +66,44 @@ class PieceWiseConstFunc:
assert self.x[0] == f.x[0], "The functions have different intervals"
assert self.x[-1] == f.x[-1], "The functions have different intervals"
x_new = np.empty(len(self.x) + len(f.x))
- y_new = np.empty_like(x_new)
+ y_new = np.empty(len(x_new)-1)
x_new[0] = self.x[0]
y_new[0] = self.y[0] + f.y[0]
- index1 = 1
- index2 = 1
- index = 1
- while (index1+1 < len(self.x)) and (index2+1 < len(f.x)):
+ index1 = 0
+ index2 = 0
+ index = 0
+ while (index1+1 < len(self.y)) and (index2+1 < len(f.y)):
+ index += 1
+ # print(index1+1, self.x[index1+1], self.y[index1+1], x_new[index])
if self.x[index1+1] < f.x[index2+1]:
- x_new[index] = self.x[index1]
index1 += 1
+ x_new[index] = self.x[index1]
elif self.x[index1+1] > f.x[index2+1]:
- x_new[index] = f.x[index2+1]
index2 += 1
+ x_new[index] = f.x[index2]
else: # self.x[index1+1] == f.x[index2+1]:
- x_new[index] = self.x[index1]
index1 += 1
index2 += 1
- index += 1
+ x_new[index] = self.x[index1]
y_new[index] = self.y[index1] + f.y[index2]
- # both indices should have reached the maximum simultaneously
- assert (index1+1 == len(self.x)) and (index2+1 == len(f.x))
+ # one array reached the end -> copy the contents of the other to the end
+ if index1+1 < len(self.y):
+ x_new[index+1:index+1+len(self.x)-index1-1] = self.x[index1+1:]
+ y_new[index+1:index+1+len(self.y)-index1-1] = self.y[index1+1:] + \
+ f.y[-1]
+ index += len(self.x)-index1-2
+ elif index2+1 < len(f.y):
+ x_new[index+1:index+1+len(f.x)-index2-1] = f.x[index2+1:]
+ y_new[index+1:index+1+len(f.y)-index2-1] = f.y[index2+1:] + \
+ self.y[-1]
+ index += len(f.x)-index2-2
+ else: # both arrays reached the end simultaneously
+ # only the last x-value missing
+ x_new[index+1] = self.x[-1]
+ # the last value is again the end of the interval
+ # x_new[index+1] = self.x[-1]
# only use the data that was actually filled
- self.x = x_new[:index+1]
+ self.x = x_new[:index+2]
self.y = y_new[:index+1]
class PieceWiseLinFunc:
diff --git a/test/test_function.py b/test/test_function.py
new file mode 100644
index 0000000..386b999
--- /dev/null
+++ b/test/test_function.py
@@ -0,0 +1,53 @@
+""" test_function.py
+
+Tests the PieceWiseConst and PieceWiseLinear functions
+
+Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
+"""
+
+from __future__ import print_function
+import numpy as np
+from copy import copy
+from numpy.testing import assert_equal, assert_almost_equal, \
+ assert_array_almost_equal
+
+import pyspike as spk
+
+def test_pwc():
+ # some random data
+ x = [0.0, 1.0, 2.0, 2.5, 4.0]
+ y = [1.0, -0.5, 1.5, 0.75]
+ f = spk.PieceWiseConstFunc(x, y)
+ xp, yp = f.get_plottable_data()
+
+ xp_expected = [0.0, 1.0, 1.0, 2.0, 2.0, 2.5, 2.5, 4.0]
+ yp_expected = [1.0, 1.0, -0.5, -0.5, 1.5, 1.5, 0.75, 0.75]
+ assert_array_almost_equal(xp, xp_expected)
+ assert_array_almost_equal(yp, yp_expected)
+
+ assert_almost_equal(f.avrg(), (1.0-0.5+0.5*1.5+1.5*0.75)/4.0, decimal=16)
+ assert_almost_equal(f.abs_avrg(), (1.0+0.5+0.5*1.5+1.5*0.75)/4.0,
+ decimal=16)
+
+ f1 = copy(f)
+ x = [0.0, 0.75, 2.0, 2.5, 2.7, 4.0]
+ y = [0.5, 1.0, -0.25, 0.0, 1.5]
+ f2 = spk.PieceWiseConstFunc(x, y)
+ f1.add(f2)
+ x_expected = [0.0, 0.75, 1.0, 2.0, 2.5, 2.7, 4.0]
+ y_expected = [1.5, 2.0, 0.5, 1.25, 0.75, 2.25]
+ assert_array_almost_equal(f1.x, x_expected, decimal=16)
+ assert_array_almost_equal(f1.y, y_expected, decimal=16)
+
+ f2.add(f)
+ assert_array_almost_equal(f2.x, x_expected, decimal=16)
+ assert_array_almost_equal(f2.y, y_expected, decimal=16)
+
+ f1.add(f2)
+ # same x, but y doubled
+ assert_array_almost_equal(f1.x, f2.x, decimal=16)
+ assert_array_almost_equal(f1.y, 2*f2.y, decimal=16)
+
+
+if __name__ == "__main__":
+ test_pwc()