|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
1 | 7 | import csv
|
2 | 8 | from itertools import islice
|
3 | 9 | from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union
|
@@ -54,16 +60,22 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None):
|
54 | 60 | def _handle_initial_state(self, state: Dict[str, Any]):
|
55 | 61 | """Restore reader state from checkpoint."""
|
56 | 62 | # Validate header compatibility
|
57 |
| - if (not self.has_header and self.HEADER_KEY in state) or (self.has_header and state[self.HEADER_KEY] is None): |
58 |
| - raise ValueError(f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}") |
| 63 | + if (not self.has_header and self.HEADER_KEY in state) or ( |
| 64 | + self.has_header and state[self.HEADER_KEY] is None |
| 65 | + ): |
| 66 | + raise ValueError( |
| 67 | + f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}" |
| 68 | + ) |
59 | 69 |
|
60 | 70 | self._header = state.get(self.HEADER_KEY)
|
61 | 71 | target_line_num = state[self.NUM_LINES_YIELDED]
|
62 | 72 | assert self._file is not None
|
63 | 73 | # Create appropriate reader
|
64 | 74 | if self.return_dict:
|
65 | 75 |
|
66 |
| - self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) |
| 76 | + self._reader = csv.DictReader( |
| 77 | + self._file, delimiter=self.delimiter, fieldnames=self._header |
| 78 | + ) |
67 | 79 | else:
|
68 | 80 | self._reader = csv.reader(self._file, delimiter=self.delimiter)
|
69 | 81 | # Skip header if needed (applies only when file has header)
|
@@ -112,5 +124,5 @@ def get_state(self) -> Dict[str, Any]:
|
112 | 124 | }
|
113 | 125 |
|
114 | 126 | def close(self):
|
115 |
| - if self._file and not self._file.closed: |
| 127 | + if self._file is not None and not self._file.closed: |
116 | 128 | self._file.close()
|
0 commit comments