Skip to content

Commit cc1955c

Browse files
committed
fixes for np fromfile
1 parent fbf92a8 commit cc1955c

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

src/trodes_to_nwb/convert_position.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ def read_trodes_datafile(filename: Path) -> dict:
172172
return fields_text
173173

174174

175+
def convert_datafile_to_pandas(datafile) -> pd.DataFrame:
176+
"""Takes the output of read_trodes_datafile and converts it to a pandas dataframe.
177+
Added for changes identified in numpy 2.2.2
178+
"""
179+
return pd.DataFrame(
180+
{key: np.squeeze(datafile["data"][key]) for key in datafile["data"].dtype.names}
181+
)
182+
183+
175184
def get_framerate(timestamps: np.ndarray) -> float:
176185
"""
177186
Calculates the framerate of a video based on the timestamps of each frame.
@@ -472,15 +481,12 @@ def get_video_timestamps(video_timestamps_filepath: Path) -> np.ndarray:
472481
An array of video timestamps.
473482
"""
474483
# Get video timestamps
484+
video_timestamps = read_trodes_datafile(video_timestamps_filepath)["data"]
475485
video_timestamps = (
476-
pd.DataFrame(read_trodes_datafile(video_timestamps_filepath)["data"])
477-
.set_index("PosTimestamp")
478-
.rename(columns={"frameCount": "HWframeCount"})
479-
)
480-
return (
481-
np.asarray(video_timestamps.HWTimestamp, dtype=np.float64)
486+
np.squeeze(video_timestamps["HWTimestamp"]).astype(np.float64)
482487
/ NANOSECONDS_PER_SECOND
483488
)
489+
return video_timestamps
484490

485491

486492
def get_position_timestamps(
@@ -519,8 +525,9 @@ def get_position_timestamps(
519525
logger = logging.getLogger("convert")
520526

521527
# Get video timestamps
528+
datafile = read_trodes_datafile(position_timestamps_filepath)
522529
video_timestamps = (
523-
pd.DataFrame(read_trodes_datafile(position_timestamps_filepath)["data"])
530+
convert_datafile_to_pandas(datafile)
524531
.set_index("PosTimestamp")
525532
.rename(columns={"frameCount": "HWframeCount"})
526533
)
@@ -549,7 +556,7 @@ def get_position_timestamps(
549556
# Get position tracking information
550557
try:
551558
position_tracking = pd.DataFrame(
552-
read_trodes_datafile(position_tracking_filepath)["data"]
559+
convert_datafile_to_pandas(read_trodes_datafile(position_tracking_filepath))
553560
).set_index("time")
554561
is_repeat_timestamp = detect_repeat_timestamps(position_tracking.index)
555562
position_tracking = position_tracking.iloc[~is_repeat_timestamp]
@@ -1005,7 +1012,7 @@ def add_associated_video_files(
10051012
]
10061013
].full_path.to_list()[0]
10071014
# get the timestamps
1008-
video_timestamps = get_video_timestamps(video_timestamps_filepath)
1015+
video_timestamps = np.squeeze(get_video_timestamps(video_timestamps_filepath))
10091016

10101017
if convert_video:
10111018
video_file_name = convert_h264_to_mp4(video_path, video_directory)

src/trodes_to_nwb/tests/test_convert_position.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from trodes_to_nwb.convert_intervals import add_epochs, add_sample_count
1313
from trodes_to_nwb.convert_position import (
1414
add_position,
15+
convert_datafile_to_pandas,
1516
correct_timestamps_for_camera_to_mcu_lag,
1617
detect_repeat_timestamps,
1718
detect_trodes_time_repeats_or_frame_jumps,
@@ -79,7 +80,7 @@ def test_read_trodes_datafile_correct_settings(tmp_path):
7980
result = read_trodes_datafile(filename)
8081
assert result["clock rate"] == "30000"
8182

82-
expected_data = pd.DataFrame(result["data"])
83+
expected_data = convert_datafile_to_pandas(result)
8384
assert expected_data["field1"].dtype == np.uint32
8485
assert expected_data["field2"].dtype == np.int32
8586
assert np.array_equal(expected_data.field1, np.array([1, 3], dtype=np.uint32))

0 commit comments

Comments
 (0)