|
13 | 13 | # --- |
14 | 14 |
|
15 | 15 | # + |
16 | | -from pathlib import Path |
17 | 16 | import datajoint as dj |
| 17 | +from pathlib import Path |
18 | 18 |
|
19 | 19 | dj.config.load( |
20 | 20 | Path("../dj_local_conf.json").absolute() |
21 | 21 | ) # load config for database connection info |
| 22 | + |
| 23 | +from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters |
| 24 | + |
22 | 25 | # - |
23 | 26 |
|
24 | | -# # MUA Analysis and Detection |
25 | | -# |
26 | | -# NOTE: This notebook is a work in progress. It is not yet complete and may contain errors. |
| 27 | +MuaEventsParameters() |
| 28 | + |
| 29 | +MuaEventsV1() |
27 | 30 |
|
28 | 31 | # + |
29 | | -from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput |
30 | | -import spyglass.spikesorting.v1 as sgs |
31 | | - |
| 32 | +from spyglass.position import PositionOutput |
32 | 33 |
|
33 | 34 | nwb_copy_file_name = "mediumnwb20230802_.nwb" |
34 | 35 |
|
35 | | -sorter_keys = { |
| 36 | +trodes_s_key = { |
36 | 37 | "nwb_file_name": nwb_copy_file_name, |
37 | | - "sorter": "clusterless_thresholder", |
38 | | - "sorter_param_name": "default_clusterless", |
| 38 | + "interval_list_name": "pos 0 valid times", |
| 39 | + "trodes_pos_params_name": "single_led_upsampled", |
39 | 40 | } |
40 | 41 |
|
41 | | -(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1 |
42 | | - |
43 | | -# + |
44 | | -spikesorting_merge_ids = ( |
45 | | - (sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1 |
46 | | -).fetch("merge_id") |
47 | | - |
48 | | -spikesorting_merge_ids |
49 | | - |
50 | | -# + |
51 | | -from spyglass.spikesorting.unit_inclusion_merge import ( |
52 | | - ImportedUnitInclusionV1, |
53 | | - UnitInclusionOutput, |
54 | | -) |
55 | | - |
56 | | -ImportedUnitInclusionV1().insert_all_units(spikesorting_merge_ids) |
57 | | - |
58 | | -UnitInclusionOutput.ImportedUnitInclusionV1() & [ |
59 | | - {"spikesorting_merge_id": id} for id in spikesorting_merge_ids |
60 | | -] |
61 | | - |
62 | | -# + |
63 | | -from spyglass.spikesorting.unit_inclusion_merge import ( |
64 | | - ImportedUnitInclusionV1, |
65 | | - UnitInclusionOutput, |
66 | | -) |
67 | | - |
68 | | -ImportedUnitInclusionV1().insert_all_units(spikesorting_merge_ids) |
69 | | - |
70 | | -UnitInclusionOutput.ImportedUnitInclusionV1() & [ |
71 | | - {"spikesorting_merge_id": id} for id in spikesorting_merge_ids |
72 | | -] |
| 42 | +pos_merge_id = (PositionOutput.TrodesPosV1 & trodes_s_key).fetch1("merge_id") |
| 43 | +pos_merge_id |
73 | 44 |
|
74 | 45 | # + |
75 | | -from spyglass.spikesorting.unit_inclusion_merge import SortedSpikesGroup |
76 | | - |
77 | | -unit_inclusion_merge_ids = ( |
78 | | - UnitInclusionOutput.ImportedUnitInclusionV1 |
79 | | - & [{"spikesorting_merge_id": id} for id in spikesorting_merge_ids] |
80 | | -).fetch("merge_id") |
81 | | - |
82 | | -SortedSpikesGroup().create_group( |
83 | | - group_name="test_group", |
84 | | - nwb_file_name=nwb_copy_file_name, |
85 | | - unit_inclusion_merge_ids=unit_inclusion_merge_ids, |
| 46 | +from spyglass.spikesorting.analysis.v1.group import ( |
| 47 | + SortedSpikesGroup, |
86 | 48 | ) |
87 | 49 |
|
88 | | -group_key = { |
| 50 | +sorted_spikes_group_key = { |
89 | 51 | "nwb_file_name": nwb_copy_file_name, |
90 | 52 | "sorted_spikes_group_name": "test_group", |
| 53 | + "unit_filter_params_name": "default_exclusion", |
91 | 54 | } |
92 | 55 |
|
93 | | -SortedSpikesGroup & group_key |
94 | | -# - |
95 | | - |
96 | | -SortedSpikesGroup.Units() & group_key |
97 | | - |
98 | | -# An example of how to get spike times |
99 | | - |
100 | | -spike_times = SortedSpikesGroup.fetch_spike_data(group_key) |
101 | | -spike_times[0] |
| 56 | +SortedSpikesGroup & sorted_spikes_group_key |
102 | 57 |
|
103 | 58 | # + |
104 | | -from spyglass.position import PositionOutput |
105 | | - |
106 | | -position_merge_id = ( |
107 | | - PositionOutput.TrodesPosV1 |
108 | | - & { |
109 | | - "nwb_file_name": nwb_copy_file_name, |
110 | | - "interval_list_name": "pos 0 valid times", |
111 | | - "trodes_pos_params_name": "default_decoding", |
112 | | - } |
113 | | -).fetch1("merge_id") |
114 | | - |
115 | | -position_info = ( |
116 | | - (PositionOutput & {"merge_id": position_merge_id}) |
117 | | - .fetch1_dataframe() |
118 | | - .dropna() |
119 | | -) |
120 | | -position_info |
121 | | - |
122 | | -# + |
123 | | -time_ind_slice = slice(63_000, 70_000) |
124 | | -time = position_info.index[time_ind_slice] |
125 | | - |
126 | | -SortedSpikesGroup.get_spike_indicator(group_key, time) |
127 | | - |
128 | | -# + |
129 | | -import matplotlib.pyplot as plt |
130 | | - |
131 | | -fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4)) |
132 | | -multiunit_firing_rate = SortedSpikesGroup.get_firing_rate( |
133 | | - group_key, time, multiunit=True |
134 | | -) |
135 | | -axes[0].plot( |
136 | | - time, |
137 | | - multiunit_firing_rate, |
138 | | -) |
139 | | -axes[0].set_ylabel("firing rate (Hz)") |
140 | | -axes[0].set_title("multiunit") |
141 | | -axes[1].fill_between( |
142 | | - time, position_info["speed"].iloc[time_ind_slice], color="lightgrey" |
143 | | -) |
144 | | -axes[1].set_ylabel("speed (cm/s)") |
145 | | -axes[1].set_xlabel("time (s)") |
146 | | - |
147 | | -# + |
148 | | -from spyglass.mua.v1.mua import MuaEventsParameters, MuaEventsV1 |
149 | | - |
150 | | -MuaEventsParameters().insert_default() |
151 | | -MuaEventsParameters() |
152 | | - |
153 | | -# + |
154 | | -selection_key = { |
| 59 | +mua_key = { |
155 | 60 | "mua_param_name": "default", |
156 | | - "nwb_file_name": nwb_copy_file_name, |
157 | | - "sorted_spikes_group_name": "test_group", |
158 | | - "pos_merge_id": position_merge_id, |
159 | | - "artifact_interval_list_name": "test_artifact_times", |
| 61 | + **sorted_spikes_group_key, |
| 62 | + "pos_merge_id": pos_merge_id, |
| 63 | + "detection_interval": "pos 0 valid times", |
160 | 64 | } |
161 | 65 |
|
162 | | -MuaEventsV1.populate(selection_key) |
| 66 | +MuaEventsV1().populate(mua_key) |
| 67 | +MuaEventsV1 & mua_key |
163 | 68 | # - |
164 | 69 |
|
165 | | -MuaEventsV1 & selection_key |
166 | | - |
167 | | -mua_times = (MuaEventsV1 & selection_key).fetch1_dataframe() |
| 70 | +mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe() |
168 | 71 | mua_times |
169 | 72 |
|
170 | 73 | # + |
171 | 74 | import matplotlib.pyplot as plt |
172 | 75 | import numpy as np |
173 | 76 |
|
174 | 77 | fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4)) |
| 78 | +speed = MuaEventsV1.get_speed(mua_key).to_numpy() |
| 79 | +time = speed.index.to_numpy() |
| 80 | +multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time) |
| 81 | + |
| 82 | +time_slice = slice( |
| 83 | + np.searchsorted(time, mua_times.loc[10].start_time) - 1_000, |
| 84 | + np.searchsorted(time, mua_times.loc[10].start_time) + 5_000, |
| 85 | +) |
| 86 | + |
175 | 87 | axes[0].plot( |
176 | | - time, |
177 | | - multiunit_firing_rate, |
| 88 | + time[time_slice], |
| 89 | + multiunit_firing_rate[time_slice], |
| 90 | + color="black", |
178 | 91 | ) |
179 | 92 | axes[0].set_ylabel("firing rate (Hz)") |
180 | 93 | axes[0].set_title("multiunit") |
181 | | -axes[1].fill_between( |
182 | | - time, position_info["speed"].iloc[time_ind_slice], color="lightgrey" |
183 | | -) |
| 94 | +axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey") |
184 | 95 | axes[1].set_ylabel("speed (cm/s)") |
185 | 96 | axes[1].set_xlabel("time (s)") |
186 | 97 |
|
187 | | -in_bounds = np.logical_and( |
188 | | - mua_times.start_time >= time[0], mua_times.end_time <= time[-1] |
189 | | -) |
190 | | - |
191 | | -for mua_time in mua_times.loc[in_bounds].itertuples(): |
192 | | - axes[0].axvspan( |
193 | | - mua_time.start_time, mua_time.end_time, color="red", alpha=0.3 |
| 98 | +for id, mua_time in mua_times.loc[ |
| 99 | + np.logical_and( |
| 100 | + mua_times["start_time"] > time[time_slice].min(), |
| 101 | + mua_times["end_time"] < time[time_slice].max(), |
194 | 102 | ) |
195 | | - axes[1].axvspan( |
196 | | - mua_time.start_time, mua_time.end_time, color="red", alpha=0.3 |
| 103 | +].iterrows(): |
| 104 | + axes[0].axvspan( |
| 105 | + mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5 |
197 | 106 | ) |
198 | | -axes[1].set_ylim((0, 80)) |
199 | | -axes[1].axhline(4, color="black", linestyle="--") |
200 | | -axes[1].set_xlim((time[0], time[-1])) |
201 | | - |
202 | | -# + |
203 | | -from spyglass.common import IntervalList |
204 | | - |
205 | | -IntervalList() & { |
206 | | - "nwb_file_name": nwb_copy_file_name, |
207 | | - "pipeline": "spikesorting_artifact_v1", |
208 | | -} |
209 | 107 | # - |
210 | 108 |
|
211 | | -( |
212 | | - sgs.ArtifactDetectionParameters |
213 | | - * sgs.SpikeSortingRecording |
214 | | - * sgs.ArtifactDetectionSelection |
215 | | -) |
216 | | - |
217 | | -SpikeSortingOutput.CurationV1() * ( |
218 | | - sgs.ArtifactDetectionParameters |
219 | | - * sgs.SpikeSortingRecording |
220 | | - * sgs.ArtifactDetectionSelection |
221 | | -) |
222 | | - |
223 | | -( |
224 | | - IntervalList() |
225 | | - & { |
226 | | - "nwb_file_name": nwb_copy_file_name, |
227 | | - "pipeline": "spikesorting_artifact_v1", |
228 | | - } |
229 | | -).proj(artifact_id="interval_list_name") |
230 | | - |
231 | | -sgs.SpikeSortingRecording() * sgs.ArtifactDetectionSelection() |
232 | | - |
233 | | -SpikeSortingOutput.CurationV1() * sgs.SpikeSortingRecording() |
234 | | - |
235 | | -IntervalList.insert1( |
236 | | - { |
237 | | - "nwb_file_name": nwb_copy_file_name, |
238 | | - "interval_list_name": "test_artifact_times", |
239 | | - "valid_times": [], |
240 | | - } |
| 109 | +(MuaEventsV1 & mua_key).create_figurl( |
| 110 | + zscore_mua=True, |
241 | 111 | ) |
0 commit comments