Skip to content

Commit ea39ffb

Browse files
Merge pull request #548 from laughingman7743/#547
Adjusted size of the last part of a multipart request (fix #547)
2 parents 8c72f89 + 2e48aa2 commit ea39ffb

File tree

2 files changed

+77
-31
lines changed

2 files changed

+77
-31
lines changed

pyathena/filesystem/s3.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,9 @@ def _head_object(
215215

216216
def _ls_buckets(self, refresh: bool = False) -> List[S3Object]:
217217
if "" not in self.dircache or refresh:
218-
try:
219-
response = self._call(
220-
self._client.list_buckets,
221-
)
222-
except botocore.exceptions.ClientError as e:
223-
raise
218+
response = self._call(
219+
self._client.list_buckets,
220+
)
224221
buckets = [
225222
S3Object(
226223
init={
@@ -550,6 +547,8 @@ def cp_file(self, path1: str, path2: str, **kwargs):
550547
bucket2, key2, version_id2 = self.parse_path(path2)
551548
if version_id2:
552549
raise ValueError("Cannot copy to a versioned file.")
550+
if not key1 or not key2:
551+
raise ValueError("Cannot copy buckets.")
553552

554553
info1 = self.info(path1)
555554
size1 = info1.get("size", 0)
@@ -662,7 +661,7 @@ def _copy_object_with_multipart_upload(
662661
}
663662
)
664663

665-
parts.sort(key=lambda x: x["PartNumber"])
664+
parts.sort(key=lambda x: x["PartNumber"]) # type: ignore
666665
self._complete_multipart_upload(
667666
bucket=bucket2,
668667
key=key2,
@@ -677,14 +676,20 @@ def cat_file(
677676
if start is not None or end is not None:
678677
size = self.info(path).get("size", 0)
679678
if start is None:
680-
start = 0
679+
range_start = 0
681680
elif start < 0:
682-
start = size + start
681+
range_start = size + start
682+
else:
683+
range_start = start
684+
683685
if end is None:
684-
end = size
686+
range_end = size
685687
elif end < 0:
686-
end = size + end
687-
ranges = (start, end)
688+
range_end = size + end
689+
else:
690+
range_end = end
691+
692+
ranges = (range_start, range_end)
688693
else:
689694
ranges = None
690695

@@ -1082,17 +1087,38 @@ def _upload_chunk(self, final: bool = False) -> bool:
10821087
part_number = len(self.multipart_upload_parts)
10831088
self.buffer.seek(0)
10841089
while data := self.buffer.read(self.blocksize):
1085-
part_number += 1
1086-
self.multipart_upload_parts.append(
1087-
self._executor.submit(
1088-
self.fs._upload_part,
1089-
bucket=self.bucket,
1090-
key=self.key,
1091-
upload_id=cast(str, self.multipart_upload.upload_id),
1092-
part_number=part_number,
1093-
body=data,
1090+
# The last part of a multipart request should be adjusted
1091+
# to be larger than the minimum part size.
1092+
next_data = self.buffer.read(self.blocksize)
1093+
next_data_size = len(next_data)
1094+
if 0 < next_data_size < self.fs.MULTIPART_UPLOAD_MIN_PART_SIZE:
1095+
upload_data = data + next_data
1096+
upload_data_size = len(upload_data)
1097+
if upload_data_size < self.fs.MULTIPART_UPLOAD_MAX_PART_SIZE:
1098+
uploads = [upload_data]
1099+
else:
1100+
split_size = upload_data_size // 2
1101+
uploads = [upload_data[:split_size], upload_data[split_size:]]
1102+
else:
1103+
uploads = [data]
1104+
if next_data:
1105+
uploads.append(next_data)
1106+
1107+
for upload in uploads:
1108+
part_number += 1
1109+
self.multipart_upload_parts.append(
1110+
self._executor.submit(
1111+
self.fs._upload_part,
1112+
bucket=self.bucket,
1113+
key=self.key,
1114+
upload_id=cast(str, self.multipart_upload.upload_id),
1115+
part_number=part_number,
1116+
body=upload,
1117+
)
10941118
)
1095-
)
1119+
1120+
if not next_data:
1121+
break
10961122

10971123
if self.autocommit and final:
10981124
self.commit()

tests/pyathena/filesystem/test_s3.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -733,18 +733,38 @@ def test_pandas_read_csv(self):
733733
)
734734
assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)]
735735

736-
def test_pandas_write_csv(self):
736+
@pytest.mark.parametrize(
737+
["line_count"],
738+
[
739+
(1 * (2**20),), # Generates files of about 2 MB.
740+
(2 * (2**20),), # 4MB
741+
(3 * (2**20),), # 6MB
742+
(4 * (2**20),), # 8MB
743+
(5 * (2**20),), # 10MB
744+
(6 * (2**20),), # 12MB
745+
],
746+
)
747+
def test_pandas_write_csv(self, line_count):
737748
import pandas
738749

739-
df = pandas.DataFrame({"a": [1], "b": [2]})
740-
path = (
741-
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/"
742-
f"filesystem/test_pandas_write_csv/{uuid.uuid4()}.csv"
743-
)
744-
df.to_csv(path, index=False)
750+
with tempfile.NamedTemporaryFile("w") as tmp:
751+
tmp.write("col1")
752+
tmp.write("\n")
753+
for i in range(0, line_count):
754+
tmp.write("a")
755+
tmp.write("\n")
756+
tmp.flush()
757+
758+
tmp.seek(0)
759+
df = pandas.read_csv(tmp.name)
760+
path = (
761+
f"s3://{ENV.s3_staging_bucket}/{ENV.s3_staging_key}{ENV.schema}/"
762+
f"filesystem/test_pandas_write_csv/{uuid.uuid4()}.csv"
763+
)
764+
df.to_csv(path, index=False)
745765

746-
actual = pandas.read_csv(path)
747-
pandas.testing.assert_frame_equal(df, actual)
766+
actual = pandas.read_csv(path)
767+
pandas.testing.assert_frame_equal(actual, df)
748768

749769

750770
class TestS3File:

0 commit comments

Comments
 (0)