summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMario Mulansky <mario.mulansky@gmx.net>2014-09-29 16:08:45 +0200
committerMario Mulansky <mario.mulansky@gmx.net>2014-09-29 16:08:45 +0200
commitb726773a29f85d465ff71867fab4fa5b8e5bcfe1 (patch)
treee9a0add62dfc3f1a27beeaa7de000b8d7614aa72
parente4f1c09672068e4778f7b5f3e27b47ff8986863c (diff)
+ multivariate distances
-rw-r--r--pyspike/__init__.py3
-rw-r--r--pyspike/distances.py76
-rw-r--r--pyspike/function.py40
-rw-r--r--test/test_distance.py47
4 files changed, 158 insertions, 8 deletions
diff --git a/pyspike/__init__.py b/pyspike/__init__.py
index 1784037..2143bdc 100644
--- a/pyspike/__init__.py
+++ b/pyspike/__init__.py
@@ -1,5 +1,6 @@
__all__ = ["function", "distances", "spikes"]
from function import PieceWiseConstFunc, PieceWiseLinFunc
-from distances import add_auxiliary_spikes, isi_distance, spike_distance
+from distances import add_auxiliary_spikes, isi_distance, spike_distance, \
+ isi_distance_multi, spike_distance_multi
from spikes import spike_train_from_string, merge_spike_trains
diff --git a/pyspike/distances.py b/pyspike/distances.py
index f4be625..52c6640 100644
--- a/pyspike/distances.py
+++ b/pyspike/distances.py
@@ -224,3 +224,79 @@ def spike_distance(spikes1, spikes2):
# could be less than original length due to equal spike times
return PieceWiseLinFunc(spike_events[:index+1],
y_starts[:index], y_ends[:index])
+
+
+
+
+############################################################
+# multi_distance
+############################################################
+def multi_distance(spike_trains, pair_distance_func, indices=None):
+ """ Internal implementation detail, use isi_distance_multi or
+ spike_distance_multi.
+
+ Computes the multi-variate distance for a set of spike-trains using the
+ pair_dist_func to compute pair-wise distances. That is it computes the
+ average distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>.
+ Args:
+ - spike_trains: list of spike trains
+ - pair_distance_func: function computing the distance of two spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - The averaged multi-variate distance of all pairs
+ """
+ if indices==None:
+ indices = np.arange(len(spike_trains))
+ indices = np.array(indices)
+ # check validity of indices
+ assert (indices < len(spike_trains)).all() and (indices >= 0).all(), \
+ "Invalid index list."
+ # generate a list of possible index pairs
+ pairs = [(i,j) for i in indices for j in indices[i+1:]]
+ # start with first pair
+ (i,j) = pairs[0]
+ average_dist = pair_distance_func(spike_trains[i], spike_trains[j])
+ for (i,j) in pairs[1:]:
+ current_dist = pair_distance_func(spike_trains[i], spike_trains[j])
+ average_dist.add(current_dist) # add to the average
+ average_dist.mul_scalar(1.0/len(pairs)) # normalize
+ return average_dist
+
+
+############################################################
+# isi_distance_multi
+############################################################
+def isi_distance_multi(spike_trains, indices=None):
+ """ computes the multi-variate isi-distance for a set of spike-trains. That
+ is the average isi-distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - A PieceWiseConstFunc representing the averaged isi distance S
+ """
+ return multi_distance(spike_trains, isi_distance, indices)
+
+
+############################################################
+# spike_distance_multi
+############################################################
+def spike_distance_multi(spike_trains, indices=None):
+ """ computes the multi-variate spike-distance for a set of spike-trains.
+ That is the average spike-distance of all pairs of spike-trains:
+ S(t) = 2/((N(N-1)) sum_{<i,j>} S_{i,j},
+ where the sum goes over all pairs <i,j>
+ Args:
+ - spike_trains: list of spike trains
+ - indices: list of indices defining which spike trains to use,
+ if None all given spike trains are used (default=None)
+ Returns:
+ - A PieceWiseLinFunc representing the averaged spike distance S
+ """
+ return multi_distance(spike_trains, spike_distance, indices)
diff --git a/pyspike/function.py b/pyspike/function.py
index 3a5a01c..26ca4b2 100644
--- a/pyspike/function.py
+++ b/pyspike/function.py
@@ -10,6 +10,7 @@ from __future__ import print_function
import numpy as np
+
##############################################################
# PieceWiseConstFunc
##############################################################
@@ -18,7 +19,7 @@ class PieceWiseConstFunc:
def __init__(self, x, y):
""" Constructs the piece-wise const function.
- Params:
+ Args:
- x: array of length N+1 defining the edges of the intervals of the pwc
function.
- y: array of length N defining the function values at the intervals.
@@ -26,6 +27,19 @@ class PieceWiseConstFunc:
self.x = np.array(x)
self.y = np.array(y)
+ def almost_equal(self, other, decimal=14):
+ """ Checks if the function is equal to another function up to `decimal`
+ precision.
+ Args:
+ - other: another PieceWiseConstFunc object
+ Returns:
+ True if the two functions are equal up to `decimal` decimals,
+ False otherwise
+ """
+ eps = 10.0**(-decimal)
+ return np.allclose(self.x, other.x, atol=eps, rtol=0.0) and \
+ np.allclose(self.y, other.y, atol=eps, rtol=0.0)
+
def get_plottable_data(self):
""" Returns two arrays containing x- and y-coordinates for immeditate
plotting of the piece-wise function.
@@ -63,7 +77,7 @@ class PieceWiseConstFunc:
def add(self, f):
""" Adds another PieceWiseConst function to this function.
Note: only functions defined on the same interval can be summed.
- Params:
+ Args:
- f: PieceWiseConst function to be added.
"""
assert self.x[0] == f.x[0], "The functions have different intervals"
@@ -111,7 +125,7 @@ class PieceWiseConstFunc:
def mul_scalar(self, fac):
""" Multiplies the function with a scalar value
- Params:
+ Args:
- fac: Value to multiply
"""
self.y *= fac
@@ -125,7 +139,7 @@ class PieceWiseLinFunc:
def __init__(self, x, y1, y2):
""" Constructs the piece-wise linear function.
- Params:
+ Args:
- x: array of length N+1 defining the edges of the intervals of the pwc
function.
- y1: array of length N defining the function values at the left of the
@@ -137,6 +151,20 @@ class PieceWiseLinFunc:
self.y1 = np.array(y1)
self.y2 = np.array(y2)
+ def almost_equal(self, other, decimal=14):
+ """ Checks if the function is equal to another function up to `decimal`
+ precision.
+ Args:
+ - other: another PieceWiseLinFunc object
+ Returns:
+ True if the two functions are equal up to `decimal` decimals,
+ False otherwise
+ """
+ eps = 10.0**(-decimal)
+ return np.allclose(self.x, other.x, atol=eps, rtol=0.0) and \
+ np.allclose(self.y1, other.y1, atol=eps, rtol=0.0) and \
+ np.allclose(self.y2, other.y2, atol=eps, rtol=0.0)
+
def get_plottable_data(self):
""" Returns two arrays containing x- and y-coordinates for immeditate
plotting of the piece-wise function.
@@ -171,7 +199,7 @@ class PieceWiseLinFunc:
def add(self, f):
""" Adds another PieceWiseLin function to this function.
Note: only functions defined on the same interval can be summed.
- Params:
+ Args:
- f: PieceWiseLin function to be added.
"""
assert self.x[0] == f.x[0], "The functions have different intervals"
@@ -246,7 +274,7 @@ class PieceWiseLinFunc:
def mul_scalar(self, fac):
""" Multiplies the function with a scalar value
- Params:
+ Args:
- fac: Value to multiply
"""
self.y1 *= fac
diff --git a/test/test_distance.py b/test/test_distance.py
index 35bdf85..c43f0b3 100644
--- a/test/test_distance.py
+++ b/test/test_distance.py
@@ -7,10 +7,12 @@ Copyright 2014, Mario Mulansky <mario.mulansky@gmx.net>
from __future__ import print_function
import numpy as np
+from copy import copy
from numpy.testing import assert_equal, assert_array_almost_equal
import pyspike as spk
+
def test_auxiliary_spikes():
t = np.array([0.2, 0.4, 0.6, 0.7])
t_aux = spk.add_auxiliary_spikes(t, T_end=1.0, T_start=0.1)
@@ -18,6 +20,7 @@ def test_auxiliary_spikes():
t_aux = spk.add_auxiliary_spikes(t_aux, 1.0)
assert_equal(t_aux, [0.0, 0.1, 0.2, 0.4, 0.6, 0.7, 1.0])
+
def test_isi():
# generate two spike trains:
t1 = np.array([0.2, 0.4, 0.6, 0.7])
@@ -32,7 +35,7 @@ def test_isi():
t2 = spk.add_auxiliary_spikes(t2, 1.0)
f = spk.isi_distance(t1, t2)
- print("ISI: ", f.y)
+ # print("ISI: ", f.y)
assert_equal(f.x, expected_times)
assert_array_almost_equal(f.y, expected_isi, decimal=14)
@@ -98,6 +101,48 @@ def test_spike():
assert_array_almost_equal(f.y2, expected_y2, decimal=14)
+def check_multi_distance(dist_func, dist_func_multi):
+ # generate spike trains:
+ t1 = spk.add_auxiliary_spikes(np.array([0.2, 0.4, 0.6, 0.7]), 1.0)
+ t2 = spk.add_auxiliary_spikes(np.array([0.3, 0.45, 0.8, 0.9, 0.95]), 1.0)
+ t3 = spk.add_auxiliary_spikes(np.array([0.2,0.4,0.6]), 1.0)
+ t4 = spk.add_auxiliary_spikes(np.array([0.1,0.4,0.5,0.6]), 1.0)
+ spike_trains = [t1, t2, t3, t4]
+
+ f12 = dist_func(t1, t2)
+ f13 = dist_func(t1, t3)
+ f14 = dist_func(t1, t4)
+ f23 = dist_func(t2, t3)
+ f24 = dist_func(t2, t4)
+ f34 = dist_func(t3, t4)
+
+ f_multi = dist_func_multi(spike_trains, [0,1])
+ assert f_multi.almost_equal(f12, decimal=14)
+
+ f = copy(f12)
+ f.add(f13)
+ f.add(f23)
+ f.mul_scalar(1.0/3)
+ f_multi = dist_func_multi(spike_trains, [0,1,2])
+ assert f_multi.almost_equal(f, decimal=14)
+
+ f.mul_scalar(3) # revert above normalization
+ f.add(f14)
+ f.add(f24)
+ f.add(f34)
+ f.mul_scalar(1.0/6)
+ f_multi = dist_func_multi(spike_trains)
+ assert f_multi.almost_equal(f, decimal=14)
+
+
+def test_multi_isi():
+ check_multi_distance(spk.isi_distance, spk.isi_distance_multi)
+
+
+def test_multi_spike():
+ check_multi_distance(spk.spike_distance, spk.spike_distance_multi)
+
+
if __name__ == "__main__":
test_auxiliary_spikes()
test_isi()