Skip to content

Commit 68b3243

Browse files
authored
BinnedSpikeTrain optional CSC format (#402)
* BinnedSpikeTrain optional CSC format * BinnedSpikeTrain.binarize() 'copy' arg is back
1 parent 5e95f77 commit 68b3243

File tree

3 files changed

+130
-83
lines changed

3 files changed

+130
-83
lines changed

elephant/conversion.py

Lines changed: 75 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ class BinnedSpikeTrain(object):
274274
Tolerance for rounding errors in the binning process and in the input
275275
data
276276
Default: 1e-8
277+
sparse_format : {'csr', 'csc'}, optional
278+
The sparse matrix format. By default, CSR format is used to perform
279+
slicing and computations efficiently.
280+
Default: 'csr'
277281
278282
Raises
279283
------
@@ -323,7 +327,11 @@ class BinnedSpikeTrain(object):
323327

324328
@deprecated_alias(binsize='bin_size', num_bins='n_bins')
325329
def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None,
326-
t_stop=None, tolerance=1e-8):
330+
t_stop=None, tolerance=1e-8, sparse_format="csr"):
331+
if sparse_format not in ("csr", "csc"):
332+
raise ValueError(f"Invalid 'sparse_format': {sparse_format}. "
333+
"Available: 'csr' and 'csc'")
334+
327335
# Converting spiketrains to a list, if spiketrains is one
328336
# SpikeTrain object
329337
if isinstance(spiketrains, neo.SpikeTrain):
@@ -339,7 +347,8 @@ def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None,
339347
# Check all parameter, set also missing values
340348
self._resolve_input_parameters(spiketrains)
341349
# Now create the sparse matrix
342-
self.sparse_matrix = self._create_sparse_matrix(spiketrains)
350+
self.sparse_matrix = self._create_sparse_matrix(
351+
spiketrains, sparse_format=sparse_format)
343352

344353
@property
345354
def shape(self):
@@ -369,13 +378,10 @@ def num_bins(self):
369378
return self.n_bins
370379

371380
def __repr__(self):
372-
return "{klass}(t_start={t_start}, t_stop={t_stop}, " \
373-
"bin_size={bin_size}; shape={shape})".format(
374-
klass=type(self).__name__,
375-
t_start=self.t_start,
376-
t_stop=self.t_stop,
377-
bin_size=self.bin_size,
378-
shape=self.shape)
381+
return f"{type(self).__name__}(t_start={self.t_start}, " \
382+
f"t_stop={self.t_stop}, bin_size={self.bin_size}; " \
383+
f"shape={self.shape}, " \
384+
f"format={self.sparse_matrix.__class__.__name__})"
379385

380386
def rescale(self, units):
381387
"""
@@ -590,7 +596,7 @@ def to_sparse_array(self):
590596
591597
Returns
592598
-------
593-
scipy.sparse.csr_matrix
599+
scipy.sparse.csr_matrix or scipy.sparse.csc_matrix
594600
Sparse matrix, version with spike counts.
595601
596602
See also
@@ -611,7 +617,7 @@ def to_sparse_bool_array(self):
611617
612618
Returns
613619
-------
614-
scipy.sparse.csr_matrix
620+
scipy.sparse.csr_matrix or scipy.sparse.csc_matrix
615621
Sparse matrix, binary, boolean version.
616622
617623
See also
@@ -638,7 +644,8 @@ def __eq__(self, other):
638644
return False
639645
sp1 = self.sparse_matrix
640646
sp2 = other.sparse_matrix
641-
if sp1.shape != sp2.shape or sp1.data.shape != sp2.data.shape:
647+
if sp1.__class__ is not sp2.__class__ or sp1.shape != sp2.shape \
648+
or sp1.data.shape != sp2.data.shape:
642649
return False
643650
return (sp1.data == sp2.data).all() and \
644651
(sp1.indptr == sp2.indptr).all() and \
@@ -662,11 +669,18 @@ def copy(self):
662669
tolerance=self.tolerance)
663670

664671
def __iter_sparse_matrix(self):
672+
spmat = self.sparse_matrix
673+
if isinstance(spmat, sps.csc_matrix):
674+
warnings.warn("The sparse matrix format is CSC. For better "
675+
"performance, specify the CSR format while "
676+
"constructing a "
677+
"BinnedSpikeTrain(sparse_format='csr')")
678+
spmat = spmat.tocsr()
665679
# taken from csr_matrix.__iter__()
666680
i0 = 0
667-
for i1 in self.sparse_matrix.indptr[1:]:
668-
indices = self.sparse_matrix.indices[i0:i1]
669-
data = self.sparse_matrix.data[i0:i1]
681+
for i1 in spmat.indptr[1:]:
682+
indices = spmat.indices[i0:i1]
683+
data = spmat.data[i0:i1]
670684
yield indices, data
671685
i0 = i1
672686

@@ -1000,45 +1014,51 @@ def to_array(self, dtype=None):
10001014
scipy.sparse.csr_matrix.toarray
10011015
10021016
"""
1003-
spmat = self.sparse_matrix
1004-
if dtype is not None and dtype != spmat.data.dtype:
1005-
# avoid a copy
1006-
spmat = sps.csr_matrix(
1007-
(spmat.data.astype(dtype), spmat.indices, spmat.indptr),
1008-
shape=spmat.shape)
1009-
return spmat.toarray()
1010-
1011-
def binarize(self, copy=None):
1017+
array = self.sparse_matrix.toarray()
1018+
if dtype is not None:
1019+
array = array.astype(dtype)
1020+
return array
1021+
1022+
def binarize(self, copy=True):
10121023
"""
10131024
Clip the internal array (no. of spikes in a bin) to `0` (no spikes) or
10141025
`1` (at least one spike) values only.
10151026
10161027
Parameters
10171028
----------
10181029
copy : bool, optional
1019-
Deprecated parameter. It has no effect.
1030+
If True, a **shallow** copy - a view of `BinnedSpikeTrain` - is
1031+
returned with the data array filled with zeros and ones. Otherwise,
1032+
the binarization (clipping) is done in-place. A shallow copy
1033+
means that :attr:`indices` and :attr:`indptr` of a sparse matrix
1034+
is shared with the original sparse matrix. Only the data is copied.
1035+
If you want to perform a deep copy, call
1036+
:func:`BinnedSpikeTrain.copy` prior to binarizing.
1037+
Default: True
10201038
10211039
Returns
10221040
-------
1023-
bst : BinnedSpikeTrainView
1024-
A view of `BinnedSpikeTrain` with a sparse matrix containing
1025-
data clipped to `0`s and `1`s.
1041+
bst : BinnedSpikeTrain or BinnedSpikeTrainView
1042+
A (view of) `BinnedSpikeTrain` with the sparse matrix data clipped
1043+
to zeros and ones.
10261044
10271045
"""
1028-
if copy is not None:
1029-
warnings.warn("'copy' parameter is deprecated - a view is always "
1030-
"returned; set this parameter to None.",
1031-
DeprecationWarning)
10321046
spmat = self.sparse_matrix
1033-
spmat = sps.csr_matrix(
1034-
(spmat.data.clip(max=1), spmat.indices, spmat.indptr),
1035-
shape=spmat.shape, copy=False)
1036-
bst = BinnedSpikeTrainView(t_start=self._t_start,
1037-
t_stop=self._t_stop,
1038-
bin_size=self._bin_size,
1039-
units=self.units,
1040-
sparse_matrix=spmat,
1041-
tolerance=self.tolerance)
1047+
if copy:
1048+
data = np.ones(len(spmat.data), dtype=spmat.data.dtype)
1049+
spmat = spmat.__class__(
1050+
(data, spmat.indices, spmat.indptr),
1051+
shape=spmat.shape, copy=False)
1052+
bst = BinnedSpikeTrainView(t_start=self._t_start,
1053+
t_stop=self._t_stop,
1054+
bin_size=self._bin_size,
1055+
units=self.units,
1056+
sparse_matrix=spmat,
1057+
tolerance=self.tolerance)
1058+
else:
1059+
spmat.data[:] = 1
1060+
bst = self
1061+
10421062
return bst
10431063

10441064
@property
@@ -1053,11 +1073,11 @@ def sparsity(self):
10531073
num_nonzero = self.sparse_matrix.data.shape[0]
10541074
return num_nonzero / np.prod(self.sparse_matrix.shape)
10551075

1056-
def _create_sparse_matrix(self, spiketrains):
1076+
def _create_sparse_matrix(self, spiketrains, sparse_format):
10571077
"""
1058-
Converts `neo.SpikeTrain` objects to a sparse matrix
1059-
(`scipy.sparse.csr_matrix`), which contains the binned spike times, and
1060-
stores it in :attr:`_sparse_mat_u`.
1078+
Converts `neo.SpikeTrain` objects to a scipy sparse matrix, which
1079+
contains the binned spike times, and
1080+
stores it in :attr:`sparse_matrix`.
10611081
10621082
Parameters
10631083
----------
@@ -1069,9 +1089,15 @@ def _create_sparse_matrix(self, spiketrains):
10691089
# The data type for numeric values
10701090
data_dtype = np.int32
10711091

1092+
if sparse_format == 'csr':
1093+
sparse_format = sps.csr_matrix
1094+
else:
1095+
# csc
1096+
sparse_format = sps.csc_matrix
1097+
10721098
if not _check_neo_spiketrain(spiketrains):
10731099
# a binned numpy array
1074-
sparse_matrix = sps.csr_matrix(spiketrains, dtype=data_dtype)
1100+
sparse_matrix = sparse_format(spiketrains, dtype=data_dtype)
10751101
return sparse_matrix
10761102

10771103
# Get index dtype that can accomodate the largest index
@@ -1120,9 +1146,9 @@ def _create_sparse_matrix(self, spiketrains):
11201146
column_ids = np.hstack(column_ids)
11211147
row_ids = np.hstack(row_ids)
11221148

1123-
sparse_matrix = sps.csr_matrix((counts, (row_ids, column_ids)),
1124-
shape=shape, dtype=data_dtype,
1125-
copy=False)
1149+
sparse_matrix = sparse_format((counts, (row_ids, column_ids)),
1150+
shape=shape, dtype=data_dtype,
1151+
copy=False)
11261152

11271153
return sparse_matrix
11281154

elephant/statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ def time_histogram(spiketrains, bin_size, t_start=None, t_stop=None,
918918
bin_size=bin_size)
919919

920920
if binary:
921-
bs = bs.binarize()
921+
bs = bs.binarize(copy=False)
922922
bin_hist = bs.get_num_of_spikes(axis=0)
923923
# Flatten array
924924
bin_hist = np.ravel(bin_hist)
@@ -1309,7 +1309,7 @@ def _epoch_with_spread(self):
13091309
tolerance=self.tolerance)
13101310

13111311
if self.binary:
1312-
bst = bst.binarize()
1312+
bst = bst.binarize(copy=False)
13131313
bincount = bst.get_num_of_spikes(axis=0)
13141314

13151315
nonzero_indices = np.nonzero(bincount)[0]

elephant/test/test_conversion.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,19 @@ def setUp(self):
195195
self.bin_size = 1 * pq.s
196196
self.tolerance = 1e-8
197197

198+
def test_binarize(self):
199+
spiketrains = [self.spiketrain_a, self.spiketrain_b,
200+
self.spiketrain_a, self.spiketrain_b]
201+
for sparse_format in ("csr", "csc"):
202+
bst = cv.BinnedSpikeTrain(spiketrains=spiketrains,
203+
bin_size=self.bin_size,
204+
sparse_format=sparse_format)
205+
bst_bin = bst.binarize(copy=True)
206+
bst_copy = bst.copy()
207+
assert_array_equal(bst_bin.to_array(), bst.to_bool_array())
208+
bst_copy.sparse_matrix.data[:] = 1
209+
self.assertEqual(bst_bin, bst_copy)
210+
198211
def test_slice(self):
199212
spiketrains = [self.spiketrain_a, self.spiketrain_b,
200213
self.spiketrain_a, self.spiketrain_b]
@@ -254,32 +267,38 @@ def test_time_slice(self):
254267

255268
def test_to_spike_trains(self):
256269
np.random.seed(1)
257-
bst1 = cv.BinnedSpikeTrain(
258-
spiketrains=[self.spiketrain_a, self.spiketrain_b],
259-
bin_size=self.bin_size
260-
)
261270
spiketrains = [homogeneous_poisson_process(rate=10 * pq.Hz,
262271
t_start=-1 * pq.s,
263272
t_stop=10 * pq.s)]
264-
bst2 = cv.BinnedSpikeTrain(spiketrains=spiketrains,
265-
bin_size=300 * pq.ms)
266-
for bst in (bst1, bst2):
267-
for spikes in ("random", "left", "center"):
268-
spiketrains_gen = bst.to_spike_trains(spikes=spikes,
269-
annotate_bins=True)
270-
for st, indices in zip(spiketrains_gen, bst.spike_indices):
271-
# check sorted
272-
self.assertTrue((np.diff(st.magnitude) > 0).all())
273-
assert_array_equal(st.array_annotations['bins'], indices)
274-
self.assertEqual(st.annotations['bin_size'], bst.bin_size)
275-
self.assertEqual(st.t_start, bst.t_start)
276-
self.assertEqual(st.t_stop, bst.t_stop)
277-
bst_same = cv.BinnedSpikeTrain(spiketrains_gen,
278-
bin_size=bst.bin_size)
279-
self.assertEqual(bst_same, bst)
280-
281-
# invalid mode
282-
self.assertRaises(ValueError, bst.to_spike_trains, spikes='right')
273+
for sparse_format in ("csr", "csc"):
274+
bst1 = cv.BinnedSpikeTrain(
275+
spiketrains=[self.spiketrain_a, self.spiketrain_b],
276+
bin_size=self.bin_size, sparse_format=sparse_format
277+
)
278+
bst2 = cv.BinnedSpikeTrain(spiketrains=spiketrains,
279+
bin_size=300 * pq.ms,
280+
sparse_format=sparse_format)
281+
for bst in (bst1, bst2):
282+
for spikes in ("random", "left", "center"):
283+
spiketrains_gen = bst.to_spike_trains(spikes=spikes,
284+
annotate_bins=True)
285+
for st, indices in zip(spiketrains_gen, bst.spike_indices):
286+
# check sorted
287+
self.assertTrue((np.diff(st.magnitude) > 0).all())
288+
assert_array_equal(st.array_annotations['bins'],
289+
indices)
290+
self.assertEqual(st.annotations['bin_size'],
291+
bst.bin_size)
292+
self.assertEqual(st.t_start, bst.t_start)
293+
self.assertEqual(st.t_stop, bst.t_stop)
294+
bst_same = cv.BinnedSpikeTrain(spiketrains_gen,
295+
bin_size=bst.bin_size,
296+
sparse_format=sparse_format)
297+
self.assertEqual(bst_same, bst)
298+
299+
# invalid mode
300+
self.assertRaises(ValueError, bst.to_spike_trains,
301+
spikes='right')
283302

284303
def test_get_num_of_spikes(self):
285304
spiketrains = [self.spiketrain_a, self.spiketrain_b]
@@ -288,14 +307,16 @@ def test_get_num_of_spikes(self):
288307
bin_size=1 * pq.s, t_start=0 * pq.s)
289308
self.assertEqual(binned.get_num_of_spikes(),
290309
len(binned.spike_indices[0]))
291-
binned_matrix = cv.BinnedSpikeTrain(spiketrains, n_bins=10,
292-
bin_size=1 * pq.s)
293-
n_spikes_per_row = binned_matrix.get_num_of_spikes(axis=1)
294-
n_spikes_per_row_from_indices = list(map(len,
295-
binned_matrix.spike_indices))
296-
assert_array_equal(n_spikes_per_row, n_spikes_per_row_from_indices)
297-
self.assertEqual(binned_matrix.get_num_of_spikes(),
298-
sum(n_spikes_per_row_from_indices))
310+
for sparse_format in ("csr", "csc"):
311+
binned_matrix = cv.BinnedSpikeTrain(spiketrains, n_bins=10,
312+
bin_size=1 * pq.s,
313+
sparse_format=sparse_format)
314+
n_spikes_per_row = binned_matrix.get_num_of_spikes(axis=1)
315+
n_spikes_per_row_from_indices = list(
316+
map(len, binned_matrix.spike_indices))
317+
assert_array_equal(n_spikes_per_row, n_spikes_per_row_from_indices)
318+
self.assertEqual(binned_matrix.get_num_of_spikes(),
319+
sum(n_spikes_per_row_from_indices))
299320

300321
def test_binned_spiketrain_sparse(self):
301322
a = neo.SpikeTrain([1.7, 1.8, 4.3] * pq.s, t_stop=10.0 * pq.s)
@@ -662,7 +683,7 @@ def test_repr(self):
662683
bin_size=1 * pq.ms)
663684
self.assertEqual(repr(bst), "BinnedSpikeTrain(t_start=1.0 s, "
664685
"t_stop=1.01 s, bin_size=0.001 s; "
665-
"shape=(1, 10))")
686+
"shape=(1, 10), format=csr_matrix)")
666687

667688
def test_binned_sparsity(self):
668689
train = neo.SpikeTrain(np.arange(10), t_stop=10 * pq.s, units=pq.s)

0 commit comments

Comments
 (0)