Skip to content

Commit 4fbced8

Browse files
Support for writing with s3 file system (#539)
* Support for writing with s3 filesystem (fix #465) * Fix test cases * Apply fmt * Remove unnecessary vars * FIx fixture args * Add test cases for file writing * Add type hints
1 parent ee7748e commit 4fbced8

File tree

9 files changed

+1486
-409
lines changed

9 files changed

+1486
-409
lines changed

pyathena/filesystem/s3.py

Lines changed: 485 additions & 125 deletions
Large diffs are not rendered by default.

pyathena/filesystem/s3_object.py

Lines changed: 408 additions & 30 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ profile_file = "tests/sqlalchemy/profiles.txt"
136136
line-length = 100
137137
exclude = [
138138
".venv",
139-
"tests"
140139
]
141140
target-version = "py38"
142141

tests/pyathena/filesystem/test_s3.py

Lines changed: 110 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
# -*- coding: utf-8 -*-
2+
import uuid
23
from itertools import chain
3-
from typing import Dict
44

5+
import fsspec
56
import pytest
67

78
from pyathena.filesystem.s3 import S3File, S3FileSystem
89
from tests import ENV
910
from tests.pyathena.conftest import connect
1011

1112

13+
@pytest.fixture(scope="class")
14+
def register_filesystem():
15+
fsspec.register_implementation("s3", "pyathena.filesystem.s3.S3FileSystem", clobber=True)
16+
fsspec.register_implementation("s3a", "pyathena.filesystem.s3.S3FileSystem", clobber=True)
17+
18+
19+
@pytest.mark.usefixtures("register_filesystem")
1220
class TestS3FileSystem:
1321
def test_parse_path(self):
1422
actual = S3FileSystem.parse_path("s3://bucket")
@@ -109,35 +117,34 @@ def test_parse_path_invalid(self):
109117
S3FileSystem.parse_path("s3a://bucket/path/to/obj?foo=bar")
110118

111119
@pytest.fixture(scope="class")
112-
def fs(self) -> Dict[str, S3FileSystem]:
113-
fs = {
114-
"default": S3FileSystem(connect()),
115-
"small_batches": S3FileSystem(connect(), default_block_size=3),
116-
}
117-
return fs
120+
def fs(self, request):
121+
if not hasattr(request, "param"):
122+
setattr(request, "param", {})
123+
return S3FileSystem(connect(), **request.param)
118124

119125
@pytest.mark.parametrize(
120-
["start", "end", "batch_mode", "target_data"],
126+
["fs", "start", "end", "target_data"],
121127
list(
122128
chain(
123129
*[
124130
[
125-
(0, 5, x, b"01234"),
126-
(2, 7, x, b"23456"),
127-
(0, 10, x, b"0123456789"),
131+
({"default_block_size": x}, 0, 5, b"01234"),
132+
({"default_block_size": x}, 2, 7, b"23456"),
133+
({"default_block_size": x}, 0, 10, b"0123456789"),
128134
]
129-
for x in ("default", "small_batches")
135+
for x in (S3FileSystem.DEFAULT_BLOCK_SIZE, 3)
130136
]
131137
)
132138
),
139+
indirect=["fs"],
133140
)
134-
def test_read(self, fs, start, end, batch_mode, target_data):
141+
def test_read(self, fs, start, end, target_data):
135142
# lowest level access: use _get_object
136-
data = fs[batch_mode]._get_object(
143+
data = fs._get_object(
137144
ENV.s3_staging_bucket, ENV.s3_filesystem_test_file_key, ranges=(start, end)
138145
)
139146
assert data == (start, target_data), data
140-
with fs[batch_mode].open(
147+
with fs.open(
141148
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}", "rb"
142149
) as file:
143150
# mid-level access: use _fetch_range
@@ -148,7 +155,84 @@ def test_read(self, fs, start, end, batch_mode, target_data):
148155
data = file.read(end - start)
149156
assert data == target_data, data
150157

151-
def test_compatibility_with_s3fs(self):
158+
@pytest.mark.parametrize(
159+
["base", "exp"],
160+
[
161+
(1, 2**10),
162+
(10, 2**10),
163+
(100, 2**10),
164+
(1, 2**20),
165+
(10, 2**20),
166+
(100, 2**20),
167+
(1000, 2**20),
168+
],
169+
)
170+
def test_write(self, fs, base, exp):
171+
data = b"a" * (base * exp)
172+
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.dat"
173+
with fs.open(path, "wb") as f:
174+
f.write(data)
175+
with fs.open(path, "rb") as f:
176+
actual = f.read()
177+
assert len(actual) == len(data)
178+
assert actual == data
179+
180+
@pytest.mark.parametrize(
181+
["base", "exp"],
182+
[
183+
(1, 2**10),
184+
(10, 2**10),
185+
(100, 2**10),
186+
(1, 2**20),
187+
(10, 2**20),
188+
(100, 2**20),
189+
(1000, 2**20),
190+
],
191+
)
192+
def test_append(self, fs, base, exp):
193+
# TODO: Check the metadata is kept.
194+
data = b"a" * (base * exp)
195+
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.dat"
196+
with fs.open(path, "ab") as f:
197+
f.write(data)
198+
extra = b"extra"
199+
with fs.open(path, "ab") as f:
200+
f.write(extra)
201+
with fs.open(path, "rb") as f:
202+
actual = f.read()
203+
assert len(actual) == len(data + extra)
204+
assert actual == data + extra
205+
206+
def test_exists(self, fs):
207+
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_filesystem_test_file_key}"
208+
assert fs.exists(path)
209+
210+
not_exists_path = (
211+
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}"
212+
)
213+
assert not fs.exists(not_exists_path)
214+
215+
def test_touch(self, fs):
216+
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}"
217+
assert not fs.exists(path)
218+
fs.touch(path)
219+
assert fs.exists(path)
220+
assert fs.size(path) == 0
221+
222+
with fs.open(path, "wb") as f:
223+
f.write(b"data")
224+
assert fs.size(path) == 4
225+
fs.touch(path, truncate=True)
226+
assert fs.size(path) == 0
227+
228+
with fs.open(path, "wb") as f:
229+
f.write(b"data")
230+
assert fs.size(path) == 4
231+
with pytest.raises(ValueError):
232+
fs.touch(path, truncate=False)
233+
assert fs.size(path) == 4
234+
235+
def test_pandas_read_csv(self):
152236
import pandas
153237

154238
df = pandas.read_csv(
@@ -158,6 +242,16 @@ def test_compatibility_with_s3fs(self):
158242
)
159243
assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)]
160244

245+
def test_pandas_write_csv(self):
246+
import pandas
247+
248+
df = pandas.DataFrame({"a": [1], "b": [2]})
249+
path = f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/{uuid.uuid4()}.csv"
250+
df.to_csv(path, index=False)
251+
252+
actual = pandas.read_csv(path)
253+
pandas.testing.assert_frame_equal(df, actual)
254+
161255

162256
class TestS3File:
163257
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)