From b726773a29f85d465ff71867fab4fa5b8e5bcfe1 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Mon, 29 Sep 2014 16:08:45 +0200 Subject: + multivariate distances --- test/test_distance.py | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) (limited to 'test') diff --git a/test/test_distance.py b/test/test_distance.py index 35bdf85..c43f0b3 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -7,10 +7,12 @@ Copyright 2014, Mario Mulansky from __future__ import print_function import numpy as np +from copy import copy from numpy.testing import assert_equal, assert_array_almost_equal import pyspike as spk + def test_auxiliary_spikes(): t = np.array([0.2, 0.4, 0.6, 0.7]) t_aux = spk.add_auxiliary_spikes(t, T_end=1.0, T_start=0.1) @@ -18,6 +20,7 @@ def test_auxiliary_spikes(): t_aux = spk.add_auxiliary_spikes(t_aux, 1.0) assert_equal(t_aux, [0.0, 0.1, 0.2, 0.4, 0.6, 0.7, 1.0]) + def test_isi(): # generate two spike trains: t1 = np.array([0.2, 0.4, 0.6, 0.7]) @@ -32,7 +35,7 @@ def test_isi(): t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.isi_distance(t1, t2) - print("ISI: ", f.y) + # print("ISI: ", f.y) assert_equal(f.x, expected_times) assert_array_almost_equal(f.y, expected_isi, decimal=14) @@ -98,6 +101,48 @@ def test_spike(): assert_array_almost_equal(f.y2, expected_y2, decimal=14) +def check_multi_distance(dist_func, dist_func_multi): + # generate spike trains: + t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0) + t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0) + t3 = spk.add_auxiliary_spikes(np.array([0.2,0.4,0.6]), 1.0) + t4 = spk.add_auxiliary_spikes(np.array([0.1,0.4,0.5,0.6]), 1.0) + spike_trains = [t1, t2, t3, t4] + + f12 = dist_func(t1, t2) + f13 = dist_func(t1, t3) + f14 = dist_func(t1, t4) + f23 = dist_func(t2, t3) + f24 = dist_func(t2, t4) + f34 = dist_func(t3, t4) + + f_multi = dist_func_multi(spike_trains, [0,1]) + assert f_multi.almost_equal(f12, decimal=14) + + f = copy(f12) + f.add(f13) + f.add(f23) + f.mul_scalar(1.0/3) + f_multi = dist_func_multi(spike_trains, [0,1,2]) + assert f_multi.almost_equal(f, decimal=14) + + f.mul_scalar(3) # revert above normalization + f.add(f14) + f.add(f24) + f.add(f34) + f.mul_scalar(1.0/6) + f_multi = dist_func_multi(spike_trains) + assert f_multi.almost_equal(f, decimal=14) + + +def test_multi_isi(): + check_multi_distance(spk.isi_distance, spk.isi_distance_multi) + + +def test_multi_spike(): + check_multi_distance(spk.spike_distance, spk.spike_distance_multi) + + if __name__ == "__main__": test_auxiliary_spikes() test_isi() -- cgit v1.2.3