diff options
-rw-r--r-- | pyspike/function.py | 43 | ||||
-rw-r--r-- | test/test_function.py | 53 |
2 files changed, 82 insertions, 14 deletions
diff --git a/pyspike/function.py b/pyspike/function.py index 7177a3d..ef6b0f1 100644 --- a/pyspike/function.py +++ b/pyspike/function.py @@ -20,8 +20,8 @@ class PieceWiseConstFunc: function. - y: array of length N defining the function values at the intervals. """ - self.x = x - self.y = y + self.x = np.array(x) + self.y = np.array(y) def get_plottable_data(self): """ Returns two arrays containing x- and y-coordinates for immeditate @@ -66,29 +66,44 @@ class PieceWiseConstFunc: assert self.x[0] == f.x[0], "The functions have different intervals" assert self.x[-1] == f.x[-1], "The functions have different intervals" x_new = np.empty(len(self.x) + len(f.x)) - y_new = np.empty_like(x_new) + y_new = np.empty(len(x_new)-1) x_new[0] = self.x[0] y_new[0] = self.y[0] + f.y[0] - index1 = 1 - index2 = 1 - index = 1 - while (index1+1 < len(self.x)) and (index2+1 < len(f.x)): + index1 = 0 + index2 = 0 + index = 0 + while (index1+1 < len(self.y)) and (index2+1 < len(f.y)): + index += 1 + # print(index1+1, self.x[index1+1], self.y[index1+1], x_new[index]) if self.x[index1+1] < f.x[index2+1]: - x_new[index] = self.x[index1] index1 += 1 + x_new[index] = self.x[index1] elif self.x[index1+1] > f.x[index2+1]: - x_new[index] = f.x[index2+1] index2 += 1 + x_new[index] = f.x[index2] else: # self.x[index1+1] == f.x[index2+1]: - x_new[index] = self.x[index1] index1 += 1 index2 += 1 - index += 1 + x_new[index] = self.x[index1] y_new[index] = self.y[index1] + f.y[index2] - # both indices should have reached the maximum simultaneously - assert (index1+1 == len(self.x)) and (index2+1 == len(f.x)) + # one array reached the end -> copy the contents of the other to the end + if index1+1 < len(self.y): + x_new[index+1:index+1+len(self.x)-index1-1] = self.x[index1+1:] + y_new[index+1:index+1+len(self.y)-index1-1] = self.y[index1+1:] + \ + f.y[-1] + index += len(self.x)-index1-2 + elif index2+1 < len(f.y): + x_new[index+1:index+1+len(f.x)-index2-1] = f.x[index2+1:] + y_new[index+1:index+1+len(f.y)-index2-1] = f.y[index2+1:] + \ + self.y[-1] + index += len(f.x)-index2-2 + else: # both arrays reached the end simultaneously + # only the last x-value missing + x_new[index+1] = self.x[-1] + # the last value is again the end of the interval + # x_new[index+1] = self.x[-1] # only use the data that was actually filled - self.x = x_new[:index+1] + self.x = x_new[:index+2] self.y = y_new[:index+1] class PieceWiseLinFunc: diff --git a/test/test_function.py b/test/test_function.py new file mode 100644 index 0000000..386b999 --- /dev/null +++ b/test/test_function.py @@ -0,0 +1,53 @@ +""" test_function.py + +Tests the PieceWiseConst and PieceWiseLinear functions + +Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net> +""" + +from __future__ import print_function +import numpy as np +from copy import copy +from numpy.testing import assert_equal, assert_almost_equal, \ + assert_array_almost_equal + +import pyspike as spk + +def test_pwc(): + # some random data + x = [0.0, 1.0, 2.0, 2.5, 4.0] + y = [1.0, -0.5, 1.5, 0.75] + f = spk.PieceWiseConstFunc(x, y) + xp, yp = f.get_plottable_data() + + xp_expected = [0.0, 1.0, 1.0, 2.0, 2.0, 2.5, 2.5, 4.0] + yp_expected = [1.0, 1.0, -0.5, -0.5, 1.5, 1.5, 0.75, 0.75] + assert_array_almost_equal(xp, xp_expected) + assert_array_almost_equal(yp, yp_expected) + + assert_almost_equal(f.avrg(), (1.0-0.5+0.5*1.5+1.5*0.75)/4.0, decimal=16) + assert_almost_equal(f.abs_avrg(), (1.0+0.5+0.5*1.5+1.5*0.75)/4.0, + decimal=16) + + f1 = copy(f) + x = [0.0, 0.75, 2.0, 2.5, 2.7, 4.0] + y = [0.5, 1.0, -0.25, 0.0, 1.5] + f2 = spk.PieceWiseConstFunc(x, y) + f1.add(f2) + x_expected = [0.0, 0.75, 1.0, 2.0, 2.5, 2.7, 4.0] + y_expected = [1.5, 2.0, 0.5, 1.25, 0.75, 2.25] + assert_array_almost_equal(f1.x, x_expected, decimal=16) + assert_array_almost_equal(f1.y, y_expected, decimal=16) + + f2.add(f) + assert_array_almost_equal(f2.x, x_expected, decimal=16) + assert_array_almost_equal(f2.y, y_expected, decimal=16) + + f1.add(f2) + # same x, but y doubled + assert_array_almost_equal(f1.x, f2.x, decimal=16) + assert_array_almost_equal(f1.y, 2*f2.y, decimal=16) + + +if __name__ == "__main__": + test_pwc() |