Skip to content

Commit 7edcb0e

Browse files
Copilotyangsunhwang
andcommitted
Fix code review issues: block line logic, text positioning, and figure closing
Co-authored-by: yangsunhwang <[email protected]>
1 parent 0217770 commit 7edcb0e

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/jdb_to_nwb/plotting/plot_behavior.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,24 @@ def rolling_mean(arr, window):
170170
xmid = int(np.mean([first_trial_idx, last_trial_idx]))
171171

172172
# Draw vertical line at block boundary (after last trial)
173-
if block_num < len(block_data): # Don't draw line after last block
173+
if i < len(block_data) - 1: # Don't draw line after last block
174174
xstart = last_trial_idx + 1
175175
if i == 0:
176176
ax4.axvline(x=xstart, color='r', linestyle='--', label='Block Change')
177177
else:
178178
ax4.axvline(x=xstart, color='r', linestyle='--')
179179

180-
# Add probability text labels
181-
plt.text(xmid - 12, 8, str(int(block['pA'])) + ': ',
182-
fontsize='xx-large', fontweight='bold', color='b', transform=ax4.transData)
183-
plt.text(xmid, 8, str(int(block['pB'])) + ': ',
184-
fontsize='xx-large', fontweight='bold', color='orange', transform=ax4.transData)
185-
plt.text(xmid + 12, 8, str(int(block['pC'])),
186-
fontsize='xx-large', fontweight='bold', color='g', transform=ax4.transData)
180+
# Add probability text labels using axes coordinates for y position
181+
# Place text at 90% of the y-axis height
182+
plt.text(xmid - 12, 0.9, str(int(block['pA'])) + ': ',
183+
fontsize='xx-large', fontweight='bold', color='b',
184+
transform=ax4.get_xaxis_transform())
185+
plt.text(xmid, 0.9, str(int(block['pB'])) + ': ',
186+
fontsize='xx-large', fontweight='bold', color='orange',
187+
transform=ax4.get_xaxis_transform())
188+
plt.text(xmid + 12, 0.9, str(int(block['pC'])),
189+
fontsize='xx-large', fontweight='bold', color='g',
190+
transform=ax4.get_xaxis_transform())
187191

188192
ax4.legend()
189193

@@ -204,6 +208,5 @@ def rolling_mean(arr, window):
204208

205209
if fig_dir:
206210
plt.savefig(os.path.join(fig_dir, "probability_matching.png"), dpi=300, bbox_inches="tight")
207-
plt.close(fig)
208211

209212
return fig

0 commit comments

Comments
 (0)