Skip to content

Commit 8953084

Browse files
committed
added code for all the plots in plot_photometry. also added the preprocessing needed for the new plots within the convert_photometry.py file. There is still some things that need to be rewritten given the new nwb converted data. Basically anything that calls the previous sampledata dataframe and visinds
1 parent 9702583 commit 8953084

File tree

3 files changed

+227
-241
lines changed

3 files changed

+227
-241
lines changed

src/jdb_to_nwb/convert_photometry.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from scipy.sparse import diags, eye, csc_matrix
1111
from scipy.sparse.linalg import spsolve
1212
from sklearn.linear_model import Lasso
13+
from scipy.stats import linregress
14+
15+
from plot_photometry import .
1316

1417
# Some of these imports are unused for now but will be used for photometry metadata
1518
from ndx_fiber_photometry import (
@@ -430,22 +433,105 @@ def process_ppd_photometry(nwbfile: NWBFile, ppd_file_path):
430433
relative_raw_signal = raw_green / raw_405
431434

432435
sampling_rate = ppd_data['sampling_rate']
436+
time_seconds = ppd_data['time']/1000
433437
visits = ppd_data['pulse_inds_1'][1:]
434438

439+
pulse_times_in_mins = [time / 60000 for time in visits]
440+
pulse_times_in_seconds = np.array(pulse_times_in_mins) * 60
441+
442+
xvals = np.arange(0,len(raw_green))/sampling_rate/60
443+
444+
plot_raw_ratiometric_and_565_photometry_signals(visits, xvals, pulse_times_in_mins, raw_green, raw_red, raw_405, relative_raw_signal, sampling_rate)
445+
446+
# Calculate the correlation between the signals. This I am not too sure if we want to keep or not. Best to discuss with Josh?
447+
slope_405x470, intercep_405x470, r_value_405x470 = linregress(x=raw_405, y=raw_green)
448+
plot_405_470_correlation(raw_405, raw_green,slope_405x470, intercep_405x470, r_value_405x470)
449+
450+
slope_405x565, intercep_405x565, r_value_405x565 = linregress(x=raw_405, y=raw_red)
451+
plot_405_565_correlation(raw_405, raw_red, slope_405x565, intercep_405x565, r_value_405x565)
452+
453+
slope_470x565, intercep_470x565, r_value_470x565 = linregress(x=raw_green, y=raw_red)
454+
plot_470_565_correlation(raw_green, raw_red, slope_470x565, intercep_470x565, r_value_470x565)
455+
435456
# low pass at 10Hz to remove high frequency noise
436457
print('Filtering data...')
437458
b,a = butter(2, 10, btype='low', fs=sampling_rate)
438459
green_denoised = filtfilt(b,a, raw_green)
439460
red_denoised = filtfilt(b,a, raw_red)
440461
ratio_denoised = filtfilt(b,a, relative_raw_signal)
441462
denoised_405 = filtfilt(b,a, raw_405)
463+
442464
# high pass at 0.001Hz which removes the drift due to bleaching, but will also remove any physiological variation in the signal on very slow timescales.
443465
b,a = butter(2, 0.001, btype='high', fs=sampling_rate)
444466
green_highpass = filtfilt(b,a, green_denoised, padtype='even')
445467
red_highpass = filtfilt(b,a, red_denoised, padtype='even')
446468
ratio_highpass = filtfilt(b,a, ratio_denoised, padtype='even')
447469
highpass_405 = filtfilt(b,a, denoised_405, padtype='even')
448470

471+
# Plot the filtered signals of interest against each other
472+
slope_filtered, intercept_filtered, r_value_filtered = linregress(x=red_highpass, y=ratio_highpass)
473+
plot_ratio_565_correlation(ratio_highpass, red_highpass, slope_filtered, intercept_filtered, r_value_filtered)
474+
475+
plot_interactive_filtered_signals(time_seconds, ratio_highpass, green_highpass, highpass_405, red_highpass)
476+
477+
478+
time_window_in_sec = input('Enter the time window in seconds for the photometry signal alignment: ') # >:)
479+
480+
samples_window = time_window_in_sec * sampling_rate
481+
total_rewarded_trials = (rwd == 1).sum()
482+
print("Total rewarded trials: "+str(total_rewarded_trials))
483+
484+
total_omitted_trials = (rwd == 0).sum()
485+
print("Total omitted trials: "+str(total_omitted_trials))
486+
487+
# Initialize lists for aligned traces
488+
aligned_ratio = []
489+
aligned_green = []
490+
aligned_405 = []
491+
aligned_red = []
492+
493+
#################################################################################
494+
# QUESTION FOR STEPH:
495+
# I need this code to run some of the plotting from this point forward.
496+
# Is there a way to get this information from convert_behvaior?
497+
# For now, I'll keep everything the way it is using "visinds" and "port_entry_indeces"
498+
visinds = sampledata.loc[sampledata.port.notnull()].index.values
499+
port_entry_indices = visinds
500+
#################################################################################
501+
502+
# Align windows
503+
for idx in visits:
504+
if idx - samples_window >= 0 and idx + samples_window < len(time_seconds): # Ensure indices are in bounds
505+
aligned_ratio.append(ratio_highpass[idx - samples_window:idx + samples_window])
506+
aligned_green.append(green_highpass[idx - samples_window:idx + samples_window])
507+
aligned_405.append(highpass_405[idx - samples_window:idx + samples_window])
508+
aligned_red.append(red_highpass[idx - samples_window:idx + samples_window])
509+
510+
# Convert to numpy arrays for easier averaging
511+
aligned_ratio = np.array(aligned_ratio)
512+
aligned_green = np.array(aligned_green)
513+
aligned_405 = np.array(aligned_405)
514+
aligned_red = np.array(aligned_red)
515+
516+
# Calculate means and SEMs
517+
mean_ratio = np.mean(aligned_ratio, axis=0)
518+
sem_ratio = np.std(aligned_ratio, axis=0) / np.sqrt(len(aligned_ratio))
519+
520+
mean_green = np.mean(aligned_green, axis=0)
521+
sem_green = np.std(aligned_green, axis=0) / np.sqrt(len(aligned_green))
522+
523+
mean_405 = np.mean(aligned_405, axis=0)
524+
sem_405 = np.std(aligned_405, axis=0) / np.sqrt(len(aligned_405))
525+
526+
mean_red = np.mean(aligned_red, axis=0)
527+
sem_red = np.std(aligned_red, axis=0) / np.sqrt(len(aligned_red))
528+
529+
# Generate time axis for aligned data
530+
aligned_time = np.linspace(-time_window_in_sec, time_window_in_sec, 2 * samples_window)
531+
532+
533+
plot_signals_aligned_port_entry(port_entry_indices, aligned_time, mean_ratio, sem_ratio, mean_green, sem_green, mean_405, sem_405, mean_red, sem_red, total_rewarded_trials, total_omitted_trials)
534+
449535
# Z-score of each signal to normalize the data
450536
print('Z-scoring data...')
451537
green_zscored = np.divide(np.subtract(green_highpass,green_highpass.mean()),green_highpass.std())
@@ -457,6 +543,110 @@ def process_ppd_photometry(nwbfile: NWBFile, ppd_file_path):
457543
ratio_zscored = np.divide(np.subtract(ratio_highpass,ratio_highpass.mean()),ratio_highpass.std())
458544
print('Done processing photometry data!')
459545

546+
plot_normalized_signals(xvals, pulse_times_in_mins, green_zscored, zscored_405, red_zscored, ratio_zscored)
547+
548+
ratio_traces = []
549+
red_traces = []
550+
time_window = int(time_window_in_sec)
551+
552+
time_window_xvals = np.arange(-time_window*86,time_window*86+1)/86
553+
554+
# Loop through each index in visinds
555+
for i in visinds:
556+
# Extract a time window of data around the current index
557+
start_idx = i - 86 * time_window
558+
end_idx = i + 86 * time_window
559+
ratio_trace = sampledata.loc[start_idx:end_idx, "ratio_z_scored"].values
560+
red_trace = sampledata.loc[start_idx:end_idx, "red_z_scored"].values
561+
562+
# Only use traces of the correct length (matching xvals)
563+
if len(ratio_trace) == len(time_window_xvals):
564+
ratio_traces.append(ratio_trace)
565+
566+
if len(red_trace) == len(time_window_xvals):
567+
red_traces.append(red_trace)
568+
569+
# Calculate mean and SEM
570+
if ratio_traces:
571+
ratio_traces_mean = np.mean(ratio_traces, axis=0)
572+
ratio_traces_sem = np.std(ratio_traces, axis=0) / np.sqrt(len(ratio_traces))
573+
else: # If no traces, create empty arrays
574+
ratio_traces_rwd_mean = np.zeros_like(time_window_xvals)
575+
ratio_traces_rwd_sem = np.zeros_like(time_window_xvals)
576+
577+
# Calculate mean and SEM
578+
if red_traces:
579+
red_traces_mean = np.mean(red_traces, axis=0)
580+
red_traces_sem = np.std(red_traces, axis=0) / np.sqrt(len(red_traces))
581+
else: # If no traces, create empty arrays
582+
red_traces_mean = np.zeros_like(time_window_xvals)
583+
red_traces_sem = np.zeros_like(time_window_xvals)
584+
585+
plot_ratio_and_565_signals_aligned_port_entry(rat, date, time_window, time_window_xvals, ratio_traces_mean, ratio_traces_sem, red_traces_mean, red_traces_sem, visinds) # 'rat' and 'date' are only used for the title of the plot
586+
587+
ratio_rwd_traces = []
588+
ratio_om_traces = []
589+
red_rwd_traces = []
590+
red_om_traces = []
591+
592+
# Loop through each index in visinds
593+
for i in visinds:
594+
# Extract a time window of data around the current index
595+
start_idx = i - 86 * time_window_in_sec
596+
end_idx = i + 86 * time_window_in_sec
597+
ratio_trace = sampledata.loc[start_idx:end_idx, "ratio_z_scored"].values
598+
red_trace = sampledata.loc[start_idx:end_idx, "red_z_scored"].values
599+
600+
# Only use traces of the correct length (matching xvals)
601+
if len(ratio_trace) == len(time_window_xvals):
602+
# Check if the trial is rewarded or omitted
603+
if sampledata.loc[i, 'rwd'] == 1:
604+
ratio_rwd_traces.append(ratio_trace) # Add to rewarded traces
605+
else:
606+
ratio_om_traces.append(ratio_trace) # Add to omitted traces
607+
608+
# Only use traces of the correct length (matching xvals)
609+
if len(red_trace) == len(time_window_xvals):
610+
# Check if the trial is rewarded or omitted
611+
if sampledata.loc[i, 'rwd'] == 1:
612+
red_rwd_traces.append(red_trace) # Add to rewarded traces
613+
else:
614+
red_om_traces.append(red_trace) # Add to omitted traces
615+
616+
# Calculate mean and SEM for rewarded traces
617+
if ratio_rwd_traces: # If there are any rewarded traces
618+
ratio_rwd_mean = np.mean(ratio_rwd_traces, axis=0)
619+
ratio_rwd_sem = np.std(ratio_rwd_traces, axis=0) / np.sqrt(len(ratio_rwd_traces))
620+
else: # If no rewarded traces, create empty arrays
621+
ratio_rwd_mean = np.zeros_like(time_window_xvals)
622+
ratio_rwd_sem = np.zeros_like(time_window_xvals)
623+
624+
# Calculate mean and SEM for rewarded traces
625+
if red_rwd_traces: # If there are any rewarded traces
626+
red_rwd_mean = np.mean(red_rwd_traces, axis=0)
627+
red_rwd_sem = np.std(red_rwd_traces, axis=0) / np.sqrt(len(red_rwd_traces))
628+
else: # If no rewarded traces, create empty arrays
629+
red_rwd_mean = np.zeros_like(time_window_xvals)
630+
red_rwd_sem = np.zeros_like(time_window_xvals)
631+
632+
# Calculate mean and SEM for omitted traces
633+
if ratio_om_traces: # If there are any omitted traces
634+
ratio_om_mean = np.mean(ratio_om_traces, axis=0)
635+
ratio_om_sem = np.std(ratio_om_traces, axis=0) / np.sqrt(len(ratio_om_traces))
636+
else: # If no omitted traces, create empty arrays
637+
ratio_om_mean = np.zeros_like(time_window_xvals)
638+
ratio_om_sem = np.zeros_like(time_window_xvals)
639+
640+
# Calculate mean and SEM for omitted traces
641+
if red_om_traces: # If there are any omitted traces
642+
red_om_mean = np.mean(red_om_traces, axis=0)
643+
red_om_sem = np.std(red_om_traces, axis=0) / np.sqrt(len(red_om_traces))
644+
else: # If no omitted traces, create empty arrays
645+
red_om_mean = np.zeros_like(time_window_xvals)
646+
red_om_sem = np.zeros_like(time_window_xvals)
647+
648+
plot_ratio_and_565_signals_aligned_port_entry_separated_by_rewarded_or_omitted(rat, date, time_window_in_sec, time_window_xvals, ratio_rwd_mean, ratio_rwd_sem, ratio_om_mean, ratio_om_sem, red_rwd_mean, red_rwd_sem, red_om_mean, red_om_sem, total_rewarded_trials, total_omitted_trials)
649+
460650
# Add actual photometry data to the NWB
461651
print("Adding photometry signals to NWB ...")
462652

src/plotting/plot_behavior.py

Whitespace-only changes.

0 commit comments

Comments
 (0)