@@ -91,6 +91,8 @@ def _stat(self):
9191 self ._data [base ] = (np .reshape (d , d .shape [0 ])
9292 if d .ndim == 2 and d .shape [1 ] == 1 else d )
9393
94+ self ._data ['channel_map' ] = self ._data ['channel_map' ].flatten ()
95+
9496 # Read the Cluster Groups
9597 for cluster_pattern , cluster_col_name in zip (['cluster_groups.*' , 'cluster_KSLabel.*' ],
9698 ['group' , 'KSLabel' ]):
@@ -127,19 +129,23 @@ def get_best_channel(self, unit):
127129
128130 def extract_spike_depths (self ):
129131 """ Reimplemented from https://github.com/cortex-lab/spikes/blob/master/analysis/ksDriftmap.m """
130- ycoords = self .data ['channel_positions' ][:, 1 ]
131- pc_features = self .data ['pc_features' ][:, 0 , :] # 1st PC only
132- pc_features = np .where (pc_features < 0 , 0 , pc_features )
133-
134- # ---- compute center of mass of these features (spike depths) ----
135-
136- # which channels for each spike?
137- spk_feature_ind = self .data ['pc_feature_ind' ][self .data ['spike_templates' ], :]
138- # ycoords of those channels?
139- spk_feature_ycoord = ycoords [spk_feature_ind ]
140- # center of mass is sum(coords.*features)/sum(features)
141- self ._data ['spike_depths' ] = (np .sum (spk_feature_ycoord * pc_features ** 2 , axis = 1 )
142- / np .sum (pc_features ** 2 , axis = 1 ))
132+
133+ if 'pc_features' in self .data :
134+ ycoords = self .data ['channel_positions' ][:, 1 ]
135+ pc_features = self .data ['pc_features' ][:, 0 , :] # 1st PC only
136+ pc_features = np .where (pc_features < 0 , 0 , pc_features )
137+
138+ # ---- compute center of mass of these features (spike depths) ----
139+
140+ # which channels for each spike?
141+ spk_feature_ind = self .data ['pc_feature_ind' ][self .data ['spike_templates' ], :]
142+ # ycoords of those channels?
143+ spk_feature_ycoord = ycoords [spk_feature_ind ]
144+ # center of mass is sum(coords.*features)/sum(features)
145+ self ._data ['spike_depths' ] = (np .sum (spk_feature_ycoord * pc_features ** 2 , axis = 1 )
146+ / np .sum (pc_features ** 2 , axis = 1 ))
147+ else :
148+ self ._data ['spike_depths' ] = None
143149
144150 # ---- extract spike sites ----
145151 max_site_ind = np .argmax (np .abs (self .data ['templates' ]).max (axis = 1 ), axis = 1 )
0 commit comments