Skip to content

Conversation

Copy link

Copilot AI commented Nov 7, 2025

The existing plot_probability_matching function expected pandas DataFrame structure with columns like port, rwd, block, etc. After refactoring, parse_arduino_text now returns trial data as a list of dictionaries.

Changes

  • New plotting function (src/jdb_to_nwb/plotting/plot_behavior.py)

    • Accepts trial_data (list of dicts) and block_data (list of dicts) from parse_arduino_text
    • Implements rolling window calculations using NumPy instead of pandas
    • Generates 4-panel figure: reward bars for ports A/B/C + rolling choice frequencies with block transitions
  • Test coverage (tests/test_convert_behavior.py)

    • Validates plotting with real test data (188 trials, 3 blocks)
    • Verifies file output and figure creation
  • Example usage (examples/plot_probability_matching_example.py)

    • Demonstrates integration with parse_arduino_text workflow

Usage

from jdb_to_nwb.convert_behavior import parse_arduino_text
from jdb_to_nwb.plotting.plot_behavior import plot_probability_matching

trial_data, block_data = parse_arduino_text(arduino_text, arduino_timestamps, logger)
fig = plot_probability_matching(trial_data, block_data, fig_dir="./output")

Trial dict structure:

{
    'end_port': 'B',           # Port visited (A/B/C)
    'reward': 1,               # 1 if rewarded, 0 otherwise
    'block': 1,                # Block number
    'trial_within_session': 3  # Trial index
}
Original prompt

def plot_probability_matching(trial_data):
x1 = np.linspace(0,len(trial_data),len(trial_data))
A1 = trial_data.rwd.loc[trial_data.port==0] + 2
yA1 = np.zeros(len(trial_data))
yA1[A1.index.values] = A1
B1 = trial_data.rwd.loc[trial_data.port==1] + 2
yB1 = np.zeros(len(trial_data))
yB1[B1.index.values] = B1
C1 = trial_data.rwd.loc[trial_data.port==2] + 2
yC1 = np.zeros(len(trial_data))
yC1[C1.index.values] = C1

#frequency of each port visit over time
window = 10
#wintype = 'boxcar'

trial_data['ChooseA'] = np.where(trial_data.port==0,1,0)
trial_data['ChooseB'] = np.where(trial_data.port==1,1,0)
trial_data['ChooseC'] = np.where(trial_data.port==2,1,0)

bdt = plt.figure(figsize = (18,12))
plt.suptitle('Within-Session Probability Matching',fontweight='bold',fontsize = 26)
ax4 = plt.subplot2grid((18,1),(3,0),colspan = 1, rowspan =15)
ax4.plot(x1,trial_data.ChooseA.rolling(window,min_periods=1).sum().divide(window),label = 'A',alpha=0.8)
ax4.plot(x1,trial_data.ChooseB.rolling(window,min_periods=1).sum().divide(window),label = 'B',alpha=0.8)
ax4.plot(x1,trial_data.ChooseC.rolling(window,min_periods=1).sum().divide(window),label = 'C',alpha=0.8)
ax4.set_ylabel('Port Visits/trial',fontsize=20,fontweight='bold')
ax4.set_ylim(0,.7)
#ax4.set_xlabel('Time (min)',fontsize=20,fontweight='bold')
ax4.legend(bbox_to_anchor=(0.9, 1.4), loc=2, borderaxespad=0.)
ax1 = plt.subplot2grid((18,1),(0,0),colspan = 1, rowspan =1,sharex=ax4)
ax1.bar(x1,yA1,color = 'blue')
ax1.axis('off')
ax2 = plt.subplot2grid((18,1),(1,0),colspan = 1, rowspan =1,sharex=ax4)
for b in trial_data.block.unique():
    xstart = int((max(trial_data.loc[trial_data.block==b].index)+1))
    ind = max(trial_data.loc[trial_data.block==b].index)
    xmin = int(min(trial_data.loc[trial_data.block==b].index))
    xmid = int(np.mean(trial_data.loc[trial_data.block==b].index))#int(xmin+(xstart-xmin)/2)
    print(xmid)
    if b==1:
        ax4.axvline(x=xstart,color ='r',linestyle='--', label = 'Block Change')
    else:
        ax4.axvline(x=xstart,color ='r',linestyle='--')
    plt.text(xmid-12,8,str(int(trial_data.pA[ind-1]))+': ',fontsize='xx-large',fontweight='bold',color = 'b')
    plt.text(xmid,8,str(int(trial_data.pB[ind-1]))+': ',fontsize='xx-large',fontweight='bold',color = 'orange')
    plt.text(xmid+12,8,str(int(trial_data.pC[ind-1])),fontsize='xx-large',fontweight='bold',color = 'g')
ax4.legend()
ax2.bar(x1,yB1,color='orange')
ax2.axis('off')
ax3 = plt.subplot2grid((18,1),(2,0),colspan = 1, rowspan =1,sharex=ax4)
ax3.bar(x1,yC1,color='g')
ax3.axis('off')
return bdt

I AM Trying to use 'parse_arduino_text' and its 'trial_data', and make probability plot that I used previously with different dataframe structure.
trial_data[i] is looking like this
{'start_time': 41203791.308800004,
'beam_break_start': 41221720.832,
'start_port': 'C',
'end_port': 'B',
'trial_within_block': 3,
'trial_within_session': 3,
'block': 1,
'reward': 1,
'beam_break_end': 41221720.832,
'end_time': 41221720.832}


💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.

Copilot AI changed the title [WIP] Add plot for within-session probability matching Add plot_probability_matching function for list-based trial data Nov 7, 2025
Copilot AI requested a review from yangsunhwang November 7, 2025 21:30
Copilot finished work on behalf of yangsunhwang November 7, 2025 21:30
@yangsunhwang
Copy link
Collaborator

@copilot ---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[22], line 1
----> 1 fig = plot_probability_matching(trial_data, block_data, fig_dir="./output")

Cell In[21], line 82, in plot_probability_matching(trial_data, block_data, fig_dir)
79 freq_C = rolling_mean(choose_C, window)
81 # Create figure
---> 82 fig = plt.figure(figsize=(18, 12))
83 plt.suptitle('Within-Session Probability Matching', fontweight='bold', fontsize=26)
85 # Main plot for port visit frequencies

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\pyplot.py:1042, in figure(num, figsize, dpi, facecolor, edgecolor, frameon, FigureClass, clear, **kwargs)
1032 if len(allnums) == max_open_warning >= 1:
1033 _api.warn_external(
1034 f"More than {max_open_warning} figures have been opened. "
1035 f"Figures created through the pyplot interface "
(...) 1039 f"Consider using matplotlib.pyplot.close().",
1040 RuntimeWarning)
-> 1042 manager = new_figure_manager(
1043 num, figsize=figsize, dpi=dpi,
1044 facecolor=facecolor, edgecolor=edgecolor, frameon=frameon,
1045 FigureClass=FigureClass, **kwargs)
1046 fig = manager.canvas.figure
1047 if fig_label:

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\pyplot.py:551, in new_figure_manager(*args, **kwargs)
549 def new_figure_manager(*args, **kwargs):
550 """Create a new figure manager instance."""
--> 551 _warn_if_gui_out_of_main_thread()
552 return _get_backend_mod().new_figure_manager(*args, **kwargs)

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\pyplot.py:528, in _warn_if_gui_out_of_main_thread()
526 def _warn_if_gui_out_of_main_thread() -> None:
527 warn = False
--> 528 canvas_class = cast(type[FigureCanvasBase], _get_backend_mod().FigureCanvas)
529 if canvas_class.required_interactive_framework:
530 if hasattr(threading, 'get_native_id'):
531 # This compares native thread ids because even if Python-level
532 # Thread objects match, the underlying OS thread (which is what
533 # really matters) may be different on Python implementations with
534 # green threads.

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\pyplot.py:369, in _get_backend_mod()
360 """
361 Ensure that a backend is selected and return it.
362
363 This is currently private, but may be made public in the future.
364 """
365 if _backend_mod is None:
366 # Use rcParams._get("backend") to avoid going through the fallback
367 # logic (which will (re)import pyplot and then call switch_backend if
368 # we need to resolve the auto sentinel)
--> 369 switch_backend(rcParams._get("backend"))
370 return cast(type[matplotlib.backend_bases._Backend], _backend_mod)

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\pyplot.py:425, in switch_backend(newbackend)
422 return
423 old_backend = rcParams._get('backend') # get without triggering backend resolution
--> 425 module = backend_registry.load_backend_module(newbackend)
426 canvas_class = module.FigureCanvas
428 required_framework = canvas_class.required_interactive_framework

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\backends\registry.py:317, in BackendRegistry.load_backend_module(self, backend)
303 """
304 Load and return the module containing the specified backend.
305
(...) 314 Module containing backend.
315 """
316 module_name = self._backend_module_name(backend)
--> 317 return importlib.import_module(module_name)

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\importlib_init_.py:90, in import_module(name, package)
88 break
89 level += 1
---> 90 return _bootstrap._gcd_import(name[level:], package, level)

File :1387, in _gcd_import(name, package, level)

File :1360, in find_and_load(name, import)

File :1331, in find_and_load_unlocked(name, import)

File :935, in _load_unlocked(spec)

File :995, in exec_module(self, module)

File :488, in _call_with_frames_removed(f, *args, **kwds)

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\backends\backend_qtagg.py:9
5 import ctypes
7 from matplotlib.transforms import Bbox
----> 9 from .qt_compat import QT_API, QtCore, QtGui
10 from .backend_agg import FigureCanvasAgg
11 from .backend_qt import _BackendQT, FigureCanvasQT

File c:\Users\yangsunhwang\anaconda3\envs\jdb_to_nwb\Lib\site-packages\matplotlib\backends\qt_compat.py:130
128 break
129 else:
--> 130 raise ImportError(
131 "Failed to import any of the following Qt binding modules: {}"
132 .format(", ".join([QT_API for _, QT_API in _candidates]))
133 )
134 else: # We should not get there.
135 raise AssertionError(f"Unexpected QT_API: {QT_API}")

ImportError: Failed to import any of the following Qt binding modules: PyQt6, PySide6, PyQt5, PySide2

not sure we need this qt binding modules?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants