summaryrefslogtreecommitdiff
path: root/test/test_distance.py
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2014-09-29 16:08:45 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2014-09-29 16:08:45 +0200
commitb726773a29f85d465ff71867fab4fa5b8e5bcfe1 (patch)
treee9a0add62dfc3f1a27beeaa7de000b8d7614aa72 /test/test_distance.py
parente4f1c09672068e4778f7b5f3e27b47ff8986863c (diff)
+ multivariate distances
Diffstat (limited to 'test/test_distance.py')
-rw-r--r--test/test_distance.py47
1 files changed, 46 insertions, 1 deletions
diff --git a/test/test_distance.py b/test/test_distance.py
index 35bdf85..c43f0b3 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -7,10 +7,12 @@ Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
from __future__ import print_function
import numpy as np
+from copy import copy
from numpy.testing import assert_equal, assert_array_almost_equal
import pyspike as spk
+
def test_auxiliary_spikes():
t = np.array([0.2, 0.4, 0.6, 0.7])
t_aux = spk.add_auxiliary_spikes(t, T_end=1.0, T_start=0.1)
@@ -18,6 +20,7 @@ def test_auxiliary_spikes():
t_aux = spk.add_auxiliary_spikes(t_aux, 1.0)
assert_equal(t_aux, [0.0, 0.1, 0.2, 0.4, 0.6, 0.7, 1.0])
+
def test_isi():
# generate two spike trains:
t1 = np.array([0.2, 0.4, 0.6, 0.7])
@@ -32,7 +35,7 @@ def test_isi():
t2 = spk.add_auxiliary_spikes(t2, 1.0)
f = spk.isi_distance(t1, t2)
- print("ISI: ", f.y)
+ # print("ISI: ", f.y)
assert_equal(f.x, expected_times)
assert_array_almost_equal(f.y, expected_isi, decimal=14)
@@ -98,6 +101,48 @@ def test_spike():
assert_array_almost_equal(f.y2, expected_y2, decimal=14)
+def check_multi_distance(dist_func, dist_func_multi):
+ # generate spike trains:
+ t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)
+ t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0)
+ t3 = spk.add_auxiliary_spikes(np.array([0.2,0.4,0.6]), 1.0)
+ t4 = spk.add_auxiliary_spikes(np.array([0.1,0.4,0.5,0.6]), 1.0)
+ spike_trains = [t1, t2, t3, t4]
+
+ f12 = dist_func(t1, t2)
+ f13 = dist_func(t1, t3)
+ f14 = dist_func(t1, t4)
+ f23 = dist_func(t2, t3)
+ f24 = dist_func(t2, t4)
+ f34 = dist_func(t3, t4)
+
+ f_multi = dist_func_multi(spike_trains, [0,1])
+ assert f_multi.almost_equal(f12, decimal=14)
+
+ f = copy(f12)
+ f.add(f13)
+ f.add(f23)
+ f.mul_scalar(1.0/3)
+ f_multi = dist_func_multi(spike_trains, [0,1,2])
+ assert f_multi.almost_equal(f, decimal=14)
+
+ f.mul_scalar(3) # revert above normalization
+ f.add(f14)
+ f.add(f24)
+ f.add(f34)
+ f.mul_scalar(1.0/6)
+ f_multi = dist_func_multi(spike_trains)
+ assert f_multi.almost_equal(f, decimal=14)
+
+
+def test_multi_isi():
+ check_multi_distance(spk.isi_distance, spk.isi_distance_multi)
+
+
+def test_multi_spike():
+ check_multi_distance(spk.spike_distance, spk.spike_distance_multi)
+
+
if __name__ == "__main__":
test_auxiliary_spikes()
test_isi()