Skip to content

Commit 7572ff5

Browse files
committed
update files
1 parent bd54f40 commit 7572ff5

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

test/nodes/test_csv_reader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def test_load_wrong_state(self):
105105

106106
state = node.state_dict()
107107
state[CSVReader.HEADER_KEY] = None
108-
with self.assertRaisesRegex(ValueError, "Check if has_header=True matches the state header=None"):
108+
with self.assertRaisesRegex(
109+
ValueError, "Check if has_header=True matches the state header=None"
110+
):
109111
node.reset(state)
110112

111113
node.close()
@@ -133,7 +135,9 @@ def test_empty_file(self):
133135
node.close()
134136

135137
def test_header_validation(self):
136-
with self.assertRaisesRegex(ValueError, "return_dict=True requires has_header=True"):
138+
with self.assertRaisesRegex(
139+
ValueError, "return_dict=True requires has_header=True"
140+
):
137141
CSVReader("dummy.csv", has_header=False, return_dict=True)
138142

139143
def test_multi_epoch(self):

torchdata/nodes/csv_reader.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import csv
28
from itertools import islice
39
from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union
@@ -54,16 +60,22 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
5460
def _handle_initial_state(self, state: Dict[str, Any]):
5561
"""Restore reader state from checkpoint."""
5662
# Validate header compatibility
57-
if (not self.has_header and self.HEADER_KEY in state) or (self.has_header and state[self.HEADER_KEY] is None):
58-
raise ValueError(f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}")
63+
if (not self.has_header and self.HEADER_KEY in state) or (
64+
self.has_header and state[self.HEADER_KEY] is None
65+
):
66+
raise ValueError(
67+
f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}"
68+
)
5969

6070
self._header = state.get(self.HEADER_KEY)
6171
target_line_num = state[self.NUM_LINES_YIELDED]
6272
assert self._file is not None
6373
# Create appropriate reader
6474
if self.return_dict:
6575

66-
self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header)
76+
self._reader = csv.DictReader(
77+
self._file, delimiter=self.delimiter, fieldnames=self._header
78+
)
6779
else:
6880
self._reader = csv.reader(self._file, delimiter=self.delimiter)
6981
# Skip header if needed (applies only when file has header)
@@ -112,5 +124,5 @@ def get_state(self) -> Dict[str, Any]:
112124
}
113125

114126
def close(self):
115-
if self._file and not self._file.closed:
127+
if self._file is not None and not self._file.closed:
116128
self._file.close()

0 commit comments

Comments
 (0)