summaryrefslogtreecommitdiff
path: root/test/test_sync_filter.py
blob: e25990334482312b72ed496cac9e765e290fddd0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
""" test_sync_filter.py

Tests the spike sync based filtering

Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>

Distributed under the BSD License

"""

from __future__ import print_function
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal, \
    assert_array_almost_equal

import pyspike as spk
from pyspike import SpikeTrain


def test_single_prof():
    st1 = np.array([1.0, 2.0, 3.0, 4.0])
    st2 = np.array([1.1, 2.1, 3.8])
    st3 = np.array([0.9, 3.1, 4.1])

    # cython implementation
    try:
        from pyspike.cython.cython_profiles import \
            coincidence_single_profile_cython as coincidence_impl
    except ImportError:
        from pyspike.cython.python_backend import \
            coincidence_single_python as coincidence_impl

    sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
                                       SpikeTrain(st2, 5.0))

    coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0))
    print(coincidences)
    for i, t in enumerate(st1):
        assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
                     "At index %d" % i)

    coincidences = np.array(coincidence_impl(st2, st1, 0, 5.0, 0.0))
    for i, t in enumerate(st2):
        assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
                     "At index %d" % i)

    sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
                                       SpikeTrain(st3, 5.0))

    coincidences = np.array(coincidence_impl(st1, st3, 0, 5.0, 0.0))
    for i, t in enumerate(st1):
        assert_equal(coincidences[i], sync_prof.y[sync_prof.x == t],
                     "At index %d" % i)

    st1 = np.array([1.0, 2.0, 3.0, 4.0])
    st2 = np.array([1.0, 2.0, 4.0])

    sync_prof = spk.spike_sync_profile(SpikeTrain(st1, 5.0),
                                       SpikeTrain(st2, 5.0))

    coincidences = np.array(coincidence_impl(st1, st2, 0, 5.0, 0.0))
    for i, t in enumerate(st1):
        expected = sync_prof.y[sync_prof.x == t]/sync_prof.mp[sync_prof.x == t]
        assert_equal(coincidences[i], expected,
                     "At index %d" % i)


def test_filter():
    st1 = SpikeTrain(np.array([1.0, 2.0, 3.0, 4.0]), 5.0)
    st2 = SpikeTrain(np.array([1.1, 2.1, 3.8]), 5.0)
    st3 = SpikeTrain(np.array([0.9, 3.1, 4.1]), 5.0)

    # filtered_spike_trains = spk.filter_by_spike_sync([st1, st2], 0.5)

    # assert_equal(filtered_spike_trains[0].spikes, [1.0, 2.0, 4.0])
    # assert_equal(filtered_spike_trains[1].spikes, [1.1, 2.1, 3.8])

    # filtered_spike_trains = spk.filter_by_spike_sync([st2, st1], 0.5)

    # assert_equal(filtered_spike_trains[0].spikes, [1.1, 2.1, 3.8])
    # assert_equal(filtered_spike_trains[1].spikes, [1.0, 2.0, 4.0])

    filtered_spike_trains = spk.filter_by_spike_sync([st1, st2, st3], 0.75)

    for st in filtered_spike_trains:
        print(st.spikes)

    assert_equal(filtered_spike_trains[0].spikes, [1.0, 4.0])
    assert_equal(filtered_spike_trains[1].spikes, [1.1, 3.8])
    assert_equal(filtered_spike_trains[2].spikes, [0.9, 4.1])


if __name__ == "main":
    test_single_prof()
    test_filter()