Skip to content

Commit 5736578

Browse files
authored
Merge pull request #121 from vathes/master
report and publication fixes
2 parents 69f1d82 + 770b015 commit 5736578

File tree

8 files changed

+194
-122
lines changed

8 files changed

+194
-122
lines changed

pipeline/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
log = logging.getLogger(__name__)
88

99

10+
# safe-guard in case `custom` is not provided
11+
if 'custom' not in dj.config:
12+
dj.config['custom'] = {}
13+
14+
1015
def get_schema_name(name):
1116
try:
1217
return dj.config['custom']['{}.database'.format(name)]

pipeline/fixes/fix_0002_delay_events.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def fix_session(session_key):
5353
if len(filelist) != len(files):
5454
log.warning("behavior files missing in {} ({}/{}). skipping".format(
5555
session_key, len(filelist), len(files)))
56-
return
56+
return False
5757

5858
log.info('filelist: {}'.format(filelist))
5959

@@ -126,7 +126,7 @@ def fix_session(session_key):
126126
# all files were internally invalid or size < 100k
127127
if not trials:
128128
log.warning('skipping ., no valid files')
129-
return
129+
return False
130130

131131
key = session_key
132132
skey = (experiment.Session & key).fetch1()
@@ -538,6 +538,8 @@ def fix_session(session_key):
538538
rows['corrected_trial_event'], ignore_extra_fields=True,
539539
allow_direct_insert=True)
540540

541+
return True
542+
541543

542544
def verify_session(s):
543545
log.info('verifying_session {}'.format(s))
@@ -559,23 +561,23 @@ def note_prob(s, e, msg):
559561

560562
if newstate == 'presample':
561563
if state and state not in {'presample', 'trialend'}:
562-
note_prob(s, e)
564+
note_prob(s, e, 'trialend !-> presample')
563565
nerr += 1
564566
if newstate == 'sample':
565567
if state and state not in {'presample', 'sample'}:
566-
note_prob(s, e)
568+
note_prob(s, e, 'presaple !-> sample')
567569
nerr += 1
568570
if newstate == 'delay':
569571
if state and state not in {'sample', 'delay'}:
570-
note_prob(s, e)
572+
note_prob(s, e, 'sample !-> delay')
571573
nerr += 1
572574
if newstate == 'go':
573575
if state and state not in {'delay', 'go'}:
574-
note_prob(s, e)
576+
note_prob(s, e, 'delay !-> go')
575577
nerr += 1
576578
if newstate == 'trialend':
577579
if state and state not in {'go', 'trialend'}:
578-
note_prob(s, e)
580+
note_prob(s, e, 'go !-> trialend')
579581
nerr += 1
580582

581583
eid, state = neweid, newstate
@@ -585,6 +587,7 @@ def note_prob(s, e, msg):
585587
else:
586588
log.warning('session {} had {} verification errors.'.format(s, nerr))
587589

590+
588591
def fix_0002_delay_events():
589592
with dj.conn().transaction:
590593

@@ -596,8 +599,10 @@ def fix_0002_delay_events():
596599
q = (experiment.Session & behavior_ingest.BehaviorIngest)
597600

598601
for s in q.fetch('KEY'):
599-
fix_session(s)
600-
verify_session(s)
602+
if fix_session(s):
603+
verify_session(s)
604+
else:
605+
log.warning('session {} verify skipped - not fixed'.format(s))
601606

602607

603608
if __name__ == '__main__':

pipeline/plot/unit_characteristic_plot.py

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,25 @@
1010
from pipeline import experiment, ephys, psth
1111

1212
from pipeline.plot.util import (_plot_with_sem, _extract_one_stim_dur, _get_units_hemisphere,
13-
_plot_stacked_psth_diff, _plot_avg_psth,
14-
jointplot_w_hue)
13+
_get_trial_event_times, _get_clustering_method,
14+
_plot_stacked_psth_diff, _plot_avg_psth, jointplot_w_hue)
1515

1616
m_scale = 1200
1717
_plt_xmin = -3
1818
_plt_xmax = 2
1919

2020

21-
def plot_clustering_quality(probe_insertion, axs=None):
21+
def plot_clustering_quality(probe_insertion, clustering_method=None, axs=None):
2222
probe_insertion = probe_insertion.proj()
23-
amp, snr, spk_rate, isi_violation = (ephys.Unit * ephys.UnitStat
24-
* ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch(
23+
24+
if clustering_method is None:
25+
try:
26+
clustering_method = _get_clustering_method(probe_insertion)
27+
except ValueError as e:
28+
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')
29+
30+
amp, snr, spk_rate, isi_violation = (ephys.Unit * ephys.UnitStat * ephys.ProbeInsertion.InsertionLocation
31+
& probe_insertion & {'clustering_method': clustering_method}).fetch(
2532
'unit_amp', 'unit_snr', 'avg_firing_rate', 'isi_violation')
2633

2734
metrics = {'amp': amp,
@@ -52,11 +59,18 @@ def plot_clustering_quality(probe_insertion, axs=None):
5259
return fig
5360

5461

55-
def plot_unit_characteristic(probe_insertion, axs=None):
62+
def plot_unit_characteristic(probe_insertion, clustering_method=None, axs=None):
5663
probe_insertion = probe_insertion.proj()
64+
65+
if clustering_method is None:
66+
try:
67+
clustering_method = _get_clustering_method(probe_insertion)
68+
except ValueError as e:
69+
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')
70+
5771
amp, snr, spk_rate, x, y, insertion_depth = (
5872
ephys.Unit * ephys.ProbeInsertion.InsertionLocation * ephys.UnitStat
59-
& probe_insertion & 'unit_quality != "all"').fetch(
73+
& probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"').fetch(
6074
'unit_amp', 'unit_snr', 'avg_firing_rate', 'unit_posx', 'unit_posy', 'dv_location')
6175

6276
insertion_depth = np.where(np.isnan(insertion_depth), 0, insertion_depth)
@@ -102,12 +116,20 @@ def plot_unit_characteristic(probe_insertion, axs=None):
102116
return fig
103117

104118

105-
def plot_unit_selectivity(probe_insertion, axs=None):
119+
def plot_unit_selectivity(probe_insertion, clustering_method=None, axs=None):
106120
probe_insertion = probe_insertion.proj()
121+
122+
if clustering_method is None:
123+
try:
124+
clustering_method = _get_clustering_method(probe_insertion)
125+
except ValueError as e:
126+
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')
127+
107128
attr_names = ['unit', 'period', 'period_selectivity', 'contra_firing_rate',
108-
'ipsi_firing_rate', 'unit_posx', 'unit_posy', 'dv_location']
129+
'ipsi_firing_rate', 'unit_posx', 'unit_posy', 'dv_location']
109130
selective_units = (psth.PeriodSelectivity * ephys.Unit * ephys.ProbeInsertion.InsertionLocation
110-
* experiment.Period & probe_insertion & 'period_selectivity != "non-selective"').fetch(*attr_names)
131+
* experiment.Period & probe_insertion & {'clustering_method': clustering_method}
132+
& 'period_selectivity != "non-selective"').fetch(*attr_names)
111133
selective_units = pd.DataFrame(selective_units).T
112134
selective_units.columns = attr_names
113135
selective_units.period_selectivity.astype('category')
@@ -162,10 +184,16 @@ def plot_unit_selectivity(probe_insertion, axs=None):
162184
return fig
163185

164186

165-
def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
187+
def plot_unit_bilateral_photostim_effect(probe_insertion, clustering_method=None, axs=None):
166188
probe_insertion = probe_insertion.proj()
189+
190+
if clustering_method is None:
191+
try:
192+
clustering_method = _get_clustering_method(probe_insertion)
193+
except ValueError as e:
194+
raise ValueError(str(e) + '\nPlease specify one with the kwarg "clustering_method"')
195+
167196
dv_loc = (ephys.ProbeInsertion.InsertionLocation & probe_insertion).fetch1('dv_location')
168-
cue_onset = (experiment.Period & 'period = "delay"').fetch1('period_start')
169197

170198
no_stim_cond = (psth.TrialCondition
171199
& {'trial_condition_name':
@@ -181,24 +209,32 @@ def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
181209
& probe_insertion).fetch('duration'))
182210
stim_dur = _extract_one_stim_dur(stim_durs)
183211

184-
units = ephys.Unit & probe_insertion & 'unit_quality != "all"'
212+
units = ephys.Unit & probe_insertion & {'clustering_method': clustering_method} & 'unit_quality != "all"'
213+
214+
metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change'])
185215

186-
metrics = pd.DataFrame(columns=['unit', 'x', 'y', 'frate_change']) # TODO: account for dv_location
216+
_, cue_onset = _get_trial_event_times(['delay'], units, 'all_noearlylick_both_alm_nostim')
217+
cue_onset = cue_onset[0]
187218

188219
# XXX: could be done with 1x fetch+join
189220
for u_idx, unit in enumerate(units.fetch('KEY', order_by='unit')):
190221

191222
x, y = (ephys.Unit & unit).fetch1('unit_posx', 'unit_posy')
192223

193-
nostim_psth, nostim_edge = (
194-
psth.UnitPsth & {**unit, **no_stim_cond}).fetch1('unit_psth')
224+
# obtain unit psth per trial, for all nostim and bistim trials
225+
nostim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(no_stim_cond['trial_condition_name'])
226+
bistim_trials = ephys.Unit.TrialSpikes & unit & psth.TrialCondition.get_trials(bi_stim_cond['trial_condition_name'])
195227

196-
bistim_psth, bistim_edge = (
197-
psth.UnitPsth & {**unit, **bi_stim_cond}).fetch1('unit_psth')
228+
nostim_psths, nostim_edge = psth.compute_unit_psth(unit, nostim_trials.fetch('KEY'), per_trial=True)
229+
bistim_psths, bistim_edge = psth.compute_unit_psth(unit, bistim_trials.fetch('KEY'), per_trial=True)
198230

199231
# compute the firing rate difference between contra vs. ipsi within the stimulation duration
200-
ctrl_frate = nostim_psth[np.logical_and(nostim_edge[1:] >= cue_onset, nostim_edge[1:] <= cue_onset + stim_dur)]
201-
stim_frate = bistim_psth[np.logical_and(bistim_edge[1:] >= cue_onset, bistim_edge[1:] <= cue_onset + stim_dur)]
232+
ctrl_frate = np.array([nostim_psth[np.logical_and(nostim_edge >= cue_onset,
233+
nostim_edge <= cue_onset + stim_dur)].mean()
234+
for nostim_psth in nostim_psths])
235+
stim_frate = np.array([bistim_psth[np.logical_and(bistim_edge >= cue_onset,
236+
bistim_edge <= cue_onset + stim_dur)].mean()
237+
for bistim_psth in bistim_psths])
202238

203239
frate_change = (stim_frate.mean() - ctrl_frate.mean()) / ctrl_frate.mean()
204240
frate_change = abs(frate_change) if frate_change < 0 else 0.0001
@@ -230,9 +266,8 @@ def plot_unit_bilateral_photostim_effect(probe_insertion, axs=None):
230266
def plot_stacked_contra_ipsi_psth(units, axs=None):
231267
units = units.proj()
232268

233-
period_starts = (experiment.Period
234-
& 'period in ("sample", "delay", "response")').fetch(
235-
'period_start')
269+
# get event start times: sample, delay, response
270+
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')
236271

237272
hemi = _get_units_hemisphere(units)
238273

@@ -285,9 +320,8 @@ def plot_stacked_contra_ipsi_psth(units, axs=None):
285320
def plot_avg_contra_ipsi_psth(units, axs=None):
286321
units = units.proj()
287322

288-
period_starts = (experiment.Period
289-
& 'period in ("sample", "delay", "response")').fetch(
290-
'period_start')
323+
# get event start times: sample, delay, response
324+
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')
291325

292326
hemi = _get_units_hemisphere(units)
293327

@@ -349,10 +383,6 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
349383

350384
hemi = _get_units_hemisphere(units)
351385

352-
period_starts = (experiment.Period
353-
& 'period in ("sample", "delay", "response")').fetch(
354-
'period_start')
355-
356386
psth_s_l = (psth.UnitPsth * psth.TrialCondition & units
357387
& {'trial_condition_name':
358388
'all_noearlylick_both_alm_stim_left'}).fetch('unit_psth')
@@ -369,6 +399,9 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
369399
& {'trial_condition_name':
370400
'all_noearlylick_both_alm_nostim_right'}).fetch('unit_psth')
371401

402+
# get event start times: sample, delay, response
403+
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')
404+
372405
# get photostim duration
373406
stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
374407
* psth.TrialCondition().get_trials('all_noearlylick_both_alm_stim')
@@ -402,9 +435,8 @@ def plot_psth_bilateral_photostim_effect(units, axs=None):
402435
ax.set_ylim((0, ymax))
403436

404437
# add shaded bar for photostim
405-
delay = (experiment.Period # TODO: use from period_starts
406-
& 'period = "delay"').fetch1('period_start')
407-
axs[1].axvspan(delay, delay + stim_dur, alpha=0.3, color='royalblue')
438+
stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
439+
axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')
408440

409441
return fig
410442

@@ -423,10 +455,6 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
423455

424456
hemi = _get_units_hemisphere(units)
425457

426-
period_starts = (experiment.Period
427-
& 'period in ("sample", "delay", "response")').fetch(
428-
'period_start')
429-
430458
# no photostim:
431459
psth_n_l = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_left'])[0]
432460
psth_n_r = psth.TrialCondition.get_cond_name_from_keywords(['_nostim', '_right'])[0]
@@ -444,6 +472,9 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
444472
psth_s_r = (psth.UnitPsth * psth.TrialCondition & units
445473
& {'trial_condition_name': psth_s_r} & 'unit_psth is not NULL').fetch('unit_psth')
446474

475+
# get event start times: sample, delay, response
476+
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick_hit')
477+
447478
# get photostim duration
448479
stim_trial_cond_name = psth.TrialCondition.get_cond_name_from_keywords(condition_name_kw + ['_stim'])[0]
449480
stim_durs = np.unique((experiment.Photostim & experiment.PhotostimEvent
@@ -474,7 +505,7 @@ def plot_psth_photostim_effect(units, condition_name_kw=['both_alm'], axs=None):
474505
ax.set_xlim([_plt_xmin, _plt_xmax])
475506

476507
# add shaded bar for photostim
477-
stim_time = (experiment.Period & 'period = "delay"').fetch1('period_start')
508+
stim_time = period_starts[np.where(period_names == 'delay')[0][0]]
478509
axs[1].axvspan(stim_time, stim_time + stim_dur, alpha=0.3, color='royalblue')
479510

480511
return fig
@@ -484,7 +515,8 @@ def plot_coding_direction(units, time_period=None, axs=None):
484515
_, proj_contra_trial, proj_ipsi_trial, time_stamps, _ = psth.compute_CD_projected_psth(
485516
units.fetch('KEY'), time_period=time_period)
486517

487-
period_starts = (experiment.Period & 'period in ("sample", "delay", "response")').fetch('period_start')
518+
# get event start times: sample, delay, response
519+
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], units, 'good_noearlylick')
488520

489521
fig = None
490522
if axs is None:
@@ -515,7 +547,8 @@ def plot_paired_coding_direction(unit_g1, unit_g2, labels=None, time_period=None
515547
_, proj_contra_trial_g2, proj_ipsi_trial_g2, time_stamps, unit_g2_hemi = psth.compute_CD_projected_psth(
516548
unit_g2.fetch('KEY'), time_period=time_period)
517549

518-
period_starts = (experiment.Period & 'period in ("sample", "delay", "response")').fetch('period_start')
550+
# get event start times: sample, delay, response
551+
period_names, period_starts = _get_trial_event_times(['sample', 'delay', 'go'], unit_g1, 'good_noearlylick')
519552

520553
if labels:
521554
assert len(labels) == 2

0 commit comments

Comments
 (0)