summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2015-05-07 23:08:12 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2015-05-07 23:08:12 +0200
commita0262fc04e4b084f4dd270a75938d4ad029783d4 (patch)
tree09eea9b88d8a8e08ae035e36eb071979c4ce73ab
parenta7df02c0dc064dbb0586a394bb74200c7d6d67df (diff)
performance improvements
use recursive approach to compute average profile for average multivariate distances, dont compute average multivariate profile, but average distances directly.
-rw-r--r--examples/performance.py42
-rw-r--r--pyspike/cython/cython_add.pyx20
-rw-r--r--pyspike/generic.py77
-rw-r--r--pyspike/isi_distance.py6
-rw-r--r--pyspike/spike_distance.py6
-rw-r--r--test/test_distance.py14
6 files changed, 139 insertions, 26 deletions
diff --git a/examples/performance.py b/examples/performance.py
new file mode 100644
index 0000000..469b5ab
--- /dev/null
+++ b/examples/performance.py
@@ -0,0 +1,42 @@
+""" Compute distances of large sets of spike trains for performance tests
+"""
+
+from __future__ import print_function
+
+import pyspike as spk
+from datetime import datetime
+import cProfile
+
+M = 100 # number of spike trains
+r = 1.0 # rate of Poisson spike times
+T = 1E3 # length of spike trains
+
+print("%d spike trains with %d spikes" % (M, int(r*T)))
+
+spike_trains = []
+
+t_start = datetime.now()
+for i in xrange(M):
+ spike_trains.append(spk.generate_poisson_spikes(r, T))
+t_end = datetime.now()
+runtime = (t_end-t_start).total_seconds()
+
+print("Spike generation runtime: %.3fs" % runtime)
+
+print("================ ISI COMPUTATIONS ================")
+print(" MULTIVARIATE DISTANCE")
+cProfile.run('spk.isi_distance_multi(spike_trains)')
+print(" MULTIVARIATE PROFILE")
+cProfile.run('spk.isi_profile_multi(spike_trains)')
+
+print("================ SPIKE COMPUTATIONS ================")
+print(" MULTIVARIATE DISTANCE")
+cProfile.run('spk.spike_distance_multi(spike_trains)')
+print(" MULTIVARIATE PROFILE")
+cProfile.run('spk.spike_profile_multi(spike_trains)')
+
+print("================ SPIKE-SYNC COMPUTATIONS ================")
+print(" MULTIVARIATE DISTANCE")
+cProfile.run('spk.spike_sync_multi(spike_trains)')
+print(" MULTIVARIATE PROFILE")
+cProfile.run('spk.spike_sync_profile_multi(spike_trains)')
diff --git a/pyspike/cython/cython_add.pyx b/pyspike/cython/cython_add.pyx
index ac64005..8da1e53 100644
--- a/pyspike/cython/cython_add.pyx
+++ b/pyspike/cython/cython_add.pyx
@@ -83,13 +83,9 @@ def add_piece_wise_const_cython(double[:] x1, double[:] y1,
else: # both arrays reached the end simultaneously
# only the last x-value missing
x_new[index+1] = x1[N1-1]
- # the last value is again the end of the interval
- # x_new[index+1] = x1[-1]
- # only use the data that was actually filled
- x1 = x_new[:index+2]
- y1 = y_new[:index+1]
# end nogil
- return np.array(x_new[:index+2]), np.array(y_new[:index+1])
+ # return np.asarray(x_new[:index+2]), np.asarray(y_new[:index+1])
+ return np.asarray(x_new[:index+2]), np.asarray(y_new[:index+1])
############################################################
@@ -169,9 +165,9 @@ def add_piece_wise_lin_cython(double[:] x1, double[:] y11, double[:] y12,
y2_new[index] = y12[N1-2]+y22[N2-2]
# only use the data that was actually filled
# end nogil
- return (np.array(x_new[:index+2]),
- np.array(y1_new[:index+1]),
- np.array(y2_new[:index+1]))
+ return (np.asarray(x_new[:index+2]),
+ np.asarray(y1_new[:index+1]),
+ np.asarray(y2_new[:index+1]))
############################################################
@@ -230,6 +226,6 @@ def add_discrete_function_cython(double[:] x1, double[:] y1, double[:] mp1,
# the last value is again the end of the interval
# only use the data that was actually filled
- return (np.array(x_new[:index+1]),
- np.array(y_new[:index+1]),
- np.array(mp_new[:index+1]))
+ return (np.asarray(x_new[:index+1]),
+ np.asarray(y_new[:index+1]),
+ np.asarray(mp_new[:index+1]))
diff --git a/pyspike/generic.py b/pyspike/generic.py
index 4f278d2..41affcb 100644
--- a/pyspike/generic.py
+++ b/pyspike/generic.py
@@ -31,6 +31,69 @@ def _generic_profile_multi(spike_trains, pair_distance_func, indices=None):
Returns:
- The averaged multi-variate distance of all pairs
"""
+
+ def divide_and_conquer(pairs1, pairs2):
+ """ recursive calls by splitting the two lists in half.
+ """
+ L1 = len(pairs1)
+ if L1 > 1:
+ dist_prof1 = divide_and_conquer(pairs1[:L1/2], pairs1[L1/2:])
+ else:
+ dist_prof1 = pair_distance_func(spike_trains[pairs1[0][0]],
+ spike_trains[pairs1[0][1]])
+ L2 = len(pairs2)
+ if L2 > 1:
+ dist_prof2 = divide_and_conquer(pairs2[:L2/2], pairs2[L2/2:])
+ else:
+ dist_prof2 = pair_distance_func(spike_trains[pairs2[0][0]],
+ spike_trains[pairs2[0][1]])
+ dist_prof1.add(dist_prof2)
+ return dist_prof1
+
+ if indices is None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(indices[i], j) for i in range(len(indices))
+ for j in indices[i+1:]]
+
+ L = len(pairs)
+ if L > 1:
+ # recursive iteration through the list of pairs to get average profile
+ avrg_dist = divide_and_conquer(pairs[:len(pairs)/2],
+ pairs[len(pairs)/2:])
+ else:
+ avrg_dist = pair_distance_func(spike_trains[pairs[0][0]],
+ spike_trains[pairs[0][1]])
+
+ return avrg_dist, L
+
+
+############################################################
+# _generic_distance_multi
+############################################################
+def _generic_distance_multi(spike_trains, pair_distance_func,
+ indices=None, interval=None):
+ """ Internal implementation detail, don't call this function directly,
+ use isi_distance_multi or spike_distance_multi instead.
+
+ Computes the multi-variate distance for a set of spike-trains using the
+ pair_dist_func to compute pair-wise distances. That is it computes the
+ average distance of all pairs of spike-trains:
+ :math:`S(t) = 2/((N(N-1)) sum_{<i,j>} D_{i,j}`,
+ where the sum goes over all pairs <i,j>.
+ Args:
+ - spike_trains: list of spike trains
+ - pair_distance_func: function computing the distance of two spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - The averaged multi-variate distance of all pairs
+ """
+
if indices is None:
indices = np.arange(len(spike_trains))
indices = np.array(indices)
@@ -40,13 +103,13 @@ def _generic_profile_multi(spike_trains, pair_distance_func, indices=None):
# generate a list of possible index pairs
pairs = [(indices[i], j) for i in range(len(indices))
for j in indices[i+1:]]
- # start with first pair
- (i, j) = pairs[0]
- average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
- for (i, j) in pairs[1:]:
- current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
- average_dist.add(current_dist) # add to the average
- return average_dist, len(pairs)
+
+ avrg_dist = 0.0
+ for (i, j) in pairs:
+ avrg_dist += pair_distance_func(spike_trains[i], spike_trains[j],
+ interval)
+
+ return avrg_dist/len(pairs)
############################################################
diff --git a/pyspike/isi_distance.py b/pyspike/isi_distance.py
index aeab0df..21df561 100644
--- a/pyspike/isi_distance.py
+++ b/pyspike/isi_distance.py
@@ -3,7 +3,8 @@
# Distributed under the BSD License
from pyspike import PieceWiseConstFunc
-from pyspike.generic import _generic_profile_multi, _generic_distance_matrix
+from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \
+ _generic_distance_matrix
############################################################
@@ -112,7 +113,8 @@ def isi_distance_multi(spike_trains, indices=None, interval=None):
:returns: The time-averaged multivariate ISI distance :math:`D_I`
:rtype: double
"""
- return isi_profile_multi(spike_trains, indices).avrg(interval)
+ return _generic_distance_multi(spike_trains, isi_distance, indices,
+ interval)
############################################################
diff --git a/pyspike/spike_distance.py b/pyspike/spike_distance.py
index cc620d4..75b3b0e 100644
--- a/pyspike/spike_distance.py
+++ b/pyspike/spike_distance.py
@@ -3,7 +3,8 @@
# Distributed under the BSD License
from pyspike import PieceWiseLinFunc
-from pyspike.generic import _generic_profile_multi, _generic_distance_matrix
+from pyspike.generic import _generic_profile_multi, _generic_distance_multi, \
+ _generic_distance_matrix
############################################################
@@ -117,7 +118,8 @@ def spike_distance_multi(spike_trains, indices=None, interval=None):
:returns: The averaged multi-variate spike distance :math:`D_S`.
:rtype: double
"""
- return spike_profile_multi(spike_trains, indices).avrg(interval)
+ return _generic_distance_multi(spike_trains, spike_distance, indices,
+ interval)
############################################################
diff --git a/test/test_distance.py b/test/test_distance.py
index 19da35f..e45ac16 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -196,7 +196,7 @@ def test_spike_sync():
0.4, decimal=16)
-def check_multi_profile(profile_func, profile_func_multi):
+def check_multi_profile(profile_func, profile_func_multi, dist_func_multi):
# generate spike trains:
t1 = SpikeTrain([0.2, 0.4, 0.6, 0.7], 1.0)
t2 = SpikeTrain([0.3, 0.45, 0.8, 0.9, 0.95], 1.0)
@@ -213,10 +213,14 @@ def check_multi_profile(profile_func, profile_func_multi):
f_multi = profile_func_multi(spike_trains, [0, 1])
assert f_multi.almost_equal(f12, decimal=14)
+ d = dist_func_multi(spike_trains, [0, 1])
+ assert_equal(f_multi.avrg(), d)
f_multi1 = profile_func_multi(spike_trains, [1, 2, 3])
f_multi2 = profile_func_multi(spike_trains[1:])
assert f_multi1.almost_equal(f_multi2, decimal=14)
+ d = dist_func_multi(spike_trains, [1, 2, 3])
+ assert_almost_equal(f_multi1.avrg(), d, decimal=14)
f = copy(f12)
f.add(f13)
@@ -224,6 +228,8 @@ def check_multi_profile(profile_func, profile_func_multi):
f.mul_scalar(1.0/3)
f_multi = profile_func_multi(spike_trains, [0, 1, 2])
assert f_multi.almost_equal(f, decimal=14)
+ d = dist_func_multi(spike_trains, [0, 1, 2])
+ assert_almost_equal(f_multi.avrg(), d, decimal=14)
f.mul_scalar(3) # revert above normalization
f.add(f14)
@@ -235,11 +241,13 @@ def check_multi_profile(profile_func, profile_func_multi):
def test_multi_isi():
- check_multi_profile(spk.isi_profile, spk.isi_profile_multi)
+ check_multi_profile(spk.isi_profile, spk.isi_profile_multi,
+ spk.isi_distance_multi)
def test_multi_spike():
- check_multi_profile(spk.spike_profile, spk.spike_profile_multi)
+ check_multi_profile(spk.spike_profile, spk.spike_profile_multi,
+ spk.spike_distance_multi)
def test_multi_spike_sync():