1010from pipeline import experiment , ephys , psth
1111
1212from pipeline .plot .util import (_plot_with_sem , _extract_one_stim_dur , _get_units_hemisphere ,
13- _plot_stacked_psth_diff , _plot_avg_psth ,
14- jointplot_w_hue )
13+ _get_trial_event_times , _get_clustering_method ,
14+ _plot_stacked_psth_diff , _plot_avg_psth , jointplot_w_hue )
1515
1616m_scale = 1200
1717_plt_xmin = - 3
1818_plt_xmax = 2
1919
2020
21- def plot_clustering_quality (probe_insertion , axs = None ):
21+ def plot_clustering_quality (probe_insertion , clustering_method = None , axs = None ):
2222 probe_insertion = probe_insertion .proj ()
23- amp , snr , spk_rate , isi_violation = (ephys .Unit * ephys .UnitStat
24- * ephys .ProbeInsertion .InsertionLocation & probe_insertion ).fetch (
23+
24+ if clustering_method is None :
25+ try :
26+ clustering_method = _get_clustering_method (probe_insertion )
27+ except ValueError as e :
28+ raise ValueError (str (e ) + '\n Please specify one with the kwarg "clustering_method"' )
29+
30+ amp , snr , spk_rate , isi_violation = (ephys .Unit * ephys .UnitStat * ephys .ProbeInsertion .InsertionLocation
31+ & probe_insertion & {'clustering_method' : clustering_method }).fetch (
2532 'unit_amp' , 'unit_snr' , 'avg_firing_rate' , 'isi_violation' )
2633
2734 metrics = {'amp' : amp ,
@@ -52,11 +59,18 @@ def plot_clustering_quality(probe_insertion, axs=None):
5259 return fig
5360
5461
55- def plot_unit_characteristic (probe_insertion , axs = None ):
62+ def plot_unit_characteristic (probe_insertion , clustering_method = None , axs = None ):
5663 probe_insertion = probe_insertion .proj ()
64+
65+ if clustering_method is None :
66+ try :
67+ clustering_method = _get_clustering_method (probe_insertion )
68+ except ValueError as e :
69+ raise ValueError (str (e ) + '\n Please specify one with the kwarg "clustering_method"' )
70+
5771 amp , snr , spk_rate , x , y , insertion_depth = (
5872 ephys .Unit * ephys .ProbeInsertion .InsertionLocation * ephys .UnitStat
59- & probe_insertion & 'unit_quality != "all"' ).fetch (
73+ & probe_insertion & { 'clustering_method' : clustering_method } & 'unit_quality != "all"' ).fetch (
6074 'unit_amp' , 'unit_snr' , 'avg_firing_rate' , 'unit_posx' , 'unit_posy' , 'dv_location' )
6175
6276 insertion_depth = np .where (np .isnan (insertion_depth ), 0 , insertion_depth )
@@ -102,12 +116,20 @@ def plot_unit_characteristic(probe_insertion, axs=None):
102116 return fig
103117
104118
105- def plot_unit_selectivity (probe_insertion , axs = None ):
119+ def plot_unit_selectivity (probe_insertion , clustering_method = None , axs = None ):
106120 probe_insertion = probe_insertion .proj ()
121+
122+ if clustering_method is None :
123+ try :
124+ clustering_method = _get_clustering_method (probe_insertion )
125+ except ValueError as e :
126+ raise ValueError (str (e ) + '\n Please specify one with the kwarg "clustering_method"' )
127+
107128 attr_names = ['unit' , 'period' , 'period_selectivity' , 'contra_firing_rate' ,
108- 'ipsi_firing_rate' , 'unit_posx' , 'unit_posy' , 'dv_location' ]
129+ 'ipsi_firing_rate' , 'unit_posx' , 'unit_posy' , 'dv_location' ]
109130 selective_units = (psth .PeriodSelectivity * ephys .Unit * ephys .ProbeInsertion .InsertionLocation
110- * experiment .Period & probe_insertion & 'period_selectivity != "non-selective"' ).fetch (* attr_names )
131+ * experiment .Period & probe_insertion & {'clustering_method' : clustering_method }
132+ & 'period_selectivity != "non-selective"' ).fetch (* attr_names )
111133 selective_units = pd .DataFrame (selective_units ).T
112134 selective_units .columns = attr_names
113135 selective_units .period_selectivity .astype ('category' )
@@ -162,10 +184,16 @@ def plot_unit_selectivity(probe_insertion, axs=None):
162184 return fig
163185
164186
165- def plot_unit_bilateral_photostim_effect (probe_insertion , axs = None ):
187+ def plot_unit_bilateral_photostim_effect (probe_insertion , clustering_method = None , axs = None ):
166188 probe_insertion = probe_insertion .proj ()
189+
190+ if clustering_method is None :
191+ try :
192+ clustering_method = _get_clustering_method (probe_insertion )
193+ except ValueError as e :
194+ raise ValueError (str (e ) + '\n Please specify one with the kwarg "clustering_method"' )
195+
167196 dv_loc = (ephys .ProbeInsertion .InsertionLocation & probe_insertion ).fetch1 ('dv_location' )
168- cue_onset = (experiment .Period & 'period = "delay"' ).fetch1 ('period_start' )
169197
170198 no_stim_cond = (psth .TrialCondition
171199 & {'trial_condition_name' :
@@ -181,24 +209,32 @@ def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
181209 & probe_insertion ).fetch ('duration' ))
182210 stim_dur = _extract_one_stim_dur (stim_durs )
183211
184- units = ephys .Unit & probe_insertion & 'unit_quality != "all"'
212+ units = ephys .Unit & probe_insertion & {'clustering_method' : clustering_method } & 'unit_quality != "all"'
213+
214+ metrics = pd .DataFrame (columns = ['unit' , 'x' , 'y' , 'frate_change' ])
185215
186- metrics = pd .DataFrame (columns = ['unit' , 'x' , 'y' , 'frate_change' ]) # TODO: account for dv_location
216+ _ , cue_onset = _get_trial_event_times (['delay' ], units , 'all_noearlylick_both_alm_nostim' )
217+ cue_onset = cue_onset [0 ]
187218
188219 # XXX: could be done with 1x fetch+join
189220 for u_idx , unit in enumerate (units .fetch ('KEY' , order_by = 'unit' )):
190221
191222 x , y = (ephys .Unit & unit ).fetch1 ('unit_posx' , 'unit_posy' )
192223
193- nostim_psth , nostim_edge = (
194- psth .UnitPsth & {** unit , ** no_stim_cond }).fetch1 ('unit_psth' )
224+ # obtain unit psth per trial, for all nostim and bistim trials
225+ nostim_trials = ephys .Unit .TrialSpikes & unit & psth .TrialCondition .get_trials (no_stim_cond ['trial_condition_name' ])
226+ bistim_trials = ephys .Unit .TrialSpikes & unit & psth .TrialCondition .get_trials (bi_stim_cond ['trial_condition_name' ])
195227
196- bistim_psth , bistim_edge = (
197- psth .UnitPsth & { ** unit , ** bi_stim_cond }). fetch1 ( 'unit_psth' )
228+ nostim_psths , nostim_edge = psth . compute_unit_psth ( unit , nostim_trials . fetch ( 'KEY' ), per_trial = True )
229+ bistim_psths , bistim_edge = psth .compute_unit_psth ( unit , bistim_trials . fetch ( 'KEY' ), per_trial = True )
198230
199231 # compute the firing rate difference between contra vs. ipsi within the stimulation duration
200- ctrl_frate = nostim_psth [np .logical_and (nostim_edge [1 :] >= cue_onset , nostim_edge [1 :] <= cue_onset + stim_dur )]
201- stim_frate = bistim_psth [np .logical_and (bistim_edge [1 :] >= cue_onset , bistim_edge [1 :] <= cue_onset + stim_dur )]
232+ ctrl_frate = np .array ([nostim_psth [np .logical_and (nostim_edge >= cue_onset ,
233+ nostim_edge <= cue_onset + stim_dur )].mean ()
234+ for nostim_psth in nostim_psths ])
235+ stim_frate = np .array ([bistim_psth [np .logical_and (bistim_edge >= cue_onset ,
236+ bistim_edge <= cue_onset + stim_dur )].mean ()
237+ for bistim_psth in bistim_psths ])
202238
203239 frate_change = (stim_frate .mean () - ctrl_frate .mean ()) / ctrl_frate .mean ()
204240 frate_change = abs (frate_change ) if frate_change < 0 else 0.0001
@@ -230,9 +266,8 @@ def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
230266def plot_stacked_contra_ipsi_psth (units , axs = None ):
231267 units = units .proj ()
232268
233- period_starts = (experiment .Period
234- & 'period in ("sample", "delay", "response")' ).fetch (
235- 'period_start' )
269+ # get event start times: sample, delay, response
270+ period_names , period_starts = _get_trial_event_times (['sample' , 'delay' , 'go' ], units , 'good_noearlylick_hit' )
236271
237272 hemi = _get_units_hemisphere (units )
238273
@@ -285,9 +320,8 @@ def plot_stacked_contra_ipsi_psth(units, axs=None):
285320def plot_avg_contra_ipsi_psth (units , axs = None ):
286321 units = units .proj ()
287322
288- period_starts = (experiment .Period
289- & 'period in ("sample", "delay", "response")' ).fetch (
290- 'period_start' )
323+ # get event start times: sample, delay, response
324+ period_names , period_starts = _get_trial_event_times (['sample' , 'delay' , 'go' ], units , 'good_noearlylick_hit' )
291325
292326 hemi = _get_units_hemisphere (units )
293327
@@ -349,10 +383,6 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
349383
350384 hemi = _get_units_hemisphere (units )
351385
352- period_starts = (experiment .Period
353- & 'period in ("sample", "delay", "response")' ).fetch (
354- 'period_start' )
355-
356386 psth_s_l = (psth .UnitPsth * psth .TrialCondition & units
357387 & {'trial_condition_name' :
358388 'all_noearlylick_both_alm_stim_left' }).fetch ('unit_psth' )
@@ -369,6 +399,9 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
369399 & {'trial_condition_name' :
370400 'all_noearlylick_both_alm_nostim_right' }).fetch ('unit_psth' )
371401
402+ # get event start times: sample, delay, response
403+ period_names , period_starts = _get_trial_event_times (['sample' , 'delay' , 'go' ], units , 'good_noearlylick_hit' )
404+
372405 # get photostim duration
373406 stim_durs = np .unique ((experiment .Photostim & experiment .PhotostimEvent
374407 * psth .TrialCondition ().get_trials ('all_noearlylick_both_alm_stim' )
@@ -402,9 +435,8 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
402435 ax .set_ylim ((0 , ymax ))
403436
404437 # add shaded bar for photostim
405- delay = (experiment .Period # TODO: use from period_starts
406- & 'period = "delay"' ).fetch1 ('period_start' )
407- axs [1 ].axvspan (delay , delay + stim_dur , alpha = 0.3 , color = 'royalblue' )
438+ stim_time = period_starts [np .where (period_names == 'delay' )[0 ][0 ]]
439+ axs [1 ].axvspan (stim_time , stim_time + stim_dur , alpha = 0.3 , color = 'royalblue' )
408440
409441 return fig
410442
@@ -423,10 +455,6 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
423455
424456 hemi = _get_units_hemisphere (units )
425457
426- period_starts = (experiment .Period
427- & 'period in ("sample", "delay", "response")' ).fetch (
428- 'period_start' )
429-
430458 # no photostim:
431459 psth_n_l = psth .TrialCondition .get_cond_name_from_keywords (['_nostim' , '_left' ])[0 ]
432460 psth_n_r = psth .TrialCondition .get_cond_name_from_keywords (['_nostim' , '_right' ])[0 ]
@@ -444,6 +472,9 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
444472 psth_s_r = (psth .UnitPsth * psth .TrialCondition & units
445473 & {'trial_condition_name' : psth_s_r } & 'unit_psth is not NULL' ).fetch ('unit_psth' )
446474
475+ # get event start times: sample, delay, response
476+ period_names , period_starts = _get_trial_event_times (['sample' , 'delay' , 'go' ], units , 'good_noearlylick_hit' )
477+
447478 # get photostim duration
448479 stim_trial_cond_name = psth .TrialCondition .get_cond_name_from_keywords (condition_name_kw + ['_stim' ])[0 ]
449480 stim_durs = np .unique ((experiment .Photostim & experiment .PhotostimEvent
@@ -474,7 +505,7 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
474505 ax .set_xlim ([_plt_xmin , _plt_xmax ])
475506
476507 # add shaded bar for photostim
477- stim_time = ( experiment . Period & 'period = " delay"' ). fetch1 ( 'period_start' )
508+ stim_time = period_starts [ np . where ( period_names == ' delay' )[ 0 ][ 0 ]]
478509 axs [1 ].axvspan (stim_time , stim_time + stim_dur , alpha = 0.3 , color = 'royalblue' )
479510
480511 return fig
@@ -484,7 +515,8 @@ def plot_coding_direction(units, time_period=None, axs=None):
484515 _ , proj_contra_trial , proj_ipsi_trial , time_stamps , _ = psth .compute_CD_projected_psth (
485516 units .fetch ('KEY' ), time_period = time_period )
486517
487- period_starts = (experiment .Period & 'period in ("sample", "delay", "response")' ).fetch ('period_start' )
518+ # get event start times: sample, delay, response
519+ period_names , period_starts = _get_trial_event_times (['sample' , 'delay' , 'go' ], units , 'good_noearlylick' )
488520
489521 fig = None
490522 if axs is None :
@@ -515,7 +547,8 @@ def plot_paired_coding_direction(unit_g1, unit_g2, labels=None, time_period=None
515547 _ , proj_contra_trial_g2 , proj_ipsi_trial_g2 , time_stamps , unit_g2_hemi = psth .compute_CD_projected_psth (
516548 unit_g2 .fetch ('KEY' ), time_period = time_period )
517549
518- period_starts = (experiment .Period & 'period in ("sample", "delay", "response")' ).fetch ('period_start' )
550+ # get event start times: sample, delay, response
551+ period_names , period_starts = _get_trial_event_times (['sample' , 'delay' , 'go' ], unit_g1 , 'good_noearlylick' )
519552
520553 if labels :
521554 assert len (labels ) == 2
0 commit comments