Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions extensions/common_trials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""SpyGlass extension for extracting and storing trial intervals from NWB files.

This module defines a DataJoint table for storing trial intervals (start/stop times) extracted from NWB files.
Each entry corresponds to a single trial from the NWB file's trials table, identified by trial index.
"""

import datajoint as dj
import numpy as np
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common import Session
from spyglass.utils import SpyglassMixin
from spyglass.utils.nwb_helper_fn import get_nwb_file

schema = dj.schema("common_trial")


@schema
class Trials(SpyglassMixin, dj.Imported):
"""Table for storing trial intervals from NWB files.

Each entry represents a single trial, identified by its index (trial_id) in the NWB file's trials table,
and stores the corresponding start and stop times as a numpy array.

Primary key:
- nwb_file_name (from Session)
- trial_id (trial index)

Attributes:
valid_times: numpy array of shape (1, 2) with start and stop time for the trial.
pipeline: optional string for interval list type (default empty).
"""

definition = """
# Time intervals used for analysis
-> Session
trial_id: int # trial index from NWB file
---
valid_times: longblob # numpy array with start/stop times for this trial
pipeline = "": varchar(64) # type of interval list
"""

def make(self, key):
"""Extract trial intervals from the NWB file and insert them into the table.

For each trial in the NWB file's trials table, insert a row with the trial index and its start/stop times.

Parameters
----------
key : dict
Must contain 'nwb_file_name' specifying the NWB file to process.

Notes
-----
- Only inserts if the NWB file contains a trials table.
- Each trial is stored as a separate row, with valid_times as a (1, 2) numpy array.
"""
nwb_file_name = key["nwb_file_name"]
nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)

if nwbf.trials is None:
return

trials = nwbf.trials.to_dataframe()

inserts = trials.apply(
lambda row: {
"nwb_file_name": nwb_file_name,
"trial_id": row.name,
"valid_times": np.asarray(
[[row.start_time, row.stop_time]]
),
},
axis=1,
).tolist()

self.insert(inserts, allow_direct_insert=True, skip_duplicates=True)
112 changes: 112 additions & 0 deletions insert/insert_trials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Functions for inserting NWB session and trial interval data into a Spyglass database.

This module provides functions to insert session and trial interval data from NWB files into a Spyglass database.
It connects to a DataJoint database using configuration from a local file and provides:

1. Insertion of session data using Spyglass's built-in tools.
2. Extraction and insertion of trial intervals using a custom Trials table.

Example usage is provided in the __main__ section, which demonstrates how to:
- Check if an NWB file exists
- Clear existing database tables if needed
- Insert session and trial data
- Query and display the inserted data
"""

import sys
from pathlib import Path

import datajoint as dj
from numpy.testing import assert_array_equal
from pynwb import NWBHDF5IO

dj_local_conf_path = "/Users/weian/catalystneuro/pagan-lab-to-nwb/src/pagan_lab_to_nwb/spyglass_mock/dj_local_conf.json"
dj.config.load(dj_local_conf_path) # load config for database connection info

dj.conn(use_tls=False)

# spyglass.common has the most frequently used tables
import spyglass.common as sgc # this import connects to the database

# spyglass.data_import has tools for inserting NWB files into the database
import spyglass.data_import as sgi
from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename

# Custom Table Imports
sys.path.append("/Users/weian/catalystneuro/neuroconv-spyglass/extensions")
from extensions.common_trials import Trials


def insert_session(nwbfile_path: Path):
"""Insert session and trial interval data from an NWB file into the Spyglass database.

This function:
- Removes any existing session with the same NWB file name.
- Inserts session metadata using Spyglass's insert_sessions().
- Inserts trial intervals using the custom Trials table.

Parameters
----------
nwbfile_path : Path
Path to the NWB file containing the session and trial information.
"""
if not nwbfile_path.exists():
raise FileNotFoundError(f"NWB file does not exist: {nwbfile_path}.")

nwb_copy_file_name = get_nwb_copy_filename(nwbfile_path.name)
# this removes all tables from the database
sgc_nwbfile = sgc.Nwbfile & {"nwb_file_name": nwb_copy_file_name}
sgc_nwbfile.delete()

sgi.insert_sessions(str(nwbfile_path), rollback_on_fail=True, raise_err=True)
insert_trials(nwbfile_path=nwbfile_path)


def insert_trials(nwbfile_path: Path):
"""Insert trial intervals from the NWB file into the Spyglass database.

This function extracts trial start/stop times from the NWB file and inserts them
into the custom Trials table.

Parameters
----------
nwbfile_path : Path
Path to the NWB file containing the trial intervals.
"""
nwb_copy_file_name = get_nwb_copy_filename(nwbfile_path.name)
Trials().make(key={"nwb_file_name": nwb_copy_file_name})


def test_trials(nwbfile_path: Path):
"""Test that trial intervals in the database match those in the NWB file.

For each trial, fetch the stored valid_times from the Trials table and compare
to the corresponding start/stop times in the NWB file.
"""
with NWBHDF5IO(nwbfile_path, "r") as io:
nwbfile = io.read()
trials = nwbfile.trials.to_dataframe()
trials = trials[["start_time", "stop_time"]].to_numpy()
nwb_copy_file_name = get_nwb_copy_filename(nwbfile_path.name)
for i in range(len(trials)):
trials_from_spyglass = (Trials() & {"nwb_file_name": nwb_copy_file_name, "trial_id": i}).fetch1("valid_times")
assert_array_equal(trials[i], trials_from_spyglass.squeeze())


def main():
nwbfile_path = Path("/Volumes/T9/data/Pagan/raw/mock_trials.nwb")
insert_session(nwbfile_path=nwbfile_path)

nwb_copy_file_name = get_nwb_copy_filename(nwbfile_path.name)
print("=== Session ===")
print(sgc.Session & {"nwb_file_name": nwb_copy_file_name})
print("=== NWB File ===")
print(sgc.Nwbfile & {"nwb_file_name": nwb_copy_file_name})
print("=== Trials ===")
print(Trials() & {"nwb_file_name": nwb_copy_file_name})

test_trials(nwbfile_path=nwbfile_path)

if __name__ == "__main__":
main()
print("Done!")
22 changes: 22 additions & 0 deletions mock/mock_trials_nwbfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path

from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile


def main():
nwbfile = mock_NWBFile(identifier="my_identifier", session_description="my_session_description")
nwbfile.add_trial(start_time=0.0, stop_time=1.0, tags=["trial_01"])
nwbfile.add_trial(start_time=1.0, stop_time=2.0, tags=["trial_02"])

# add processing module to make spyglass happy
nwbfile.create_processing_module(name="behavior", description="dummy behavior module")
nwbfile_path = Path("/Volumes/T9/data/Pagan/raw/mock_trials.nwb")
if nwbfile_path.exists():
nwbfile_path.unlink()
with NWBHDF5IO(nwbfile_path, "w") as io:
io.write(nwbfile)
print(f"mock trials NWB file successfully written at {nwbfile_path}")

if __name__ == "__main__":
main()