Skip to content

Commit 526446c

Browse files
committed
fix
1 parent d37cf2e commit 526446c

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

datasets/flwr_datasets/federated_dataset_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -629,14 +629,18 @@ def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
629629
for key in row1:
630630
if key == "audio":
631631
# Special handling for 'audio' key
632-
if not all(
633-
[
634-
np.array_equal(row1[key]["array"], row2[key]["array"]),
635-
row1[key]["path"] == row2[key]["path"],
636-
row1[key]["sampling_rate"] == row2[key]["sampling_rate"],
637-
]
638-
):
632+
# Check array and sampling_rate
633+
if not np.array_equal(row1[key]["array"], row2[key]["array"]):
639634
return False
635+
if row1[key]["sampling_rate"] != row2[key]["sampling_rate"]:
636+
return False
637+
638+
# Check path if available (AudioDecoder raises TypeError)
639+
try:
640+
if row1[key]["path"] != row2[key]["path"]:
641+
return False
642+
except TypeError:
643+
pass
640644
elif row1[key] != row2[key]:
641645
# Direct comparison for other keys
642646
return False

0 commit comments

Comments
 (0)