Skip to content

Commit 11bbbed

Browse files
authored
Bugfix: If OVERRIDE_CONTENT_ENCODING_FOR_PARQUET_KWARG passed then explicitly override content encoding (#505)
1 parent 6f6ff83 commit 11bbbed

File tree

2 files changed

+117
-9
lines changed

2 files changed

+117
-9
lines changed

deltacat/tests/utils/test_pyarrow.py

+106-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from deltacat.utils.pyarrow import (
33
s3_partial_parquet_file_to_table,
44
pyarrow_read_csv,
5+
ContentTypeValidationError,
56
content_type_to_reader_kwargs,
67
_add_column_kwargs,
8+
logger,
79
s3_file_to_table,
10+
s3_file_to_parquet,
811
ReadKwargsProviderPyArrowSchemaOverride,
912
RAISE_ON_EMPTY_CSV_KWARG,
1013
RAISE_ON_DECIMAL_OVERFLOW,
@@ -435,7 +438,7 @@ def test_read_csv_when_decimal_precision_overflows_and_raise_kwarg_specified(sel
435438
pa.lib.ArrowInvalid,
436439
lambda: pyarrow_read_csv(
437440
OVERFLOWING_DECIMAL_PRECISION_UTSV_PATH,
438-
**{**kwargs, RAISE_ON_DECIMAL_OVERFLOW: True}
441+
**{**kwargs, RAISE_ON_DECIMAL_OVERFLOW: True},
439442
),
440443
)
441444

@@ -479,7 +482,7 @@ def test_read_csv_when_decimal_scale_overflows_and_raise_kwarg_specified(self):
479482
pa.lib.ArrowInvalid,
480483
lambda: pyarrow_read_csv(
481484
OVERFLOWING_DECIMAL_SCALE_UTSV_PATH,
482-
**{**kwargs, RAISE_ON_DECIMAL_OVERFLOW: True}
485+
**{**kwargs, RAISE_ON_DECIMAL_OVERFLOW: True},
483486
),
484487
)
485488

@@ -590,7 +593,7 @@ def test_read_csv_when_decimal_scale_overflows_with_decimal256_and_raise_on_over
590593
pa.lib.ArrowNotImplementedError,
591594
lambda: pyarrow_read_csv(
592595
OVERFLOWING_DECIMAL_SCALE_UTSV_PATH,
593-
**{**kwargs, RAISE_ON_DECIMAL_OVERFLOW: True}
596+
**{**kwargs, RAISE_ON_DECIMAL_OVERFLOW: True},
594597
),
595598
)
596599

@@ -818,8 +821,11 @@ def test_s3_file_to_table_when_utsv_gzip_and_content_type_overridden(self):
818821
schema = pa.schema(
819822
[("is_active", pa.string()), ("ship_datetime_utc", pa.timestamp("us"))]
820823
)
821-
822824
# OVERRIDE_CONTENT_ENCODING_FOR_PARQUET_KWARG has no effect on uTSV files
825+
pa_kwargs_provider = lambda content_type, kwargs: {
826+
"reader_type": "pyarrow",
827+
**kwargs,
828+
}
823829
pa_kwargs_provider = lambda content_type, kwargs: {
824830
"reader_type": "pyarrow",
825831
OVERRIDE_CONTENT_ENCODING_FOR_PARQUET_KWARG: ContentEncoding.IDENTITY.value,
@@ -864,3 +870,99 @@ def test_s3_file_to_table_when_parquet_gzip_and_encoding_overridden(self):
864870
schema = result.schema
865871
schema_index = schema.get_field_index("n_legs")
866872
self.assertEqual(schema.field(schema_index).type, "int64")
873+
874+
875+
class TestS3FileToParquet(TestCase):
876+
def test_s3_file_to_parquet_sanity(self):
877+
test_s3_url = PARQUET_FILE_PATH
878+
test_content_type = ContentType.PARQUET.value
879+
test_content_encoding = ContentEncoding.IDENTITY.value
880+
pa_kwargs_provider = lambda content_type, kwargs: {
881+
"reader_type": "pyarrow",
882+
**kwargs,
883+
}
884+
with self.assertLogs(logger=logger.name, level="DEBUG") as cm:
885+
result_parquet_file: ParquetFile = s3_file_to_parquet(
886+
test_s3_url,
887+
test_content_type,
888+
test_content_encoding,
889+
["n_legs", "animal"],
890+
["n_legs"],
891+
pa_read_func_kwargs_provider=pa_kwargs_provider,
892+
)
893+
log_message_log_args = cm.records[0].getMessage()
894+
log_message_presanitize_kwargs = cm.records[1].getMessage()
895+
self.assertIn(
896+
f"Reading {test_s3_url} to PyArrow ParquetFile. Content type: {test_content_type}. Encoding: {test_content_encoding}",
897+
log_message_log_args,
898+
)
899+
self.assertIn("{'reader_type': 'pyarrow'}", log_message_presanitize_kwargs)
900+
for index, field in enumerate(result_parquet_file.schema_arrow):
901+
self.assertEqual(
902+
field.name, result_parquet_file.schema_arrow.field(index).name
903+
)
904+
self.assertEqual(result_parquet_file.schema_arrow.field(0).type, "int64")
905+
906+
def test_s3_file_to_parquet_when_parquet_gzip_encoding_and_overridden_returns_success(
907+
self,
908+
):
909+
test_s3_url = PARQUET_FILE_PATH
910+
test_content_type = ContentType.PARQUET.value
911+
test_content_encoding = ContentEncoding.GZIP.value
912+
pa_kwargs_provider = lambda content_type, kwargs: {
913+
"reader_type": "pyarrow",
914+
OVERRIDE_CONTENT_ENCODING_FOR_PARQUET_KWARG: ContentEncoding.IDENTITY.value,
915+
**kwargs,
916+
}
917+
with self.assertLogs(logger=logger.name, level="DEBUG") as cm:
918+
result_parquet_file: ParquetFile = s3_file_to_parquet(
919+
test_s3_url,
920+
test_content_type,
921+
test_content_encoding,
922+
["n_legs", "animal"],
923+
["n_legs"],
924+
pa_read_func_kwargs_provider=pa_kwargs_provider,
925+
)
926+
log_message_log_args = cm.records[0].getMessage()
927+
log_message_log_new_content_encoding = cm.records[1].getMessage()
928+
log_message_presanitize_kwargs = cm.records[2].getMessage()
929+
self.assertIn(
930+
f"Reading {test_s3_url} to PyArrow ParquetFile. Content type: {test_content_type}. Encoding: {test_content_encoding}",
931+
log_message_log_args,
932+
)
933+
self.assertIn(
934+
f"Overriding {test_s3_url} content encoding from {ContentEncoding.GZIP.value} to {ContentEncoding.IDENTITY.value}",
935+
log_message_log_new_content_encoding,
936+
)
937+
self.assertIn("{'reader_type': 'pyarrow'}", log_message_presanitize_kwargs)
938+
for index, field in enumerate(result_parquet_file.schema_arrow):
939+
self.assertEqual(
940+
field.name, result_parquet_file.schema_arrow.field(index).name
941+
)
942+
self.assertEqual(result_parquet_file.schema_arrow.field(0).type, "int64")
943+
944+
def test_s3_file_to_parquet_when_parquet_gzip_encoding_not_overridden_throws_error(
945+
self,
946+
):
947+
test_s3_url = PARQUET_FILE_PATH
948+
test_content_type = ContentType.PARQUET.value
949+
test_content_encoding = ContentEncoding.GZIP.value
950+
pa_kwargs_provider = lambda content_type, kwargs: {
951+
"reader_type": "pyarrow",
952+
**kwargs,
953+
}
954+
with self.assertRaises(ContentTypeValidationError):
955+
with self.assertLogs(logger=logger.name, level="DEBUG") as cm:
956+
s3_file_to_parquet(
957+
test_s3_url,
958+
test_content_type,
959+
test_content_encoding,
960+
["n_legs", "animal"],
961+
["n_legs"],
962+
pa_read_func_kwargs_provider=pa_kwargs_provider,
963+
)
964+
log_message_log_args = cm.records[0].getMessage()
965+
self.assertIn(
966+
f"Reading {test_s3_url} to PyArrow ParquetFile. Content type: {test_content_type}. Encoding: {test_content_encoding}",
967+
log_message_log_args,
968+
)

deltacat/utils/pyarrow.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,18 @@ def s3_file_to_parquet(
617617
f"Reading {s3_url} to PyArrow ParquetFile. "
618618
f"Content type: {content_type}. Encoding: {content_encoding}"
619619
)
620+
kwargs = {}
621+
if pa_read_func_kwargs_provider:
622+
kwargs = pa_read_func_kwargs_provider(content_type, kwargs)
620623

624+
if OVERRIDE_CONTENT_ENCODING_FOR_PARQUET_KWARG in kwargs:
625+
new_content_encoding = kwargs.pop(OVERRIDE_CONTENT_ENCODING_FOR_PARQUET_KWARG)
626+
if content_type == ContentType.PARQUET.value:
627+
logger.debug(
628+
f"Overriding {s3_url} content encoding from {content_encoding} "
629+
f"to {new_content_encoding}"
630+
)
631+
content_encoding = new_content_encoding
621632
if (
622633
content_type != ContentType.PARQUET.value
623634
or content_encoding != ContentEncoding.IDENTITY
@@ -630,15 +641,10 @@ def s3_file_to_parquet(
630641
if s3_client_kwargs is None:
631642
s3_client_kwargs = {}
632643

633-
kwargs = {}
634-
635644
if s3_url.startswith("s3://"):
636645
s3_file_system = create_s3_file_system(s3_client_kwargs)
637646
kwargs["filesystem"] = s3_file_system
638647

639-
if pa_read_func_kwargs_provider:
640-
kwargs = pa_read_func_kwargs_provider(content_type, kwargs)
641-
642648
logger.debug(f"Pre-sanitize kwargs for {s3_url}: {kwargs}")
643649

644650
kwargs = sanitize_kwargs_to_callable(ParquetFile.__init__, kwargs)

0 commit comments

Comments
 (0)