From a0262fc04e4b084f4dd270a75938d4ad029783d4 Mon Sep 17 00:00:00 2001 From: Mario Mulansky Date: Thu, 7 May 2015 23:08:12 +0200 Subject: performance improvements use recursive approach to compute average profile for average multivariate distances, dont compute average multivariate profile, but average distances directly. --- test/test_distance.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'test/test_distance.py') diff --git a/test/test_distance.py b/test/test_distance.py index 19da35f..e45ac16 100644 --- a/test/test_distance.py +++ b/test/test_distance.py @@ -196,7 +196,7 @@ def test_spike_sync(): 0.4, decimal=16) -def check_multi_profile(profile_func, profile_func_multi): +def check_multi_profile(profile_func, profile_func_multi, dist_func_multi): # generate spike trains: t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0) t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0) @@ -213,10 +213,14 @@ def check_multi_profile(profile_func, profile_func_multi): f_multi = profile_func_multi(spike_trains, [0, 1]) assert f_multi.almost_equal(f12, decimal=14) + d = dist_func_multi(spike_trains, [0, 1]) + assert_equal(f_multi.avrg(), d) f_multi1 = profile_func_multi(spike_trains, [1, 2, 3]) f_multi2 = profile_func_multi(spike_trains[1:]) assert f_multi1.almost_equal(f_multi2, decimal=14) + d = dist_func_multi(spike_trains, [1, 2, 3]) + assert_almost_equal(f_multi1.avrg(), d, decimal=14) f = copy(f12) f.add(f13) @@ -224,6 +228,8 @@ def check_multi_profile(profile_func, profile_func_multi): f.mul_scalar(1.0/3) f_multi = profile_func_multi(spike_trains, [0, 1, 2]) assert f_multi.almost_equal(f, decimal=14) + d = dist_func_multi(spike_trains, [0, 1, 2]) + assert_almost_equal(f_multi.avrg(), d, decimal=14) f.mul_scalar(3) # revert above normalization f.add(f14) @@ -235,11 +241,13 @@ def check_multi_profile(profile_func, profile_func_multi): def test_multi_isi(): - check_multi_profile(spk.isi_profile, spk.isi_profile_multi) + check_multi_profile(spk.isi_profile, spk.isi_profile_multi, + spk.isi_distance_multi) def test_multi_spike(): - check_multi_profile(spk.spike_profile, spk.spike_profile_multi) + check_multi_profile(spk.spike_profile, spk.spike_profile_multi, + spk.spike_distance_multi) def test_multi_spike_sync(): -- cgit v1.2.3