Skip to content

Commit 316acd9

Browse files
committed
fixed the code again and added all the processing code needed for each plot into the plot function itself. ignore plot_behavior for now. thats a task for later.
1 parent 8953084 commit 316acd9

File tree

1 file changed

+164
-9
lines changed

1 file changed

+164
-9
lines changed

src/plotting/plot_photometry.py

Lines changed: 164 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from matplotlib.widgets import Slider
44
from scipy.stats import linregress
55

6-
def plot_raw_ratiometric_and_565_photometry_signals(xvals, pulse_times_in_mins, raw_green, raw_red, raw_405, relative_raw_signal, sampling_rate):
6+
def plot_raw_photometry_signals(visits, raw_green, raw_red, raw_405, relative_raw_signal, sampling_rate):
77
"""
88
Plots the raw 470, 405, 565 and ratiometric 470/405 fluorescence signals.
99
"""
10+
xvals = np.arange(0,len(raw_green))/sampling_rate/60
11+
pulse_times_in_mins = [time / 60000 for time in visits]
1012

1113
raw = plt.figure(figsize=(16, 10))
1214
plt.suptitle('Raw & ratiometric 470/405 fluorescence signals', fontsize=16)
@@ -42,6 +44,7 @@ def plot_405_470_correlation(raw_405, raw_green, slope_405x470, intercep_405x470
4244
"""
4345
Plots the correlation between the 405 and 470 signals.
4446
"""
47+
slope_405x470, intercep_405x470, r_value_405x470 = linregress(x=raw_405, y=raw_green)
4548
plt.figure()
4649
plt.scatter(raw_405[::5], raw_green[::5],alpha=0.1, marker='.')
4750
x = np.array(plt.xlim())
@@ -62,6 +65,7 @@ def plot_405_565_correlation(raw_405, raw_red, slope_405x565, intercep_405x565,
6265
"""
6366
Plots the correlation between the 405 and 565 signals.
6467
"""
68+
slope_405x565, intercep_405x565, r_value_405x565 = linregress(x=raw_405, y=raw_red)
6569
plt.figure()
6670
plt.scatter(raw_405[::5], raw_red[::5],alpha=0.1, marker='.')
6771
x = np.array(plt.xlim())
@@ -82,7 +86,7 @@ def plot_470_565_correlation(raw_green, raw_red, slope_470x565, intercep_470x565
8286
"""
8387
Plots the correlation between the 470 and 565 signals.
8488
"""
85-
89+
slope_470x565, intercep_470x565, r_value_470x565 = linregress(x=raw_green, y=raw_red)
8690
plt.figure()
8791
plt.scatter(raw_green[::5], raw_red[::5],alpha=0.1, marker='.')
8892
x = np.array(plt.xlim())
@@ -103,7 +107,7 @@ def plot_ratio_565_correlation(ratio_highpass, red_highpass, slope_filtered, int
103107
"""
104108
Plots the correlation between the 470/405 ratio and the 565 signal.
105109
"""
106-
110+
slope_filtered, intercept_filtered, r_value_filtered = linregress(x=red_highpass, y=ratio_highpass)
107111
plt.figure(figsize=(13, 10))
108112
plt.scatter(red_highpass[::5], ratio_highpass[::5],alpha=0.1, marker='.')
109113
x = np.array(plt.xlim())
@@ -175,10 +179,62 @@ def update(val):
175179
plt.show()
176180
return None
177181

178-
def plot_signals_aligned_port_entry(port_entry_indices, aligned_time, mean_ratio, sem_ratio, mean_green, sem_green, mean_405, sem_405, mean_red, sem_red, total_rewarded_trials, total_omitted_trials):
182+
def plot_signals_aligned_port_entry(sampling_rate, time_seconds, visits, port_entry_indices, aligned_time,ratio_highpass, highpass_405, green_highpass, red_highpass, mean_ratio, sem_ratio, mean_green, sem_green, mean_405, sem_405, mean_red, sem_red, total_rewarded_trials, total_omitted_trials):
179183
"""
180184
"""
181-
185+
time_window_in_sec = 10
186+
187+
samples_window = time_window_in_sec * sampling_rate
188+
total_rewarded_trials = (rwd == 1).sum()
189+
print("Total rewarded trials: "+str(total_rewarded_trials))
190+
191+
total_omitted_trials = (rwd == 0).sum()
192+
print("Total omitted trials: "+str(total_omitted_trials))
193+
194+
# Initialize lists for aligned traces
195+
aligned_ratio = []
196+
aligned_green = []
197+
aligned_405 = []
198+
aligned_red = []
199+
200+
#################################################################################
201+
# QUESTION FOR STEPH:
202+
# I need this code to run some of the plotting from this point forward.
203+
# Is there a way to get this information from convert_behvaior?
204+
# For now, I'll keep everything the way it is using "visinds" and "port_entry_indeces"
205+
visinds = sampledata.loc[sampledata.port.notnull()].index.values
206+
port_entry_indices = visinds
207+
#################################################################################
208+
209+
# Align windows
210+
for idx in visits:
211+
if idx - samples_window >= 0 and idx + samples_window < len(time_seconds): # Ensure indices are in bounds
212+
aligned_ratio.append(ratio_highpass[idx - samples_window:idx + samples_window])
213+
aligned_green.append(green_highpass[idx - samples_window:idx + samples_window])
214+
aligned_405.append(highpass_405[idx - samples_window:idx + samples_window])
215+
aligned_red.append(red_highpass[idx - samples_window:idx + samples_window])
216+
217+
# Convert to numpy arrays for easier averaging
218+
aligned_ratio = np.array(aligned_ratio)
219+
aligned_green = np.array(aligned_green)
220+
aligned_405 = np.array(aligned_405)
221+
aligned_red = np.array(aligned_red)
222+
223+
# Calculate means and SEMs
224+
mean_ratio = np.mean(aligned_ratio, axis=0)
225+
sem_ratio = np.std(aligned_ratio, axis=0) / np.sqrt(len(aligned_ratio))
226+
227+
mean_green = np.mean(aligned_green, axis=0)
228+
sem_green = np.std(aligned_green, axis=0) / np.sqrt(len(aligned_green))
229+
230+
mean_405 = np.mean(aligned_405, axis=0)
231+
sem_405 = np.std(aligned_405, axis=0) / np.sqrt(len(aligned_405))
232+
233+
mean_red = np.mean(aligned_red, axis=0)
234+
sem_red = np.std(aligned_red, axis=0) / np.sqrt(len(aligned_red))
235+
236+
# Generate time axis for aligned data
237+
aligned_time = np.linspace(-time_window_in_sec, time_window_in_sec, 2 * samples_window)
182238
# Plot averaged traces
183239
plt.figure(figsize=(16, 10))
184240
plt.title("Average Traces Aligned to Port Entry ("+str(len(port_entry_indices))+" total trials)")
@@ -208,10 +264,10 @@ def plot_signals_aligned_port_entry(port_entry_indices, aligned_time, mean_ratio
208264

209265
return None
210266

211-
def plot_normalized_signals(pulse_times_in_mins, green_zscored, zscored_405, red_zscored, ratio_zscored, xvals):
267+
def plot_normalized_signals(sampling_rate, pulse_times_in_mins, green_zscored, zscored_405, red_zscored, ratio_zscored):
212268
"""
213269
"""
214-
270+
xvals = np.arange(0,len(green_zscored))/sampling_rate/60
215271
zscrd = plt.figure(figsize=(16, 10))
216272
plt.suptitle('Z-scored signals calculated after preprocessing raw singals \n by applying a high-pass filter at 0.001 Hz and a low-pass filter at 10 Hz')
217273

@@ -245,7 +301,42 @@ def plot_normalized_signals(pulse_times_in_mins, green_zscored, zscored_405, red
245301
def plot_ratio_and_565_signals_aligned_port_entry(rat, date, time_window, time_window_xvals, ratio_traces_mean, ratio_traces_sem, red_traces_mean, red_traces_sem, visinds):
246302
"""
247303
"""
248-
304+
ratio_traces = []
305+
red_traces = []
306+
time_window_in_sec = 10
307+
308+
time_window_xvals = np.arange(-time_window_in_sec*86,time_window_in_sec*86+1)/86
309+
310+
# Loop through each index in visinds
311+
for i in visinds:
312+
# Extract a time window of data around the current index
313+
start_idx = i - 86 * time_window_in_sec
314+
end_idx = i + 86 * time_window_in_sec
315+
ratio_trace = sampledata.loc[start_idx:end_idx, "ratio_z_scored"].values
316+
red_trace = sampledata.loc[start_idx:end_idx, "red_z_scored"].values
317+
318+
# Only use traces of the correct length (matching xvals)
319+
if len(ratio_trace) == len(time_window_xvals):
320+
ratio_traces.append(ratio_trace)
321+
322+
if len(red_trace) == len(time_window_xvals):
323+
red_traces.append(red_trace)
324+
325+
# Calculate mean and SEM
326+
if ratio_traces:
327+
ratio_traces_mean = np.mean(ratio_traces, axis=0)
328+
ratio_traces_sem = np.std(ratio_traces, axis=0) / np.sqrt(len(ratio_traces))
329+
else: # If no traces, create empty arrays
330+
ratio_traces_rwd_mean = np.zeros_like(time_window_xvals)
331+
ratio_traces_rwd_sem = np.zeros_like(time_window_xvals)
332+
333+
# Calculate mean and SEM
334+
if red_traces:
335+
red_traces_mean = np.mean(red_traces, axis=0)
336+
red_traces_sem = np.std(red_traces, axis=0) / np.sqrt(len(red_traces))
337+
else: # If no traces, create empty arrays
338+
red_traces_mean = np.zeros_like(time_window_xvals)
339+
red_traces_sem = np.zeros_like(time_window_xvals)
249340
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(19, 10))
250341

251342
plt.suptitle(f'Z-scored ratiometric 470/405 & 565 fluorescence signals {rat} {date} {time_window}sec', fontsize=16)
@@ -269,9 +360,73 @@ def plot_ratio_and_565_signals_aligned_port_entry(rat, date, time_window, time_w
269360

270361
return None
271362

272-
def plot_ratio_and_565_signals_aligned_port_entry_separated_by_rewarded_or_omitted(rat, date, time_window_in_sec, xvals, ratio_rwd_mean, ratio_rwd_sem, ratio_om_mean, ratio_om_sem, red_rwd_mean, red_rwd_sem, red_om_mean, red_om_sem, total_rewarded_trials, total_omitted_trials):
363+
def plot_ratio_and_565_signals_aligned_port_entry_separated_by_rewarded_or_omitted(rat, date, visinds, time_window_in_sec, xvals, ratio_rwd_mean, ratio_rwd_sem, ratio_om_mean, ratio_om_sem, red_rwd_mean, red_rwd_sem, red_om_mean, red_om_sem, total_rewarded_trials, total_omitted_trials):
273364
"""
274365
"""
366+
ratio_rwd_traces = []
367+
ratio_om_traces = []
368+
red_rwd_traces = []
369+
red_om_traces = []
370+
371+
time_window_in_sec = 10
372+
373+
time_window_xvals = np.arange(-time_window_in_sec*86,time_window_in_sec*86+1)/86
374+
375+
# Loop through each index in visinds
376+
for i in visinds:
377+
# Extract a time window of data around the current index
378+
start_idx = i - 86 * time_window_in_sec
379+
end_idx = i + 86 * time_window_in_sec
380+
ratio_trace = sampledata.loc[start_idx:end_idx, "ratio_z_scored"].values
381+
red_trace = sampledata.loc[start_idx:end_idx, "red_z_scored"].values
382+
383+
# Only use traces of the correct length (matching xvals)
384+
if len(ratio_trace) == len(time_window_xvals):
385+
# Check if the trial is rewarded or omitted
386+
if sampledata.loc[i, 'rwd'] == 1:
387+
ratio_rwd_traces.append(ratio_trace) # Add to rewarded traces
388+
else:
389+
ratio_om_traces.append(ratio_trace) # Add to omitted traces
390+
391+
# Only use traces of the correct length (matching xvals)
392+
if len(red_trace) == len(time_window_xvals):
393+
# Check if the trial is rewarded or omitted
394+
if sampledata.loc[i, 'rwd'] == 1:
395+
red_rwd_traces.append(red_trace) # Add to rewarded traces
396+
else:
397+
red_om_traces.append(red_trace) # Add to omitted traces
398+
399+
# Calculate mean and SEM for rewarded traces
400+
if ratio_rwd_traces: # If there are any rewarded traces
401+
ratio_rwd_mean = np.mean(ratio_rwd_traces, axis=0)
402+
ratio_rwd_sem = np.std(ratio_rwd_traces, axis=0) / np.sqrt(len(ratio_rwd_traces))
403+
else: # If no rewarded traces, create empty arrays
404+
ratio_rwd_mean = np.zeros_like(time_window_xvals)
405+
ratio_rwd_sem = np.zeros_like(time_window_xvals)
406+
407+
# Calculate mean and SEM for rewarded traces
408+
if red_rwd_traces: # If there are any rewarded traces
409+
red_rwd_mean = np.mean(red_rwd_traces, axis=0)
410+
red_rwd_sem = np.std(red_rwd_traces, axis=0) / np.sqrt(len(red_rwd_traces))
411+
else: # If no rewarded traces, create empty arrays
412+
red_rwd_mean = np.zeros_like(time_window_xvals)
413+
red_rwd_sem = np.zeros_like(time_window_xvals)
414+
415+
# Calculate mean and SEM for omitted traces
416+
if ratio_om_traces: # If there are any omitted traces
417+
ratio_om_mean = np.mean(ratio_om_traces, axis=0)
418+
ratio_om_sem = np.std(ratio_om_traces, axis=0) / np.sqrt(len(ratio_om_traces))
419+
else: # If no omitted traces, create empty arrays
420+
ratio_om_mean = np.zeros_like(time_window_xvals)
421+
ratio_om_sem = np.zeros_like(time_window_xvals)
422+
423+
# Calculate mean and SEM for omitted traces
424+
if red_om_traces: # If there are any omitted traces
425+
red_om_mean = np.mean(red_om_traces, axis=0)
426+
red_om_sem = np.std(red_om_traces, axis=0) / np.sqrt(len(red_om_traces))
427+
else: # If no omitted traces, create empty arrays
428+
red_om_mean = np.zeros_like(time_window_xvals)
429+
red_om_sem = np.zeros_like(time_window_xvals)
275430

276431
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(19, 10))
277432

0 commit comments

Comments
 (0)