summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-04-24 23:29:05 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-04-24 23:29:05 +0200
commit795e16ffe7afb469ef07a548c1f6a31d924196b3 (patch)
treedccc43cb074d4e5457e2462fe6e441e3d9cf8d7b
parentf7ad8e6b23f706a2371e2bc25b533b59f8dea137 (diff)
bugfixes for spike distance
-rw-r--r--pyspike/cython/cython_distance.pyx24
-rw-r--r--pyspike/cython/python_backend.py15
-rw-r--r--test/test_distance.py50
3 files changed, 58 insertions, 31 deletions
diff --git a/pyspike/cython/cython_distance.pyx b/pyspike/cython/cython_distance.pyx
index dc2557f..a41d8e8 100644
--- a/pyspike/cython/cython_distance.pyx
+++ b/pyspike/cython/cython_distance.pyx
@@ -194,31 +194,31 @@ def spike_distance_cython(double[:] t1, double[:] t2,
t_p2 = t_start
if t1[0] > t_start:
# dt_p1 = t2[0]-t_start
- dt_p1 = 0.0
t_f1 = t1[0]
dt_f1 = get_min_dist_cython(t_f1, t2, N2, 0, t_start, t_end)
isi1 = fmax(t_f1-t_start, t1[1]-t1[0])
- s1 = dt_f1*(t_f1-t_start)/isi1
+ dt_p1 = dt_f1
+ s1 = dt_p1*(t_f1-t_start)/isi1
index1 = -1
else:
- dt_p1 = 0.0
t_f1 = t1[1]
dt_f1 = get_min_dist_cython(t_f1, t2, N2, 0, t_start, t_end)
+ dt_p1 = 0.0
isi1 = t1[1]-t1[0]
s1 = dt_p1
index1 = 0
if t2[0] > t_start:
# dt_p1 = t2[0]-t_start
- dt_p2 = 0.0
t_f2 = t2[0]
dt_f2 = get_min_dist_cython(t_f2, t1, N1, 0, t_start, t_end)
+ dt_p2 = dt_f2
isi2 = fmax(t_f2-t_start, t2[1]-t2[0])
- s2 = dt_f2*(t_f2-t_start)/isi2
+ s2 = dt_p2*(t_f2-t_start)/isi2
index2 = -1
else:
- dt_p2 = 0.0
t_f2 = t2[1]
dt_f2 = get_min_dist_cython(t_f2, t1, N1, 0, t_start, t_end)
+ dt_p2 = 0.0
isi2 = t2[1]-t2[0]
s2 = dt_p2
index2 = 0
@@ -231,16 +231,16 @@ def spike_distance_cython(double[:] t1, double[:] t2,
if (index1 < N1-1) and (t_f1 < t_f2 or index2 == N2-1):
index1 += 1
# first calculate the previous interval end value
+ s1 = dt_f1*(t_f1-t_p1) / isi1
# the previous time now was the following time before:
- dt_p1 = dt_f1
+ dt_p1 = dt_f1
t_p1 = t_f1 # t_p1 contains the current time point
- # get the next time
+ # get the next time
if index1 < N1-1:
t_f1 = t1[index1+1]
else:
t_f1 = t_end
spike_events[index] = t_p1
- s1 = dt_p1
s2 = (dt_p2*(t_f2-t_p1) + dt_f2*(t_p1-t_p2)) / isi2
y_ends[index-1] = (s1*isi2 + s2*isi1)/isi_avrg_cython(isi1,
isi2)
@@ -249,6 +249,7 @@ def spike_distance_cython(double[:] t1, double[:] t2,
dt_f1 = get_min_dist_cython(t_f1, t2, N2, index2,
t_start, t_end)
isi1 = t_f1-t_p1
+ s1 = dt_p1
else:
dt_f1 = dt_p1
isi1 = fmax(t_end-t1[N1-1], t1[N1-1]-t1[N1-2])
@@ -260,9 +261,10 @@ def spike_distance_cython(double[:] t1, double[:] t2,
elif (index2 < N2-1) and (t_f1 > t_f2 or index1 == N1-1):
index2 += 1
# first calculate the previous interval end value
+ s2 = dt_f2*(t_f2-t_p2) / isi2
# the previous time now was the following time before:
dt_p2 = dt_f2
- t_p2 = t_f2 # t_p1 contains the current time point
+ t_p2 = t_f2 # t_p2 contains the current time point
# get the next time
if index2 < N2-1:
t_f2 = t2[index2+1]
@@ -270,7 +272,6 @@ def spike_distance_cython(double[:] t1, double[:] t2,
t_f2 = t_end
spike_events[index] = t_p2
s1 = (dt_p1*(t_f1-t_p2) + dt_f1*(t_p2-t_p1)) / isi1
- s2 = dt_p2
y_ends[index-1] = (s1*isi2 + s2*isi1) / isi_avrg_cython(isi1,
isi2)
# now the next interval start value
@@ -278,6 +279,7 @@ def spike_distance_cython(double[:] t1, double[:] t2,
dt_f2 = get_min_dist_cython(t_f2, t1, N1, index1,
t_start, t_end)
isi2 = t_f2-t_p2
+ s2 = dt_p2
else:
dt_f2 = dt_p2
isi2 = fmax(t_end-t2[N2-1], t2[N2-1]-t2[N2-2])
diff --git a/pyspike/cython/python_backend.py b/pyspike/cython/python_backend.py
index c65bfb0..317b568 100644
--- a/pyspike/cython/python_backend.py
+++ b/pyspike/cython/python_backend.py
@@ -142,12 +142,11 @@ def spike_distance_python(spikes1, spikes2, t_start, t_end):
t_p1 = t_start
t_p2 = t_start
if t1[0] > t_start:
- # dt_p1 = t2[0]-t_start
- dt_p1 = 0.0
t_f1 = t1[0]
dt_f1 = get_min_dist(t_f1, t2, 0, t_start, t_end)
+ dt_p1 = dt_f1
isi1 = max(t_f1-t_start, t1[1]-t1[0])
- s1 = dt_f1*(t_f1-t_start)/isi1
+ s1 = dt_p1*(t_f1-t_start)/isi1
index1 = -1
else:
dt_p1 = 0.0
@@ -158,11 +157,11 @@ def spike_distance_python(spikes1, spikes2, t_start, t_end):
index1 = 0
if t2[0] > t_start:
# dt_p1 = t2[0]-t_start
- dt_p2 = 0.0
t_f2 = t2[0]
dt_f2 = get_min_dist(t_f2, t1, 0, t_start, t_end)
+ dt_p2 = dt_f2
isi2 = max(t_f2-t_start, t2[1]-t2[0])
- s2 = dt_f2*(t_f2-t_start)/isi2
+ s2 = dt_p2*(t_f2-t_start)/isi2
index2 = -1
else:
dt_p2 = 0.0
@@ -180,6 +179,7 @@ def spike_distance_python(spikes1, spikes2, t_start, t_end):
if (index1 < N1-1) and (t_f1 < t_f2 or index2 == N2-1):
index1 += 1
# first calculate the previous interval end value
+ s1 = dt_f1*(t_f1-t_p1) / isi1
# the previous time now was the following time before:
dt_p1 = dt_f1
t_p1 = t_f1 # t_p1 contains the current time point
@@ -189,13 +189,13 @@ def spike_distance_python(spikes1, spikes2, t_start, t_end):
else:
t_f1 = t_end
spike_events[index] = t_p1
- s1 = dt_p1
s2 = (dt_p2*(t_f2-t_p1) + dt_f2*(t_p1-t_p2)) / isi2
y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2)
# now the next interval start value
if index1 < N1-1:
dt_f1 = get_min_dist(t_f1, t2, index2, t_start, t_end)
isi1 = t_f1-t_p1
+ s1 = dt_p1
else:
dt_f1 = dt_p1
isi1 = max(t_end-t1[N1-1], t1[N1-1]-t1[N1-2])
@@ -206,6 +206,7 @@ def spike_distance_python(spikes1, spikes2, t_start, t_end):
elif (index2 < N2-1) and (t_f1 > t_f2 or index1 == N1-1):
index2 += 1
# first calculate the previous interval end value
+ s2 = dt_f2*(t_f2-t_p2) / isi2
# the previous time now was the following time before:
dt_p2 = dt_f2
t_p2 = t_f2 # t_p1 contains the current time point
@@ -216,12 +217,12 @@ def spike_distance_python(spikes1, spikes2, t_start, t_end):
t_f2 = t_end
spike_events[index] = t_p2
s1 = (dt_p1*(t_f1-t_p2) + dt_f1*(t_p2-t_p1)) / isi1
- s2 = dt_p2
y_ends[index-1] = (s1*isi2 + s2*isi1) / (0.5*(isi1+isi2)**2)
# now the next interval start value
if index2 < N2-1:
dt_f2 = get_min_dist(t_f2, t1, index1, t_start, t_end)
isi2 = t_f2-t_p2
+ s2 = dt_p2
else:
dt_f2 = dt_p2
isi2 = max(t_end-t2[N2-1], t2[N2-1]-t2[N2-2])
diff --git a/test/test_distance.py b/test/test_distance.py
index 0059001..20b52e8 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -73,7 +73,7 @@ def test_spike():
assert_equal(f.x, expected_times)
- assert_almost_equal(f.avrg(), 0.1662415, decimal=6)
+ assert_almost_equal(f.avrg(), 1.6624149659863946e-01, decimal=15)
assert_almost_equal(f.y2[-1], 0.1394558, decimal=6)
t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
@@ -84,7 +84,7 @@ def test_spike():
s1 = np.array([0.1, 0.1, (0.1*0.1+0.05*0.1)/0.2, 0.05, (0.05*0.15 * 2)/0.2,
0.15, 0.1, (0.1*0.1+0.1*0.2)/0.3, (0.1*0.2+0.1*0.1)/0.3,
(0.1*0.05+0.1*0.25)/0.3, 0.1])
- s2 = np.array([0.1, 0.1*0.2/0.3, 0.1, (0.1*0.05 * 2)/.15, 0.05,
+ s2 = np.array([0.1, (0.1*0.2+0.1*0.1)/0.3, 0.1, (0.1*0.05 * 2)/.15, 0.05,
(0.05*0.2+0.1*0.15)/0.35, (0.05*0.1+0.1*0.25)/0.35,
0.1, 0.1, 0.05, 0.05])
isi1 = np.array([0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.3, 0.3, 0.3, 0.3])
@@ -113,12 +113,18 @@ def test_spike():
t2 = SpikeTrain([0.1, 0.4, 0.5, 0.6], [0.0, 1.0])
expected_times = [0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 1.0]
- s1 = np.array([0.1, 0.1*0.1/0.2, 0.1, 0.0, 0.0, 0.0, 0.0])
- s2 = np.array([0.1*0.1/0.3, 0.1, 0.1*0.2/0.3, 0.0, 0.1, 0.0, 0.0])
+ # due to the edge correction in the beginning, s1 and s2 are different
+ # for left and right values
+ s1_r = np.array([0.1, (0.1*0.1+0.1*0.1)/0.2, 0.1, 0.0, 0.0, 0.0, 0.0])
+ s1_l = np.array([0.1, (0.1*0.1+0.1*0.1)/0.2, 0.1, 0.0, 0.0, 0.0, 0.0])
+ s2_r = np.array([0.1*0.1/0.3, 0.1*0.3/0.3, 0.1*0.2/0.3,
+ 0.0, 0.1, 0.0, 0.0])
+ s2_l = np.array([0.1*0.1/0.3, 0.1*0.1/0.3, 0.1*0.2/0.3, 0.0,
+ 0.1, 0.0, 0.0])
isi1 = np.array([0.2, 0.2, 0.2, 0.2, 0.2, 0.4])
isi2 = np.array([0.3, 0.3, 0.3, 0.1, 0.1, 0.4])
- expected_y1 = (s1[:-1]*isi2+s2[:-1]*isi1) / (0.5*(isi1+isi2)**2)
- expected_y2 = (s1[1:]*isi2+s2[1:]*isi1) / (0.5*(isi1+isi2)**2)
+ expected_y1 = (s1_r[:-1]*isi2+s2_r[:-1]*isi1) / (0.5*(isi1+isi2)**2)
+ expected_y2 = (s1_l[1:]*isi2+s2_l[1:]*isi1) / (0.5*(isi1+isi2)**2)
expected_times = np.array(expected_times)
expected_y1 = np.array(expected_y1)
@@ -321,19 +327,37 @@ def test_spike_sync_matrix():
def test_regression_spiky():
+ # standard example
+ st1 = SpikeTrain(np.arange(100, 1201, 100), 1300)
+ st2 = SpikeTrain(np.arange(100, 1201, 110), 1300)
+
+ isi_dist = spk.isi_distance(st1, st2)
+ assert_almost_equal(isi_dist, 7.6923076923076941e-02, decimal=15)
+
+ spike_dist = spk.spike_distance(st1, st2)
+ assert_equal(spike_dist, 2.1105878248735391e-01)
+
+ spike_sync = spk.spike_sync(st1, st2)
+ assert_equal(spike_sync, 8.6956521739130432e-01)
+
+ # multivariate check
+
spike_trains = spk.load_spike_trains_from_txt("test/PySpike_testdata.txt",
(0.0, 4000.0))
- isi_profile = spk.isi_profile_multi(spike_trains)
- isi_dist = isi_profile.avrg()
- print(isi_dist)
+ isi_dist = spk.isi_distance_multi(spike_trains)
# get the full precision from SPIKY
- # assert_equal(isi_dist, 0.1832)
+ assert_almost_equal(isi_dist, 1.8318789829845508e-01, decimal=15)
spike_profile = spk.spike_profile_multi(spike_trains)
- spike_dist = spike_profile.avrg()
- print(spike_dist)
+ assert_equal(len(spike_profile.y1)+len(spike_profile.y2), 1252)
+
+ spike_dist = spk.spike_distance_multi(spike_trains)
+ # get the full precision from SPIKY
+ assert_almost_equal(spike_dist, 2.4432433330596512e-01, decimal=15)
+
+ spike_sync = spk.spike_sync_multi(spike_trains)
# get the full precision from SPIKY
- # assert_equal(spike_dist, 0.2445)
+ assert_equal(spike_sync, 0.7183531505298066)
def test_multi_variate_subsets():