From 165b1e7c707324115ab2d21a78716ea2e243fc60 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Thu, 16 Oct 2014 13:00:00 +0200 Subject: added distance tests --- test/test_distance.py | 46 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/test/test_distance.py b/test/test_distance.py index 81ffe09..2a6bf4e 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -11,7 +11,8 @@ Distributed under the BSD License from __future__ import print_function import numpy as np from copy import copy -from numpy.testing import assert_equal, assert_array_almost_equal +from numpy.testing import assert_equal, assert_almost_equal, \ + assert_array_almost_equal import pyspike as spk @@ -25,7 +26,12 @@ def test_isi(): expected_times = [0.0, 0.2, 0.3, 0.4, 0.45, 0.6, 0.7, 0.8, 0.9, 0.95, 1.0] expected_isi = [0.1/0.3, 0.1/0.3, 0.05/0.2, 0.05/0.2, 0.15/0.35, 0.25/0.35, 0.05/0.35, 0.2/0.3, 0.25/0.3, 0.25/0.3] - + expected_times = np.array(expected_times) + expected_isi = np.array(expected_isi) + + expected_isi_val = sum((expected_times[1:] - expected_times[:-1]) * + expected_isi)/(expected_times[-1]-expected_times[0]) + t1 = spk.add_auxiliary_spikes(t1, 1.0) t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.isi_profile(t1, t2) @@ -33,7 +39,9 @@ def test_isi(): # print("ISI: ", f.y) assert_equal(f.x, expected_times) - assert_array_almost_equal(f.y, expected_isi, decimal=14) + assert_array_almost_equal(f.y, expected_isi, decimal=15) + assert_equal(f.avrg(), expected_isi_val) + assert_equal(spk.isi_distance(t1, t2), expected_isi_val) # check with some equal spike times t1 = np.array([0.2, 0.4, 0.6]) @@ -41,13 +49,20 @@ def test_isi(): expected_times = [0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 1.0] expected_isi = [0.1/0.2, 0.1/0.3, 0.1/0.3, 0.1/0.2, 0.1/0.2, 0.0/0.5] + expected_times = np.array(expected_times) + expected_isi = np.array(expected_isi) + + expected_isi_val = sum((expected_times[1:] - expected_times[:-1]) * + expected_isi)/(expected_times[-1]-expected_times[0]) t1 = spk.add_auxiliary_spikes(t1, 1.0) t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.isi_profile(t1, t2) assert_equal(f.x, expected_times) - assert_array_almost_equal(f.y, expected_isi, decimal=14) + assert_array_almost_equal(f.y, expected_isi, decimal=15) + assert_equal(f.avrg(), expected_isi_val) + assert_equal(spk.isi_distance(t1, t2), expected_isi_val) def test_spike(): @@ -67,13 +82,22 @@ def test_spike(): expected_y1 = (s1[:-1]*isi2+s2[:-1]*isi1) / (0.5*(isi1+isi2)**2) expected_y2 = (s1[1:]*isi2+s2[1:]*isi1) / (0.5*(isi1+isi2)**2) + expected_times = np.array(expected_times) + expected_y1 = np.array(expected_y1) + expected_y2 = np.array(expected_y2) + expected_spike_val = sum((expected_times[1:] - expected_times[:-1]) * + (expected_y1+expected_y2)/2) + expected_spike_val /= (expected_times[-1]-expected_times[0]) + t1 = spk.add_auxiliary_spikes(t1, 1.0) t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.spike_profile(t1, t2) assert_equal(f.x, expected_times) - assert_array_almost_equal(f.y1, expected_y1, decimal=14) - assert_array_almost_equal(f.y2, expected_y2, decimal=14) + assert_array_almost_equal(f.y1, expected_y1, decimal=15) + assert_array_almost_equal(f.y2, expected_y2, decimal=15) + assert_equal(f.avrg(), expected_spike_val) + assert_equal(spk.spike_distance(t1, t2), expected_spike_val) # check with some equal spike times t1 = np.array([0.2, 0.4, 0.6]) @@ -87,6 +111,13 @@ def test_spike(): expected_y1 = (s1[:-1]*isi2+s2[:-1]*isi1) / (0.5*(isi1+isi2)**2) expected_y2 = (s1[1:]*isi2+s2[1:]*isi1) / (0.5*(isi1+isi2)**2) + expected_times = np.array(expected_times) + expected_y1 = np.array(expected_y1) + expected_y2 = np.array(expected_y2) + expected_spike_val = sum((expected_times[1:] - expected_times[:-1]) * + (expected_y1+expected_y2)/2) + expected_spike_val /= (expected_times[-1]-expected_times[0]) + t1 = spk.add_auxiliary_spikes(t1, 1.0) t2 = spk.add_auxiliary_spikes(t2, 1.0) f = spk.spike_profile(t1, t2) @@ -94,6 +125,9 @@ def test_spike(): assert_equal(f.x, expected_times) assert_array_almost_equal(f.y1, expected_y1, decimal=14) assert_array_almost_equal(f.y2, expected_y2, decimal=14) + assert_almost_equal(f.avrg(), expected_spike_val, decimal=16) + assert_almost_equal(spk.spike_distance(t1, t2), expected_spike_val, + decimal=16) def check_multi_distance(dist_func, dist_func_multi): -- cgit v1.2.3