Skip to content

Commit 9a5ef79

Browse files
Fix DataJoint query errors with NaN values in probe geometry fields (#1346)
* Initial plan * Implement NaN value replacement fix for probe geometry fields Co-authored-by: samuelbray32 <[email protected]> * Complete NaN value replacement fix implementation Co-authored-by: samuelbray32 <[email protected]> * Move _replace_nan_with_default to utils and simplify tests Co-authored-by: samuelbray32 <[email protected]> * lint * import test function from fixtures --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: samuelbray32 <[email protected]> Co-authored-by: samuelbray32 <[email protected]>
1 parent fef4f64 commit 9a5ef79

File tree

4 files changed

+97
-25
lines changed

4 files changed

+97
-25
lines changed

src/spyglass/common/common_device.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from spyglass.common.errors import PopulateException
55
from spyglass.settings import test_mode
66
from spyglass.utils import SpyglassMixin, logger
7-
from spyglass.utils.dj_helper_fn import accept_divergence
7+
from spyglass.utils.dj_helper_fn import (
8+
accept_divergence,
9+
_replace_nan_with_default,
10+
)
811
from spyglass.utils.nwb_helper_fn import get_nwb_file
912

1013
schema = dj.schema("common_device")
@@ -563,15 +566,17 @@ def __read_ndx_probe_data(
563566
for electrode in shank.shanks_electrodes.values():
564567
# the next line will need to be fixed if we have different sized
565568
# contacts on a shank
566-
elect_dict[electrode.name] = {
567-
"probe_id": new_probe_dict["probe_id"],
568-
"probe_shank": shank_dict[shank.name]["probe_shank"],
569-
"contact_size": nwb_probe_obj.contact_size,
570-
"probe_electrode": int(electrode.name),
571-
"rel_x": electrode.rel_x,
572-
"rel_y": electrode.rel_y,
573-
"rel_z": electrode.rel_z,
574-
}
569+
elect_dict[electrode.name] = _replace_nan_with_default(
570+
{
571+
"probe_id": new_probe_dict["probe_id"],
572+
"probe_shank": shank_dict[shank.name]["probe_shank"],
573+
"contact_size": nwb_probe_obj.contact_size,
574+
"probe_electrode": int(electrode.name),
575+
"rel_x": electrode.rel_x,
576+
"rel_y": electrode.rel_y,
577+
"rel_z": electrode.rel_z,
578+
}
579+
)
575580

576581
@classmethod
577582
def _read_config_probe_data(
@@ -597,15 +602,17 @@ def _read_config_probe_data(
597602
"Electrode", []
598603
)
599604
for i, e in enumerate(elect_dict_list):
600-
elect_dict[str(i)] = {
601-
"probe_id": probe_id,
602-
"probe_shank": e["probe_shank"],
603-
"probe_electrode": e["probe_electrode"],
604-
"contact_size": e.get("contact_size"),
605-
"rel_x": e.get("rel_x"),
606-
"rel_y": e.get("rel_y"),
607-
"rel_z": e.get("rel_z"),
608-
}
605+
elect_dict[str(i)] = _replace_nan_with_default(
606+
{
607+
"probe_id": probe_id,
608+
"probe_shank": e["probe_shank"],
609+
"probe_electrode": e["probe_electrode"],
610+
"contact_size": e.get("contact_size"),
611+
"rel_x": e.get("rel_x"),
612+
"rel_y": e.get("rel_y"),
613+
"rel_z": e.get("rel_z"),
614+
}
615+
)
609616

610617
# make the probe type if not in database
611618
new_probe_type_dict.update(
@@ -780,11 +787,15 @@ def create_from_nwbfile(
780787
"probe_electrode": elec_index,
781788
}
782789

783-
for dim in ["rel_x", "rel_y", "rel_z"]:
790+
for dim in ["rel_x", "rel_y", "rel_z", "contact_size"]:
784791
if dim in nwbfile.electrodes[elec_index]:
785-
elect_dict[elec_index][dim] = nwbfile.electrodes[
786-
elec_index, dim
787-
]
792+
value = nwbfile.electrodes[elec_index, dim]
793+
elect_dict[elec_index][dim] = value
794+
795+
# Apply NaN replacement to the entire electrode dictionary
796+
elect_dict[elec_index] = _replace_nan_with_default(
797+
elect_dict[elec_index]
798+
)
788799

789800
if not device_found:
790801
logger.warning(

src/spyglass/utils/dj_helper_fn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Helper functions for manipulating information from DataJoint fetch calls."""
22

33
import inspect
4+
import math
45
import multiprocessing.pool
56
import os
67
import re
@@ -701,3 +702,28 @@ def accept_divergence(key, new_value, existing_value, test_mode=False):
701702
+ f"'{new_value}' ?\n"
702703
)
703704
return str_to_bool(response)
705+
706+
707+
def _replace_nan_with_default(data_dict, default_value=-1.0):
708+
"""
709+
Replace NaN values in a dictionary with a default value.
710+
711+
This is necessary because DataJoint cannot properly format queries
712+
with NaN values, causing errors during probe insertion/validation.
713+
714+
Args:
715+
data_dict: Dictionary that may contain NaN values
716+
default_value: Value to replace NaN with (default: -1.0)
717+
718+
Returns:
719+
Dictionary with NaN values replaced
720+
"""
721+
if not isinstance(data_dict, dict):
722+
return data_dict
723+
724+
result = data_dict.copy()
725+
for key, value in result.items():
726+
if isinstance(value, float) and math.isnan(value):
727+
result[key] = default_value
728+
729+
return result

tests/common/test_device.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,30 @@ def test_create_probe(common, mini_devices, mini_path, mini_copy_name):
4545
assert array_equal(
4646
before, after
4747
), "Probe create_from_nwbfile had unexpected effect"
48+
49+
50+
def test_replace_nan_with_default(utils):
51+
"""Test that NaN values in probe geometry fields are properly replaced with -1.0."""
52+
# Test with NaN values (similar to the issue case)
53+
test_data = {
54+
"probe_id": "nTrode32_probe description",
55+
"probe_shank": 0,
56+
"contact_size": float("nan"),
57+
"probe_electrode": 194,
58+
"rel_x": float("nan"),
59+
"rel_y": float("nan"),
60+
"rel_z": float("nan"),
61+
}
62+
63+
result = utils.dj_helper_fn._replace_nan_with_default(test_data)
64+
65+
# Check that NaN values were replaced with -1.0
66+
assert result["contact_size"] == -1.0
67+
assert result["rel_x"] == -1.0
68+
assert result["rel_y"] == -1.0
69+
assert result["rel_z"] == -1.0
70+
71+
# Check that non-NaN values were preserved
72+
assert result["probe_id"] == "nTrode32_probe description"
73+
assert result["probe_shank"] == 0
74+
assert result["probe_electrode"] == 194

tests/conftest.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ def mini_insert(
290290
):
291291
from spyglass.common import LabMember, Nwbfile, Session # noqa: E402
292292
from spyglass.data_import import insert_sessions # noqa: E402
293-
from spyglass.spikesorting.spikesorting_merge import (
293+
from spyglass.spikesorting.spikesorting_merge import ( # noqa: E402
294294
SpikeSortingOutput,
295-
) # noqa: E402
295+
)
296296
from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402
297297

298298
_ = SpikeSortingOutput()
@@ -400,6 +400,14 @@ def populate_exception():
400400
yield PopulateException
401401

402402

403+
@pytest.fixture(scope="session")
404+
def utils():
405+
"""Spyglass utils module."""
406+
from spyglass import utils
407+
408+
yield utils
409+
410+
403411
@pytest.fixture(scope="session")
404412
def frequent_imports():
405413
"""Often needed for graph cascade."""

0 commit comments

Comments
 (0)