Skip to content

Commit 4868480

Browse files
Implement Iterator protocol and resource management for CSV readers
- DefaultCSVReader and AthenaCSVReader now inherit from collections.abc.Iterator - Added close() method to properly release file resources - Added context manager support (__enter__/__exit__) - Updated ResultSet to use reader's close() method - Removed _csv_file instance variable from ResultSet (reader manages file lifecycle) - No _closed flag needed - use _file is None check instead 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a2d9c20 commit 4868480

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

pyathena/s3fs/reader.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from __future__ import annotations
33

44
import csv
5-
from typing import Any, Iterator, List, Optional, Tuple
5+
from collections.abc import Iterator
6+
from typing import Any, List, Optional, Tuple
67

78

8-
class DefaultCSVReader:
9+
class DefaultCSVReader(Iterator[List[str]]):
910
"""CSV reader using Python's standard csv module.
1011
1112
This reader wraps Python's standard csv.reader and treats empty fields
@@ -33,9 +34,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None:
3334
file_obj: File-like object to read from.
3435
delimiter: Field delimiter character.
3536
"""
37+
self._file: Optional[Any] = file_obj
3638
self._reader = csv.reader(file_obj, delimiter=delimiter)
3739

38-
def __iter__(self) -> Iterator[List[str]]:
40+
def __iter__(self) -> "DefaultCSVReader":
3941
"""Iterate over rows in the CSV file."""
4042
return self
4143

@@ -46,17 +48,33 @@ def __next__(self) -> List[str]:
4648
List of field values as strings.
4749
4850
Raises:
49-
StopIteration: When end of file is reached.
51+
StopIteration: When end of file is reached or reader is closed.
5052
"""
53+
if self._file is None:
54+
raise StopIteration
5155
row = next(self._reader)
5256
# Python's csv.reader returns [] for empty lines; normalize to ['']
5357
# to represent a single empty field (consistent with single-value handling)
5458
if not row:
5559
return [""]
5660
return row
5761

62+
def close(self) -> None:
63+
"""Close the underlying file object."""
64+
if self._file is not None:
65+
self._file.close()
66+
self._file = None
67+
68+
def __enter__(self) -> "DefaultCSVReader":
69+
"""Enter context manager."""
70+
return self
71+
72+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
73+
"""Exit context manager and close resources."""
74+
self.close()
75+
5876

59-
class AthenaCSVReader:
77+
class AthenaCSVReader(Iterator[List[Optional[str]]]):
6078
"""CSV reader that distinguishes between NULL and empty string.
6179
6280
This is the default reader for S3FSCursor.
@@ -87,10 +105,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None:
87105
file_obj: File-like object to read from.
88106
delimiter: Field delimiter character.
89107
"""
90-
self._file = file_obj
108+
self._file: Optional[Any] = file_obj
91109
self._delimiter = delimiter
92110

93-
def __iter__(self) -> Iterator[List[Optional[str]]]:
111+
def __iter__(self) -> "AthenaCSVReader":
94112
"""Iterate over rows in the CSV file."""
95113
return self
96114

@@ -101,8 +119,10 @@ def __next__(self) -> List[Optional[str]]:
101119
List of field values, with None for NULL and '' for empty string.
102120
103121
Raises:
104-
StopIteration: When end of file is reached.
122+
StopIteration: When end of file is reached or reader is closed.
105123
"""
124+
if self._file is None:
125+
raise StopIteration
106126
line = self._file.readline()
107127
if not line:
108128
raise StopIteration
@@ -234,3 +254,17 @@ def _parse_unquoted_field(self, line: str, pos: int) -> Tuple[str, int]:
234254
pos += 1
235255

236256
return value, pos
257+
258+
def close(self) -> None:
259+
"""Close the underlying file object."""
260+
if self._file is not None:
261+
self._file.close()
262+
self._file = None
263+
264+
def __enter__(self) -> "AthenaCSVReader":
265+
"""Enter context manager."""
266+
return self
267+
268+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
269+
"""Exit context manager and close resources."""
270+
self.close()

pyathena/s3fs/result_set.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __init__(
8080
self._csv_reader_class: CSVReaderType = csv_reader or AthenaCSVReader
8181
self._fs = self._create_s3_file_system()
8282
self._csv_reader: Optional[Any] = None
83-
self._csv_file: Optional[Any] = None
8483

8584
if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
8685
self._init_csv_reader()
@@ -122,8 +121,8 @@ def _init_csv_reader(self) -> None:
122121
path = f"{bucket}/{key}"
123122

124123
try:
125-
self._csv_file = self._fs._open(path, mode="rb")
126-
text_wrapper = TextIOWrapper(self._csv_file, encoding="utf-8")
124+
csv_file = self._fs._open(path, mode="rb")
125+
text_wrapper = TextIOWrapper(csv_file, encoding="utf-8")
127126

128127
if self.output_location.endswith(".txt"):
129128
# Tab-separated format (no header row)
@@ -228,8 +227,7 @@ def fetchall(
228227
def close(self) -> None:
229228
"""Close the result set and release resources."""
230229
super().close()
231-
if self._csv_file:
230+
if self._csv_reader:
232231
with contextlib.suppress(Exception):
233-
self._csv_file.close()
234-
self._csv_file = None
235-
self._csv_reader = None
232+
self._csv_reader.close()
233+
self._csv_reader = None

0 commit comments

Comments
 (0)