Skip to content

Commit a2c3c73

Browse files
Implement Iterator protocol and resource management for CSV readers
- DefaultCSVReader and AthenaCSVReader now inherit from collections.abc.Iterator - Added close() method to properly release file resources - Added context manager support (__enter__/__exit__) - Updated ResultSet to use reader's close() method (no exception suppression) - Removed _csv_file instance variable from ResultSet (reader manages file lifecycle) - Removed unnecessary contextlib.suppress (length check already handles empty files) - Added unit tests for Iterator protocol, close(), and context manager - No _closed flag needed - use _file is None check instead 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a2d9c20 commit a2c3c73

File tree

3 files changed

+97
-19
lines changed

3 files changed

+97
-19
lines changed

pyathena/s3fs/reader.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from __future__ import annotations
33

44
import csv
5-
from typing import Any, Iterator, List, Optional, Tuple
5+
from collections.abc import Iterator
6+
from typing import Any, List, Optional, Tuple
67

78

8-
class DefaultCSVReader:
9+
class DefaultCSVReader(Iterator[List[str]]):
910
"""CSV reader using Python's standard csv module.
1011
1112
This reader wraps Python's standard csv.reader and treats empty fields
@@ -33,9 +34,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None:
3334
file_obj: File-like object to read from.
3435
delimiter: Field delimiter character.
3536
"""
37+
self._file: Optional[Any] = file_obj
3638
self._reader = csv.reader(file_obj, delimiter=delimiter)
3739

38-
def __iter__(self) -> Iterator[List[str]]:
40+
def __iter__(self) -> "DefaultCSVReader":
3941
"""Iterate over rows in the CSV file."""
4042
return self
4143

@@ -46,17 +48,33 @@ def __next__(self) -> List[str]:
4648
List of field values as strings.
4749
4850
Raises:
49-
StopIteration: When end of file is reached.
51+
StopIteration: When end of file is reached or reader is closed.
5052
"""
53+
if self._file is None:
54+
raise StopIteration
5155
row = next(self._reader)
5256
# Python's csv.reader returns [] for empty lines; normalize to ['']
5357
# to represent a single empty field (consistent with single-value handling)
5458
if not row:
5559
return [""]
5660
return row
5761

62+
def close(self) -> None:
63+
"""Close the underlying file object."""
64+
if self._file is not None:
65+
self._file.close()
66+
self._file = None
67+
68+
def __enter__(self) -> "DefaultCSVReader":
69+
"""Enter context manager."""
70+
return self
71+
72+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
73+
"""Exit context manager and close resources."""
74+
self.close()
75+
5876

59-
class AthenaCSVReader:
77+
class AthenaCSVReader(Iterator[List[Optional[str]]]):
6078
"""CSV reader that distinguishes between NULL and empty string.
6179
6280
This is the default reader for S3FSCursor.
@@ -87,10 +105,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None:
87105
file_obj: File-like object to read from.
88106
delimiter: Field delimiter character.
89107
"""
90-
self._file = file_obj
108+
self._file: Optional[Any] = file_obj
91109
self._delimiter = delimiter
92110

93-
def __iter__(self) -> Iterator[List[Optional[str]]]:
111+
def __iter__(self) -> "AthenaCSVReader":
94112
"""Iterate over rows in the CSV file."""
95113
return self
96114

@@ -101,8 +119,10 @@ def __next__(self) -> List[Optional[str]]:
101119
List of field values, with None for NULL and '' for empty string.
102120
103121
Raises:
104-
StopIteration: When end of file is reached.
122+
StopIteration: When end of file is reached or reader is closed.
105123
"""
124+
if self._file is None:
125+
raise StopIteration
106126
line = self._file.readline()
107127
if not line:
108128
raise StopIteration
@@ -234,3 +254,17 @@ def _parse_unquoted_field(self, line: str, pos: int) -> Tuple[str, int]:
234254
pos += 1
235255

236256
return value, pos
257+
258+
def close(self) -> None:
259+
"""Close the underlying file object."""
260+
if self._file is not None:
261+
self._file.close()
262+
self._file = None
263+
264+
def __enter__(self) -> "AthenaCSVReader":
265+
"""Enter context manager."""
266+
return self
267+
268+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
269+
"""Exit context manager and close resources."""
270+
self.close()

pyathena/s3fs/result_set.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
import contextlib
54
import logging
65
from io import TextIOWrapper
76
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -80,7 +79,6 @@ def __init__(
8079
self._csv_reader_class: CSVReaderType = csv_reader or AthenaCSVReader
8180
self._fs = self._create_s3_file_system()
8281
self._csv_reader: Optional[Any] = None
83-
self._csv_file: Optional[Any] = None
8482

8583
if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
8684
self._init_csv_reader()
@@ -122,17 +120,16 @@ def _init_csv_reader(self) -> None:
122120
path = f"{bucket}/{key}"
123121

124122
try:
125-
self._csv_file = self._fs._open(path, mode="rb")
126-
text_wrapper = TextIOWrapper(self._csv_file, encoding="utf-8")
123+
csv_file = self._fs._open(path, mode="rb")
124+
text_wrapper = TextIOWrapper(csv_file, encoding="utf-8")
127125

128126
if self.output_location.endswith(".txt"):
129127
# Tab-separated format (no header row)
130128
self._csv_reader = self._csv_reader_class(text_wrapper, delimiter="\t")
131129
else:
132130
# Standard CSV format (has header row, skip it)
133131
self._csv_reader = self._csv_reader_class(text_wrapper, delimiter=",")
134-
with contextlib.suppress(StopIteration):
135-
next(self._csv_reader)
132+
next(self._csv_reader)
136133

137134
except Exception as e:
138135
_logger.exception(f"Failed to open {path}.")
@@ -228,8 +225,6 @@ def fetchall(
228225
def close(self) -> None:
229226
"""Close the result set and release resources."""
230227
super().close()
231-
if self._csv_file:
232-
with contextlib.suppress(Exception):
233-
self._csv_file.close()
234-
self._csv_file = None
235-
self._csv_reader = None
228+
if self._csv_reader:
229+
self._csv_reader.close()
230+
self._csv_reader = None

tests/pyathena/s3fs/test_reader.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
from collections.abc import Iterator
23
from io import StringIO
34

45
from pyathena.s3fs.reader import AthenaCSVReader, DefaultCSVReader
@@ -60,6 +61,30 @@ def test_empty_line(self):
6061
rows = list(reader)
6162
assert rows == [[""]]
6263

64+
def test_implements_iterator_protocol(self):
65+
"""DefaultCSVReader implements collections.abc.Iterator."""
66+
data = StringIO("a,b\n")
67+
reader = DefaultCSVReader(data, delimiter=",")
68+
assert isinstance(reader, Iterator)
69+
70+
def test_close(self):
71+
"""close() releases file resources."""
72+
data = StringIO("a,b\n")
73+
reader = DefaultCSVReader(data, delimiter=",")
74+
reader.close()
75+
# After close, iteration should stop immediately
76+
rows = list(reader)
77+
assert rows == []
78+
79+
def test_context_manager(self):
80+
"""Reader can be used as context manager."""
81+
data = StringIO("a,b\n1,2\n")
82+
with DefaultCSVReader(data, delimiter=",") as reader:
83+
rows = list(reader)
84+
assert rows == [["a", "b"], ["1", "2"]]
85+
# After exiting context, reader should be closed
86+
assert list(reader) == []
87+
6388

6489
class TestAthenaCSVReader:
6590
"""Tests for AthenaCSVReader that distinguishes NULL from empty string."""
@@ -209,3 +234,27 @@ def test_multiple_rows_with_multiline_field(self):
209234
reader = AthenaCSVReader(data, delimiter=",")
210235
rows = list(reader)
211236
assert rows == [["a", "b"], ["multi\nline", "c"], ["d", "e"]]
237+
238+
def test_implements_iterator_protocol(self):
239+
"""AthenaCSVReader implements collections.abc.Iterator."""
240+
data = StringIO("a,b\n")
241+
reader = AthenaCSVReader(data, delimiter=",")
242+
assert isinstance(reader, Iterator)
243+
244+
def test_close(self):
245+
"""close() releases file resources."""
246+
data = StringIO("a,b\n")
247+
reader = AthenaCSVReader(data, delimiter=",")
248+
reader.close()
249+
# After close, iteration should stop immediately
250+
rows = list(reader)
251+
assert rows == []
252+
253+
def test_context_manager(self):
254+
"""Reader can be used as context manager."""
255+
data = StringIO("a,b\n1,2\n")
256+
with AthenaCSVReader(data, delimiter=",") as reader:
257+
rows = list(reader)
258+
assert rows == [["a", "b"], ["1", "2"]]
259+
# After exiting context, reader should be closed
260+
assert list(reader) == []

0 commit comments

Comments
 (0)