@@ -170,20 +170,24 @@ def rolling_mean(arr, window):
170170 xmid = int (np .mean ([first_trial_idx , last_trial_idx ]))
171171
172172 # Draw vertical line at block boundary (after last trial)
173- if block_num < len (block_data ): # Don't draw line after last block
173+ if i < len (block_data ) - 1 : # Don't draw line after last block
174174 xstart = last_trial_idx + 1
175175 if i == 0 :
176176 ax4 .axvline (x = xstart , color = 'r' , linestyle = '--' , label = 'Block Change' )
177177 else :
178178 ax4 .axvline (x = xstart , color = 'r' , linestyle = '--' )
179179
180- # Add probability text labels
181- plt .text (xmid - 12 , 8 , str (int (block ['pA' ])) + ': ' ,
182- fontsize = 'xx-large' , fontweight = 'bold' , color = 'b' , transform = ax4 .transData )
183- plt .text (xmid , 8 , str (int (block ['pB' ])) + ': ' ,
184- fontsize = 'xx-large' , fontweight = 'bold' , color = 'orange' , transform = ax4 .transData )
185- plt .text (xmid + 12 , 8 , str (int (block ['pC' ])),
186- fontsize = 'xx-large' , fontweight = 'bold' , color = 'g' , transform = ax4 .transData )
180+ # Add probability text labels using axes coordinates for y position
181+ # Place text at 90% of the y-axis height
182+ plt .text (xmid - 12 , 0.9 , str (int (block ['pA' ])) + ': ' ,
183+ fontsize = 'xx-large' , fontweight = 'bold' , color = 'b' ,
184+ transform = ax4 .get_xaxis_transform ())
185+ plt .text (xmid , 0.9 , str (int (block ['pB' ])) + ': ' ,
186+ fontsize = 'xx-large' , fontweight = 'bold' , color = 'orange' ,
187+ transform = ax4 .get_xaxis_transform ())
188+ plt .text (xmid + 12 , 0.9 , str (int (block ['pC' ])),
189+ fontsize = 'xx-large' , fontweight = 'bold' , color = 'g' ,
190+ transform = ax4 .get_xaxis_transform ())
187191
188192 ax4 .legend ()
189193
@@ -204,6 +208,5 @@ def rolling_mean(arr, window):
204208
205209 if fig_dir :
206210 plt .savefig (os .path .join (fig_dir , "probability_matching.png" ), dpi = 300 , bbox_inches = "tight" )
207- plt .close (fig )
208211
209212 return fig
0 commit comments