Skip to content

Commit a0a3b67

Browse files
authored
Merge pull request #104 from LorenFrankLab/np_upgrade
Fixes for numpy 2.0
2 parents fbf92a8 + f34e7ba commit a0a3b67

File tree

4 files changed

+32
-22
lines changed

4 files changed

+32
-22
lines changed

.github/workflows/test_package_build.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ jobs:
2929
run: python -m build
3030
- run: twine check dist/*
3131
- name: Upload sdist and wheel artifacts
32-
uses: actions/upload-artifact@v3
32+
uses: actions/upload-artifact@v4
3333
with:
3434
name: dist
3535
path: dist/
3636
- name: Build git archive
3737
run: mkdir archive && git archive -v -o archive/archive.tgz HEAD
3838
- name: Upload git archive artifact
39-
uses: actions/upload-artifact@v3
39+
uses: actions/upload-artifact@v4
4040
with:
4141
name: archive
4242
path: archive/
@@ -51,13 +51,13 @@ jobs:
5151
steps:
5252
- name: Download sdist and wheel artifacts
5353
if: matrix.package != 'archive'
54-
uses: actions/download-artifact@v3
54+
uses: actions/download-artifact@v4
5555
with:
5656
name: dist
5757
path: dist/
5858
- name: Download git archive artifact
5959
if: matrix.package == 'archive'
60-
uses: actions/download-artifact@v3
60+
uses: actions/download-artifact@v4
6161
with:
6262
name: archive
6363
path: archive/
@@ -117,7 +117,7 @@ jobs:
117117
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
118118
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
119119
steps:
120-
- uses: actions/download-artifact@v3
120+
- uses: actions/download-artifact@v4
121121
with:
122122
name: dist
123123
path: dist/

src/trodes_to_nwb/convert_position.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def parse_dtype(fieldstr: str) -> np.dtype:
9494
--------
9595
>>> fieldstr = '<time uint32><x float32><y float32><z float32>'
9696
>>> parse_dtype(fieldstr)
97-
dtype([('time', '<u4'), ('x', '<f4'), ('y', '<f4'), ('z', '<f4')])
97+
dtype([('time', '<u4', (1,)), ('x', '<f4', (1,)), ('y', '<f4', (1,)), ('z', '<f4', (1,))])
9898
9999
"""
100100
# Returns np.dtype from field string
@@ -172,6 +172,15 @@ def read_trodes_datafile(filename: Path) -> dict:
172172
return fields_text
173173

174174

175+
def convert_datafile_to_pandas(datafile) -> pd.DataFrame:
176+
"""Takes the output of read_trodes_datafile and converts it to a pandas dataframe.
177+
Added for changes identified in numpy 2.2.2
178+
"""
179+
return pd.DataFrame(
180+
{key: np.squeeze(datafile["data"][key]) for key in datafile["data"].dtype.names}
181+
)
182+
183+
175184
def get_framerate(timestamps: np.ndarray) -> float:
176185
"""
177186
Calculates the framerate of a video based on the timestamps of each frame.
@@ -472,15 +481,12 @@ def get_video_timestamps(video_timestamps_filepath: Path) -> np.ndarray:
472481
An array of video timestamps.
473482
"""
474483
# Get video timestamps
484+
video_timestamps = read_trodes_datafile(video_timestamps_filepath)["data"]
475485
video_timestamps = (
476-
pd.DataFrame(read_trodes_datafile(video_timestamps_filepath)["data"])
477-
.set_index("PosTimestamp")
478-
.rename(columns={"frameCount": "HWframeCount"})
479-
)
480-
return (
481-
np.asarray(video_timestamps.HWTimestamp, dtype=np.float64)
486+
np.squeeze(video_timestamps["HWTimestamp"]).astype(np.float64)
482487
/ NANOSECONDS_PER_SECOND
483488
)
489+
return video_timestamps
484490

485491

486492
def get_position_timestamps(
@@ -519,8 +525,9 @@ def get_position_timestamps(
519525
logger = logging.getLogger("convert")
520526

521527
# Get video timestamps
528+
datafile = read_trodes_datafile(position_timestamps_filepath)
522529
video_timestamps = (
523-
pd.DataFrame(read_trodes_datafile(position_timestamps_filepath)["data"])
530+
convert_datafile_to_pandas(datafile)
524531
.set_index("PosTimestamp")
525532
.rename(columns={"frameCount": "HWframeCount"})
526533
)
@@ -549,7 +556,7 @@ def get_position_timestamps(
549556
# Get position tracking information
550557
try:
551558
position_tracking = pd.DataFrame(
552-
read_trodes_datafile(position_tracking_filepath)["data"]
559+
convert_datafile_to_pandas(read_trodes_datafile(position_tracking_filepath))
553560
).set_index("time")
554561
is_repeat_timestamp = detect_repeat_timestamps(position_tracking.index)
555562
position_tracking = position_tracking.iloc[~is_repeat_timestamp]
@@ -1005,7 +1012,7 @@ def add_associated_video_files(
10051012
]
10061013
].full_path.to_list()[0]
10071014
# get the timestamps
1008-
video_timestamps = get_video_timestamps(video_timestamps_filepath)
1015+
video_timestamps = np.squeeze(get_video_timestamps(video_timestamps_filepath))
10091016

10101017
if convert_video:
10111018
video_file_name = convert_h264_to_mp4(video_path, video_directory)

src/trodes_to_nwb/spike_gadgets_raw_io.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,9 @@ def get_analogsignal_multiplexed(self, channel_names=None) -> np.ndarray:
560560
)
561561
# read the data into int16
562562
data = (
563-
self._raw_memmap[:, data_offsets[:, 0]]
564-
+ self._raw_memmap[:, data_offsets[:, 0] + 1] * INT_16_CONVERSION
563+
self._raw_memmap[:, data_offsets[:, 0]].astype(np.int16)
564+
+ self._raw_memmap[:, data_offsets[:, 0] + 1].astype(np.int16)
565+
* INT_16_CONVERSION
565566
)
566567
# initialize the first row
567568
analog_multiplexed_data[0] = data[0]
@@ -646,8 +647,8 @@ def get_analogsignal_multiplexed_partial(
646647
)
647648
# read the data into int16
648649
data = (
649-
self._raw_memmap[i_start:i_stop, data_offsets[:, 0]]
650-
+ self._raw_memmap[i_start:i_stop, data_offsets[:, 0] + 1]
650+
self._raw_memmap[i_start:i_stop, data_offsets[:, 0]].astype(np.int16)
651+
+ self._raw_memmap[i_start:i_stop, data_offsets[:, 0] + 1].astype(np.int16)
651652
* INT_16_CONVERSION
652653
)
653654
# initialize the first row
@@ -969,8 +970,9 @@ def get_analogsignal_multiplexed(self, channel_names=None) -> np.ndarray:
969970
)
970971
# read the data into int16
971972
data = (
972-
self._raw_memmap[:, data_offsets[:, 0]]
973-
+ self._raw_memmap[:, data_offsets[:, 0] + 1] * INT_16_CONVERSION
973+
self._raw_memmap[:, data_offsets[:, 0]].astype(np.int16)
974+
+ self._raw_memmap[:, data_offsets[:, 0] + 1].astype(np.int16)
975+
* INT_16_CONVERSION
974976
)
975977
# initialize the first row
976978
# if no previous state, assume first segment. Default to superclass behavior

src/trodes_to_nwb/tests/test_convert_position.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from trodes_to_nwb.convert_intervals import add_epochs, add_sample_count
1313
from trodes_to_nwb.convert_position import (
1414
add_position,
15+
convert_datafile_to_pandas,
1516
correct_timestamps_for_camera_to_mcu_lag,
1617
detect_repeat_timestamps,
1718
detect_trodes_time_repeats_or_frame_jumps,
@@ -79,7 +80,7 @@ def test_read_trodes_datafile_correct_settings(tmp_path):
7980
result = read_trodes_datafile(filename)
8081
assert result["clock rate"] == "30000"
8182

82-
expected_data = pd.DataFrame(result["data"])
83+
expected_data = convert_datafile_to_pandas(result)
8384
assert expected_data["field1"].dtype == np.uint32
8485
assert expected_data["field2"].dtype == np.int32
8586
assert np.array_equal(expected_data.field1, np.array([1, 3], dtype=np.uint32))

0 commit comments

Comments
 (0)