summaryrefslogtreecommitdiff
path: root/test/numeric/test_regression_random_spikes.py
blob: 6156bb4f465ac23399c35140b00824202bddb328 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
""" regression benchmark

Copyright 2015, Mario Mulansky <mario.mulansky@gmx.net>

Distributed under the BSD License
"""
from __future__ import print_function

import numpy as np
from scipy.io import loadmat
import pyspike as spk

from numpy.testing import assert_almost_equal

spk.disable_backend_warning = True


def test_regression_random():

    spike_file = "test/numeric/regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "test/numeric/regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]
    results_cSPIKY = loadmat(result_file)[result_name]

    for i, spike_train_data in enumerate(spike_train_sets):
        spike_trains = []
        for spikes in spike_train_data[0]:
            spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

        isi = spk.isi_distance_multi(spike_trains)
        isi_prof = spk.isi_profile_multi(spike_trains).avrg()

        spike = spk.spike_distance_multi(spike_trains)
        spike_prof = spk.spike_profile_multi(spike_trains).avrg()
        # spike_sync = spk.spike_sync_multi(spike_trains)

        assert_almost_equal(isi, results_cSPIKY[i][0], decimal=14,
                            err_msg="Index: %d, ISI" % i)
        assert_almost_equal(isi_prof, results_cSPIKY[i][0], decimal=14,
                            err_msg="Index: %d, ISI" % i)

        assert_almost_equal(spike, results_cSPIKY[i][1], decimal=14,
                            err_msg="Index: %d, SPIKE" % i)
        assert_almost_equal(spike_prof, results_cSPIKY[i][1], decimal=14,
                            err_msg="Index: %d, SPIKE" % i)


def check_regression_dataset(spike_file="benchmark.mat",
                             spikes_name="spikes",
                             result_file="results_cSPIKY.mat",
                             result_name="Distances"):
    """ Debuging function """
    np.set_printoptions(precision=15)

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    err_max = 0.0
    err_max_ind = -1
    err_count = 0

    for i, spike_train_data in enumerate(spike_train_sets):
        spike_trains = []
        for spikes in spike_train_data[0]:
            spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

        isi = spk.isi_distance_multi(spike_trains)
        spike = spk.spike_distance_multi(spike_trains)
        # spike_sync = spk.spike_sync_multi(spike_trains)

        if abs(isi - results_cSPIKY[i][0]) > 1E-14:
            print("Error in ISI:", i, isi, results_cSPIKY[i][0])
            print("Spike trains:")
            for st in spike_trains:
                print(st.spikes)

        err = abs(spike - results_cSPIKY[i][1])
        if err > 1E-14:
            err_count += 1
        if err > err_max:
            err_max = err
            err_max_ind = i

    print("Total Errors:", err_count)

    if err_max_ind > -1:
        print("Max SPIKE distance error:", err_max, "at index:", err_max_ind)
        spike_train_data = spike_train_sets[err_max_ind]
        for spikes in spike_train_data[0]:
            print(spikes.flatten())


def check_single_spike_train_set(index):
    """ Debuging function """
    np.set_printoptions(precision=15)
    spike_file = "regression_random_spikes.mat"
    spikes_name = "spikes"
    result_name = "Distances"
    result_file = "regression_random_results_cSPIKY.mat"

    spike_train_sets = loadmat(spike_file)[spikes_name][0]

    results_cSPIKY = loadmat(result_file)[result_name]

    spike_train_data = spike_train_sets[index]

    spike_trains = []
    for spikes in spike_train_data[0]:
        print("Spikes:", spikes.flatten())
        spike_trains.append(spk.SpikeTrain(spikes.flatten(), 100.0))

    print(spk.spike_distance_multi(spike_trains))

    print(results_cSPIKY[index][1])

    print(spike_trains[1].spikes)


if __name__ == "__main__":

    test_regression_random()
    # check_regression_dataset()
    # check_single_spike_train_set(7633)