summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2014-09-21 14:09:55 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2014-09-21 14:09:55 +0200
commitab22a4d07ec803f1d52a505442989c19f343aa35 (patch)
tree5b25a1ac464093b26229ea50548e23eff77fc534
parentb20d416c4765b2280526c633ca62f43677b1d26a (diff)
added spike-distance test + bugfix
-rw-r--r--examples/test_data.py4
-rw-r--r--pyspike/distances.py22
-rw-r--r--test/test_distance.py66
3 files changed, 83 insertions, 9 deletions
diff --git a/examples/test_data.py b/examples/test_data.py
index 94e7d51..afafe7a 100644
--- a/examples/test_data.py
+++ b/examples/test_data.py
@@ -16,7 +16,7 @@ for line in spike_file:
for (i,spikes) in enumerate(spike_trains):
plt.plot(spikes, i*np.ones_like(spikes), 'o')
-f = spk.isi_distance(spike_trains[0], spike_trains[10], 4000)
+f = spk.isi_distance(spike_trains[0], spike_trains[1], 4000)
x, y = f.get_plottable_data()
plt.figure()
@@ -26,7 +26,7 @@ print("Average: %.8f" % f.avrg())
print("Absolute average: %.8f" % f.abs_avrg())
-f = spk.spike_distance(spike_trains[0], spike_trains[10], 4000)
+f = spk.spike_distance(spike_trains[0], spike_trains[1], 4000)
x, y = f.get_plottable_data()
print(x)
print(y)
diff --git a/pyspike/distances.py b/pyspike/distances.py
index 2ea80e7..766c75a 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -36,7 +36,6 @@ def isi_distance(spikes1, spikes2, T_end, T_start=0.0):
# compute the isi-distance
spike_events = np.empty(len(nu1)+len(nu2))
spike_events[0] = T_start
- spike_events[-1] = T_end
# the values have one entry less - the number of intervals between events
isi_values = np.empty(len(spike_events)-1)
# add the distance of the first events
@@ -48,22 +47,32 @@ def isi_distance(spikes1, spikes2, T_end, T_start=0.0):
index = 1
while True:
# check which spike is next - from s1 or s2
- if s1[index1+1] <= s2[index2+1]:
+ if s1[index1+1] < s2[index2+1]:
index1 += 1
# break condition relies on existence of spikes at T_end
if index1 >= len(nu1):
break
spike_events[index] = s1[index1]
- else:
+ elif s1[index1+1] > s2[index2+1]:
index2 += 1
if index2 >= len(nu2):
break
spike_events[index] = s2[index2]
+ else: # s1[index1+1] == s2[index2+1]
+ index1 += 1
+ index2 += 1
+ if (index1 >= len(nu1)) or (index2 >= len(nu2)):
+ break
+ spike_events[index] = s1[index1]
# compute the corresponding isi-distance
isi_values[index] = (nu1[index1]-nu2[index2]) / \
max(nu1[index1], nu2[index2])
index += 1
- return PieceWiseConstFunc(spike_events, isi_values)
+ # the last event is the interval end
+ spike_events[index] = T_end
+ # use only the data added above
+ # could be less than original length due to equal spike times
+ return PieceWiseConstFunc(spike_events[:index+1], isi_values[:index])
def get_min_dist(spike_time, spike_train, start_index=0):
@@ -119,7 +128,7 @@ def spike_distance(spikes1, spikes2, T_end, T_start=0.0):
isi1 = t1[1]-t1[0]
isi2 = t2[1]-t2[0]
while True:
- print(index, index1, index2)
+ # print(index, index1, index2)
if t1[index1+1] < t2[index2+1]:
index1 += 1
# break condition relies on existence of spikes at T_end
@@ -133,7 +142,6 @@ def spike_distance(spikes1, spikes2, T_end, T_start=0.0):
y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
# now the next interval start value
dt_f1 = get_min_dist(t1[index1+1], t2, index2)
- s1 = dt_f1
isi1 = t1[index1+1]-t1[index1]
# s2 is the same as above, thus we can compute y2 immediately
y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
@@ -149,7 +157,7 @@ def spike_distance(spikes1, spikes2, T_end, T_start=0.0):
y_ends[index-1] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
# now the next interval start value
dt_f2 = get_min_dist(t2[index2+1], t1, index1)
- s2 = dt_f2
+ #s2 = dt_f2
isi2 = t2[index2+1]-t2[index2]
# s2 is the same as above, thus we can compute y2 immediately
y_starts[index] = (s1*isi2 + s2*isi1) / ((isi1+isi2)**2/2)
diff --git a/test/test_distance.py b/test/test_distance.py
new file mode 100644
index 0000000..93053e7
--- /dev/null
+++ b/test/test_distance.py
@@ -0,0 +1,66 @@
+""" test_distance.py
+
+Tests the isi- and spike-distance computation
+
+Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
+"""
+
+from __future__ import print_function
+import numpy as np
+from numpy.testing import assert_equal, assert_array_almost_equal
+
+import pyspike as spk
+
+def test_isi():
+ # generate two spike trains:
+ t1 = np.array([0.2, 0.4, 0.6, 0.7])
+ t2 = np.array([0.3, 0.45, 0.8, 0.9, 0.95])
+
+ # pen&paper calculation of the isi distance
+ expected_times = [0.0,0.2,0.3,0.4,0.45,0.6,0.7,0.8,0.9,0.95,1.0]
+ expected_isi = [-0.1/0.3, -0.1/0.3, 0.05/0.2, 0.05/0.2, -0.15/0.35,
+ -0.25/0.35, -0.05/0.35, 0.2/0.3, 0.25/0.3, 0.25/0.3]
+
+ f = spk.isi_distance(t1, t2, 1.0)
+
+ assert_equal(f.x, expected_times)
+ assert_array_almost_equal(f.y, expected_isi, decimal=14)
+
+ # check with some equal spike times
+ t1 = np.array([0.2,0.4,0.6])
+ t2 = np.array([0.1,0.4,0.5,0.6])
+
+ expected_times = [0.0,0.1,0.2,0.4,0.5,0.6,1.0]
+ expected_isi = [0.1/0.2, -0.1/0.3, -0.1/0.3, 0.1/0.2, 0.1/0.2, -0.0/0.5]
+
+ f = spk.isi_distance(t1, t2, 1.0)
+
+ assert_equal(f.x, expected_times)
+ assert_array_almost_equal(f.y, expected_isi, decimal=14)
+
+
+def test_spike():
+ # generate two spike trains:
+ t1 = np.array([0.2, 0.4, 0.6, 0.7])
+ t2 = np.array([0.3, 0.45, 0.8, 0.9, 0.95])
+
+ # pen&paper calculation of the spike distance
+ expected_times = [0.0,0.2,0.3,0.4,0.45,0.6,0.7,0.8,0.9,0.95,1.0]
+ s1 = np.array([0.0, 0.1, (0.1*0.1+0.05*0.1)/0.2, 0.05, (0.05*0.15 * 2)/0.2,
+ 0.15, 0.1, 0.1*0.2/0.3, 0.1**2/0.3, 0.1*0.05/0.3, 0.0])
+ s2 = np.array([0.0, 0.1*0.2/0.3, 0.1, (0.1*0.05 * 2)/.15, 0.05,
+ (0.05*0.2+0.1*0.15)/0.35, (0.05*0.1+0.1*0.25)/0.35, 0.1,0.1,0.05,0.0])
+ isi1 = np.array([0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.3, 0.3, 0.3, 0.3])
+ isi2 = np.array([0.3, 0.3, 0.15, 0.15, 0.35, 0.35, 0.35, 0.1, 0.05, 0.05])
+ expected_y1 = (s1[:-1]*isi2+s2[:-1]*isi1) / (0.5*(isi1+isi2)**2)
+ expected_y2 = (s1[1:]*isi2+s2[1:]*isi1) / (0.5*(isi1+isi2)**2)
+
+ f = spk.spike_distance(t1, t2, 1.0)
+
+ assert_equal(f.x, expected_times)
+ assert_array_almost_equal(f.y1, expected_y1, decimal=14)
+ assert_array_almost_equal(f.y2, expected_y2, decimal=14)
+
+if __name__ == "main":
+ test_isi()
+ test_spike()