@@ -274,6 +274,10 @@ class BinnedSpikeTrain(object):
274274 Tolerance for rounding errors in the binning process and in the input
275275 data
276276 Default: 1e-8
277+ sparse_format : {'csr', 'csc'}, optional
278+ The sparse matrix format. By default, CSR format is used to perform
279+ slicing and computations efficiently.
280+ Default: 'csr'
277281
278282 Raises
279283 ------
@@ -323,7 +327,11 @@ class BinnedSpikeTrain(object):
323327
324328 @deprecated_alias (binsize = 'bin_size' , num_bins = 'n_bins' )
325329 def __init__ (self , spiketrains , bin_size = None , n_bins = None , t_start = None ,
326- t_stop = None , tolerance = 1e-8 ):
330+ t_stop = None , tolerance = 1e-8 , sparse_format = "csr" ):
331+ if sparse_format not in ("csr" , "csc" ):
332+ raise ValueError (f"Invalid 'sparse_format': { sparse_format } . "
333+ "Available: 'csr' and 'csc'" )
334+
327335 # Converting spiketrains to a list, if spiketrains is one
328336 # SpikeTrain object
329337 if isinstance (spiketrains , neo .SpikeTrain ):
@@ -339,7 +347,8 @@ def __init__(self, spiketrains, bin_size=None, n_bins=None, t_start=None,
339347 # Check all parameter, set also missing values
340348 self ._resolve_input_parameters (spiketrains )
341349 # Now create the sparse matrix
342- self .sparse_matrix = self ._create_sparse_matrix (spiketrains )
350+ self .sparse_matrix = self ._create_sparse_matrix (
351+ spiketrains , sparse_format = sparse_format )
343352
344353 @property
345354 def shape (self ):
@@ -369,13 +378,10 @@ def num_bins(self):
369378 return self .n_bins
370379
371380 def __repr__ (self ):
372- return "{klass}(t_start={t_start}, t_stop={t_stop}, " \
373- "bin_size={bin_size}; shape={shape})" .format (
374- klass = type (self ).__name__ ,
375- t_start = self .t_start ,
376- t_stop = self .t_stop ,
377- bin_size = self .bin_size ,
378- shape = self .shape )
381+ return f"{ type (self ).__name__ } (t_start={ self .t_start } , " \
382+ f"t_stop={ self .t_stop } , bin_size={ self .bin_size } ; " \
383+ f"shape={ self .shape } , " \
384+ f"format={ self .sparse_matrix .__class__ .__name__ } )"
379385
380386 def rescale (self , units ):
381387 """
@@ -590,7 +596,7 @@ def to_sparse_array(self):
590596
591597 Returns
592598 -------
593- scipy.sparse.csr_matrix
599+ scipy.sparse.csr_matrix or scipy.sparse.csc_matrix
594600 Sparse matrix, version with spike counts.
595601
596602 See also
@@ -611,7 +617,7 @@ def to_sparse_bool_array(self):
611617
612618 Returns
613619 -------
614- scipy.sparse.csr_matrix
620+ scipy.sparse.csr_matrix or scipy.sparse.csc_matrix
615621 Sparse matrix, binary, boolean version.
616622
617623 See also
@@ -638,7 +644,8 @@ def __eq__(self, other):
638644 return False
639645 sp1 = self .sparse_matrix
640646 sp2 = other .sparse_matrix
641- if sp1 .shape != sp2 .shape or sp1 .data .shape != sp2 .data .shape :
647+ if sp1 .__class__ is not sp2 .__class__ or sp1 .shape != sp2 .shape \
648+ or sp1 .data .shape != sp2 .data .shape :
642649 return False
643650 return (sp1 .data == sp2 .data ).all () and \
644651 (sp1 .indptr == sp2 .indptr ).all () and \
@@ -662,11 +669,18 @@ def copy(self):
662669 tolerance = self .tolerance )
663670
664671 def __iter_sparse_matrix (self ):
672+ spmat = self .sparse_matrix
673+ if isinstance (spmat , sps .csc_matrix ):
674+ warnings .warn ("The sparse matrix format is CSC. For better "
675+ "performance, specify the CSR format while "
676+ "constructing a "
677+ "BinnedSpikeTrain(sparse_format='csr')" )
678+ spmat = spmat .tocsr ()
665679 # taken from csr_matrix.__iter__()
666680 i0 = 0
667- for i1 in self . sparse_matrix .indptr [1 :]:
668- indices = self . sparse_matrix .indices [i0 :i1 ]
669- data = self . sparse_matrix .data [i0 :i1 ]
681+ for i1 in spmat .indptr [1 :]:
682+ indices = spmat .indices [i0 :i1 ]
683+ data = spmat .data [i0 :i1 ]
670684 yield indices , data
671685 i0 = i1
672686
@@ -1000,45 +1014,51 @@ def to_array(self, dtype=None):
10001014 scipy.sparse.csr_matrix.toarray
10011015
10021016 """
1003- spmat = self .sparse_matrix
1004- if dtype is not None and dtype != spmat .data .dtype :
1005- # avoid a copy
1006- spmat = sps .csr_matrix (
1007- (spmat .data .astype (dtype ), spmat .indices , spmat .indptr ),
1008- shape = spmat .shape )
1009- return spmat .toarray ()
1010-
1011- def binarize (self , copy = None ):
1017+ array = self .sparse_matrix .toarray ()
1018+ if dtype is not None :
1019+ array = array .astype (dtype )
1020+ return array
1021+
1022+ def binarize (self , copy = True ):
10121023 """
10131024 Clip the internal array (no. of spikes in a bin) to `0` (no spikes) or
10141025 `1` (at least one spike) values only.
10151026
10161027 Parameters
10171028 ----------
10181029 copy : bool, optional
1019- Deprecated parameter. It has no effect.
1030+ If True, a **shallow** copy - a view of `BinnedSpikeTrain` - is
1031+ returned with the data array filled with zeros and ones. Otherwise,
1032+ the binarization (clipping) is done in-place. A shallow copy
1033+ means that :attr:`indices` and :attr:`indptr` of a sparse matrix
1034+ is shared with the original sparse matrix. Only the data is copied.
1035+ If you want to perform a deep copy, call
1036+ :func:`BinnedSpikeTrain.copy` prior to binarizing.
1037+ Default: True
10201038
10211039 Returns
10221040 -------
1023- bst : BinnedSpikeTrainView
1024- A view of `BinnedSpikeTrain` with a sparse matrix containing
1025- data clipped to `0`s and `1`s .
1041+ bst : BinnedSpikeTrain or BinnedSpikeTrainView
1042+ A ( view of) `BinnedSpikeTrain` with the sparse matrix data clipped
1043+ to zeros and ones .
10261044
10271045 """
1028- if copy is not None :
1029- warnings .warn ("'copy' parameter is deprecated - a view is always "
1030- "returned; set this parameter to None." ,
1031- DeprecationWarning )
10321046 spmat = self .sparse_matrix
1033- spmat = sps .csr_matrix (
1034- (spmat .data .clip (max = 1 ), spmat .indices , spmat .indptr ),
1035- shape = spmat .shape , copy = False )
1036- bst = BinnedSpikeTrainView (t_start = self ._t_start ,
1037- t_stop = self ._t_stop ,
1038- bin_size = self ._bin_size ,
1039- units = self .units ,
1040- sparse_matrix = spmat ,
1041- tolerance = self .tolerance )
1047+ if copy :
1048+ data = np .ones (len (spmat .data ), dtype = spmat .data .dtype )
1049+ spmat = spmat .__class__ (
1050+ (data , spmat .indices , spmat .indptr ),
1051+ shape = spmat .shape , copy = False )
1052+ bst = BinnedSpikeTrainView (t_start = self ._t_start ,
1053+ t_stop = self ._t_stop ,
1054+ bin_size = self ._bin_size ,
1055+ units = self .units ,
1056+ sparse_matrix = spmat ,
1057+ tolerance = self .tolerance )
1058+ else :
1059+ spmat .data [:] = 1
1060+ bst = self
1061+
10421062 return bst
10431063
10441064 @property
@@ -1053,11 +1073,11 @@ def sparsity(self):
10531073 num_nonzero = self .sparse_matrix .data .shape [0 ]
10541074 return num_nonzero / np .prod (self .sparse_matrix .shape )
10551075
1056- def _create_sparse_matrix (self , spiketrains ):
1076+ def _create_sparse_matrix (self , spiketrains , sparse_format ):
10571077 """
1058- Converts `neo.SpikeTrain` objects to a sparse matrix
1059- (`scipy.sparse.csr_matrix`), which contains the binned spike times, and
1060- stores it in :attr:`_sparse_mat_u `.
1078+ Converts `neo.SpikeTrain` objects to a scipy sparse matrix, which
1079+ contains the binned spike times, and
1080+ stores it in :attr:`sparse_matrix `.
10611081
10621082 Parameters
10631083 ----------
@@ -1069,9 +1089,15 @@ def _create_sparse_matrix(self, spiketrains):
10691089 # The data type for numeric values
10701090 data_dtype = np .int32
10711091
1092+ if sparse_format == 'csr' :
1093+ sparse_format = sps .csr_matrix
1094+ else :
1095+ # csc
1096+ sparse_format = sps .csc_matrix
1097+
10721098 if not _check_neo_spiketrain (spiketrains ):
10731099 # a binned numpy array
1074- sparse_matrix = sps . csr_matrix (spiketrains , dtype = data_dtype )
1100+ sparse_matrix = sparse_format (spiketrains , dtype = data_dtype )
10751101 return sparse_matrix
10761102
10771103 # Get index dtype that can accomodate the largest index
@@ -1120,9 +1146,9 @@ def _create_sparse_matrix(self, spiketrains):
11201146 column_ids = np .hstack (column_ids )
11211147 row_ids = np .hstack (row_ids )
11221148
1123- sparse_matrix = sps . csr_matrix ((counts , (row_ids , column_ids )),
1124- shape = shape , dtype = data_dtype ,
1125- copy = False )
1149+ sparse_matrix = sparse_format ((counts , (row_ids , column_ids )),
1150+ shape = shape , dtype = data_dtype ,
1151+ copy = False )
11261152
11271153 return sparse_matrix
11281154
0 commit comments