Skip to content

Commit 1a4f841

Browse files
Copilotyangsunhwang
andcommitted
Add plot_probability_matching function for trial data visualization
Co-authored-by: yangsunhwang <[email protected]>
1 parent b3f5956 commit 1a4f841

File tree

2 files changed

+149
-1
lines changed

2 files changed

+149
-1
lines changed

src/jdb_to_nwb/convert_behavior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from hdmf.common.table import DynamicTable, VectorData
99
from ndx_franklab_novela import AssociatedFiles
1010
from .timestamps_alignment import trim_sync_pulses, handle_timestamps_reset
11-
from .plotting.plot_behavior import plot_maze_configurations, plot_trial_time_histogram
11+
from .plotting.plot_behavior import plot_maze_configurations, plot_trial_time_histogram, plot_probability_matching
1212

1313

1414
def load_maze_configurations(maze_configuration_file_path: Path):

src/jdb_to_nwb/plotting/plot_behavior.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import numpy as np
23
import matplotlib.pyplot as plt
34
from hexmaze import plot_hex_maze
45

@@ -59,3 +60,150 @@ def plot_maze_configurations(block_data, fig_dir=None):
5960
fig2.savefig(os.path.join(fig_dir, "maze_configurations_with_optimal_paths.png"), dpi=300, bbox_inches="tight")
6061
plt.close(fig1)
6162
plt.close(fig2)
63+
64+
65+
def plot_probability_matching(trial_data, block_data, fig_dir=None):
66+
"""
67+
Plot within-session probability matching behavior.
68+
69+
Shows:
70+
- Rolling average of port choice frequencies over trials
71+
- Reward delivery events at each port (as bars)
72+
- Block transitions and reward probabilities
73+
74+
Parameters
75+
----------
76+
trial_data : list of dict
77+
List of trial dictionaries from parse_arduino_text.
78+
Each dict should have keys: 'end_port', 'reward', 'block', 'trial_within_session'
79+
block_data : list of dict
80+
List of block dictionaries from parse_arduino_text.
81+
Each dict should have keys: 'block', 'pA', 'pB', 'pC'
82+
fig_dir : str or None
83+
Directory to save the figure. If None, figure is not saved.
84+
85+
Returns
86+
-------
87+
fig : matplotlib.figure.Figure
88+
The created figure object
89+
"""
90+
# Create mapping from port letters to indices
91+
port_map = {'A': 0, 'B': 1, 'C': 2}
92+
93+
# Extract data from trial_data list
94+
num_trials = len(trial_data)
95+
x1 = np.linspace(0, num_trials, num_trials)
96+
97+
# Create arrays for reward events at each port
98+
yA1 = np.zeros(num_trials)
99+
yB1 = np.zeros(num_trials)
100+
yC1 = np.zeros(num_trials)
101+
102+
# Create arrays for port choice indicators
103+
choose_A = np.zeros(num_trials)
104+
choose_B = np.zeros(num_trials)
105+
choose_C = np.zeros(num_trials)
106+
107+
# Process each trial
108+
for i, trial in enumerate(trial_data):
109+
end_port = trial['end_port']
110+
reward = trial['reward']
111+
112+
# Mark reward events (offset by +2 for visibility)
113+
if reward == 1:
114+
if end_port == 'A':
115+
yA1[i] = reward + 2
116+
elif end_port == 'B':
117+
yB1[i] = reward + 2
118+
elif end_port == 'C':
119+
yC1[i] = reward + 2
120+
121+
# Mark port choices
122+
if end_port == 'A':
123+
choose_A[i] = 1
124+
elif end_port == 'B':
125+
choose_B[i] = 1
126+
elif end_port == 'C':
127+
choose_C[i] = 1
128+
129+
# Calculate rolling window averages for port choice frequency
130+
window = 10
131+
132+
# Use pandas-like rolling calculation with numpy
133+
def rolling_mean(arr, window):
134+
"""Calculate rolling mean with min_periods=1"""
135+
result = np.zeros(len(arr))
136+
for i in range(len(arr)):
137+
start_idx = max(0, i - window + 1)
138+
result[i] = np.mean(arr[start_idx:i+1])
139+
return result
140+
141+
freq_A = rolling_mean(choose_A, window)
142+
freq_B = rolling_mean(choose_B, window)
143+
freq_C = rolling_mean(choose_C, window)
144+
145+
# Create figure
146+
fig = plt.figure(figsize=(18, 12))
147+
plt.suptitle('Within-Session Probability Matching', fontweight='bold', fontsize=26)
148+
149+
# Main plot for port visit frequencies
150+
ax4 = plt.subplot2grid((18, 1), (3, 0), colspan=1, rowspan=15)
151+
ax4.plot(x1, freq_A, label='A', alpha=0.8, color='blue')
152+
ax4.plot(x1, freq_B, label='B', alpha=0.8, color='orange')
153+
ax4.plot(x1, freq_C, label='C', alpha=0.8, color='green')
154+
ax4.set_ylabel('Port Visits/trial', fontsize=20, fontweight='bold')
155+
ax4.set_ylim(0, 0.7)
156+
ax4.legend(bbox_to_anchor=(0.9, 1.4), loc=2, borderaxespad=0.)
157+
158+
# Add block transition lines and probability labels
159+
for i, block in enumerate(block_data):
160+
block_num = block['block']
161+
# Get trials for this block
162+
block_trials = [t for t in trial_data if t['block'] == block_num]
163+
164+
if block_trials:
165+
# Get the last trial index of this block
166+
last_trial_idx = block_trials[-1]['trial_within_session'] - 1
167+
first_trial_idx = block_trials[0]['trial_within_session'] - 1
168+
169+
# Calculate midpoint for text placement
170+
xmid = int(np.mean([first_trial_idx, last_trial_idx]))
171+
172+
# Draw vertical line at block boundary (after last trial)
173+
if block_num < len(block_data): # Don't draw line after last block
174+
xstart = last_trial_idx + 1
175+
if i == 0:
176+
ax4.axvline(x=xstart, color='r', linestyle='--', label='Block Change')
177+
else:
178+
ax4.axvline(x=xstart, color='r', linestyle='--')
179+
180+
# Add probability text labels
181+
plt.text(xmid - 12, 8, str(int(block['pA'])) + ': ',
182+
fontsize='xx-large', fontweight='bold', color='b', transform=ax4.transData)
183+
plt.text(xmid, 8, str(int(block['pB'])) + ': ',
184+
fontsize='xx-large', fontweight='bold', color='orange', transform=ax4.transData)
185+
plt.text(xmid + 12, 8, str(int(block['pC'])),
186+
fontsize='xx-large', fontweight='bold', color='g', transform=ax4.transData)
187+
188+
ax4.legend()
189+
190+
# Top subplot: Rewards at port A (blue bars)
191+
ax1 = plt.subplot2grid((18, 1), (0, 0), colspan=1, rowspan=1, sharex=ax4)
192+
ax1.bar(x1, yA1, color='blue')
193+
ax1.axis('off')
194+
195+
# Middle subplot: Rewards at port B (orange bars)
196+
ax2 = plt.subplot2grid((18, 1), (1, 0), colspan=1, rowspan=1, sharex=ax4)
197+
ax2.bar(x1, yB1, color='orange')
198+
ax2.axis('off')
199+
200+
# Bottom subplot: Rewards at port C (green bars)
201+
ax3 = plt.subplot2grid((18, 1), (2, 0), colspan=1, rowspan=1, sharex=ax4)
202+
ax3.bar(x1, yC1, color='g')
203+
ax3.axis('off')
204+
205+
if fig_dir:
206+
plt.savefig(os.path.join(fig_dir, "probability_matching.png"), dpi=300, bbox_inches="tight")
207+
plt.close(fig)
208+
209+
return fig

0 commit comments

Comments
 (0)