summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-05-12 18:45:03 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-05-12 18:45:03 +0200
commit6f418a5a837939b132967bcdb3ff0ede6d899bd2 (patch)
tree6a90d51c1b091903c0a32c6f745ffd0be1a02534
parenta35402c208bd0ad31e5e60b6ddc55a3470e7bdde (diff)
+functions to obtain values of the pwc/pwl profile
Added __call__ operators to PieceWiseConst and PieceWiseLin class for obtaining function values at certain points in time.
-rw-r--r--pyspike/PieceWiseConstFunc.py22
-rw-r--r--pyspike/PieceWiseLinFunc.py30
-rw-r--r--test/test_function.py32
3 files changed, 83 insertions, 1 deletions
diff --git a/pyspike/PieceWiseConstFunc.py b/pyspike/PieceWiseConstFunc.py
index 41998ef..cf64e58 100644
--- a/pyspike/PieceWiseConstFunc.py
+++ b/pyspike/PieceWiseConstFunc.py
@@ -26,6 +26,28 @@ class PieceWiseConstFunc(object):
self.x = np.array(x)
self.y = np.array(y)
+ def __call__(self, t):
+ """ Returns the function value for the given time t. If t is a list of
+ times, the corresponding list of values is returned.
+
+ :param: time t, or list of times
+ :returns: function value(s) at that time(s).
+ """
+ assert np.all(t >= self.x[0]) and np.all(t <= self.x[-1]), \
+ "Invalid time: " + str(t)
+
+ ind = np.searchsorted(self.x, t, side='right')
+ # correct the cases t == x[0], t == x[-1]
+ try:
+ ind[ind == 0] = 1
+ ind[ind == len(self.x)] = len(self.x)-1
+ except TypeError:
+ if ind == 0:
+ ind = 1
+ if ind == len(self.x):
+ ind = len(self.x)-1
+ return self.y[ind-1]
+
def copy(self):
""" Returns a copy of itself
diff --git a/pyspike/PieceWiseLinFunc.py b/pyspike/PieceWiseLinFunc.py
index f2442be..b9787eb 100644
--- a/pyspike/PieceWiseLinFunc.py
+++ b/pyspike/PieceWiseLinFunc.py
@@ -29,6 +29,36 @@ class PieceWiseLinFunc:
self.y1 = np.array(y1)
self.y2 = np.array(y2)
+ def __call__(self, t):
+ """ Returns the function value for the given time t. If t is a list of
+ times, the corresponding list of values is returned.
+
+ :param: time t, or list of times
+ :returns: function value(s) at that time(s).
+ """
+ def intermediate_value(x0, x1, y0, y1, x):
+ """ computes the intermediate value of a linear function """
+ return y0 + (y1-y0)*(x-x0)/(x1-x0)
+
+ assert np.all(t >= self.x[0]) and np.all(t <= self.x[-1]), \
+ "Invalid time: " + str(t)
+
+ ind = np.searchsorted(self.x, t, side='right')
+ # correct the cases t == x[0], t == x[-1]
+ try:
+ ind[ind == 0] = 1
+ ind[ind == len(self.x)] = len(self.x)-1
+ except TypeError:
+ if ind == 0:
+ ind = 1
+ if ind == len(self.x):
+ ind = len(self.x)-1
+ return intermediate_value(self.x[ind-1],
+ self.x[ind],
+ self.y1[ind-1],
+ self.y2[ind-1],
+ t)
+
def copy(self):
""" Returns a copy of itself
diff --git a/test/test_function.py b/test/test_function.py
index d81b03a..c56a295 100644
--- a/test/test_function.py
+++ b/test/test_function.py
@@ -10,7 +10,8 @@ Distributed under the BSD License
from __future__ import print_function
import numpy as np
from copy import copy
-from numpy.testing import assert_almost_equal, assert_array_almost_equal
+from numpy.testing import assert_equal, assert_almost_equal, \
+ assert_array_equal, assert_array_almost_equal
import pyspike as spk
@@ -20,6 +21,17 @@ def test_pwc():
x = [0.0, 1.0, 2.0, 2.5, 4.0]
y = [1.0, -0.5, 1.5, 0.75]
f = spk.PieceWiseConstFunc(x, y)
+
+ # function values
+ assert_equal(f(0.0), 1.0)
+ assert_equal(f(0.5), 1.0)
+ assert_equal(f(2.25), 1.5)
+ assert_equal(f(3.5), 0.75)
+ assert_equal(f(4.0), 0.75)
+
+ assert_array_equal(f([0.0, 0.5, 2.25, 3.5, 4.0]),
+ [1.0, 1.0, 1.5, 0.75, 0.75])
+
xp, yp = f.get_plottable_data()
xp_expected = [0.0, 1.0, 1.0, 2.0, 2.0, 2.5, 2.5, 4.0]
@@ -38,11 +50,17 @@ def test_pwc():
assert_almost_equal(a, (-0.5*1.0+0.5*1.5+1.0*0.75)/2.5, decimal=16)
a = f.avrg([1.0, 4.0])
assert_almost_equal(a, (-0.5*1.0+0.5*1.5+1.5*0.75)/3.0, decimal=16)
+ a = f.avrg([0.0, 2.2])
+ assert_almost_equal(a, (1.0*1.0-0.5*1.0+0.2*1.5)/2.2, decimal=15)
# averaging over multiple intervals
a = f.avrg([(0.5, 1.5), (1.5, 3.5)])
assert_almost_equal(a, (0.5-0.5+0.5*1.5+1.0*0.75)/3.0, decimal=16)
+ # averaging over multiple intervals
+ a = f.avrg([(0.5, 1.5), (2.2, 3.5)])
+ assert_almost_equal(a, (0.5*1.0-0.5*0.5+0.3*1.5+1.0*0.75)/2.3, decimal=15)
+
def test_pwc_add():
# some random data
@@ -105,6 +123,18 @@ def test_pwl():
y1 = [1.0, -0.5, 1.5, 0.75]
y2 = [1.5, -0.4, 1.5, 0.25]
f = spk.PieceWiseLinFunc(x, y1, y2)
+
+ # function values
+ assert_equal(f(0.0), 1.0)
+ assert_equal(f(0.5), 1.25)
+ assert_equal(f(2.25), 1.5)
+ assert_equal(f(2.5), 0.75)
+ assert_equal(f(3.5), 0.75-0.5*1.0/1.5)
+ assert_equal(f(4.0), 0.25)
+
+ assert_array_equal(f([0.0, 0.5, 2.25, 2.5, 3.5, 4.0]),
+ [1.0, 1.25, 1.5, 0.75, 0.75-0.5*1.0/1.5, 0.25])
+
xp, yp = f.get_plottable_data()
xp_expected = [0.0, 1.0, 1.0, 2.0, 2.0, 2.5, 2.5, 4.0]