Skip to content

Commit 5d4543a

Browse files
committed
Feat: improve hive partitioning detection for parquet reader
1 parent b67d976 commit 5d4543a

File tree

4 files changed

+206
-83
lines changed

4 files changed

+206
-83
lines changed

aodn_cloud_optimised/bin/create_dataset_config.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -817,10 +817,13 @@ def main():
817817
case ".parquet":
818818
dataset_config_schema = dict()
819819

820+
# TODO: at this stage, we don't know yet if it's a hive or single parquet file. Could add another option in the create_dataset_config script for parquet only.
820821
try:
821822
# Try reading as a single Parquet file
822823
with fs.open(fp, "rb") as f:
823824
schema = pq.read_schema(f)
825+
826+
parquet_partitioning = None
824827
except Exception:
825828
# If that fails, assume it's a Hive-partitioned dataset
826829

@@ -834,12 +837,18 @@ def main():
834837
dataset_path, format="parquet", partitioning="hive", filesystem=fs
835838
)
836839
schema = dataset.schema
840+
parquet_partitioning = "hive"
837841

838842
dataset_config_schema = dict()
839843
for field in schema:
840844
dataset_config_schema[field.name] = {"type": str(field.type)}
841845

842-
regex_filter = [".*\\.parquet$"]
846+
case ".zarr":
847+
# TODO: implement a zarr reader
848+
849+
raise NotImplementedError(
850+
f"input file type `{obj_key_suffix}` not yet implemented"
851+
)
843852

844853
# Default: Raise NotImplemented
845854
case _:
@@ -886,10 +895,36 @@ def main():
886895
"mode": f"{TO_REPLACE_PLACEHOLDER}",
887896
"restart_every_path": False,
888897
}
889-
parent_s3_path = PureS3Path.from_uri(fp).parent.as_uri()
890-
dataset_config["run_settings"]["paths"] = [
891-
{"s3_uri": parent_s3_path, "filter": regex_filter, "year_range": []}
892-
]
898+
899+
match obj_key_suffix:
900+
case ".nc" | ".csv":
901+
parent_s3_path = PureS3Path.from_uri(fp).parent.as_uri()
902+
dataset_config["run_settings"]["paths"] = [
903+
{
904+
"type": "files",
905+
"s3_uri": parent_s3_path,
906+
"filter": regex_filter,
907+
"year_range": [],
908+
}
909+
]
910+
case ".zarr":
911+
# TODO: partially implemented
912+
parent_s3_path = PureS3Path.from_uri(fp).as_uri()
913+
dataset_config["run_settings"]["paths"] = [
914+
{
915+
"type": "zarr",
916+
"s3_uri": parent_s3_path,
917+
}
918+
]
919+
case ".parquet":
920+
parent_s3_path = PureS3Path.from_uri(fp).as_uri()
921+
dataset_config["run_settings"]["paths"] = [
922+
{
923+
"type": "parquet",
924+
"partitioning": parquet_partitioning,
925+
"s3_uri": parent_s3_path,
926+
}
927+
]
893928

894929
if args.s3fs_opts:
895930
dataset_config.setdefault("run_settings", {})["s3_bucket_opts"] = {

aodn_cloud_optimised/bin/generic_cloud_optimised_creation.py

Lines changed: 131 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,22 @@ class PathConfig(BaseModel):
7272
7373
Attributes:
7474
s3_uri: S3 URI as a POSIX path string.
75-
filter: List of regex patterns to filter files.
76-
year_range: Year filter: None, one year, or a two-year inclusive range, or a list of exclusive years to process.
75+
type: Type of dataset. Can be "files", "parquet", or "zarr".
76+
partitioning: Optional, used only for Parquet datasets (e.g., "hive").
77+
filter: List of regex patterns to filter files (only valid for type="files").
78+
year_range: Optional Year filter: None, one year, or a two-year inclusive range, or a list of exclusive years to process. (only valid for type="files")
79+
7780
"""
7881

7982
s3_uri: str
83+
type: Optional[Literal["files", "parquet", "zarr"]] = Field(
84+
default=None,
85+
description="Dataset type. One of 'files', 'parquet', or 'zarr'. Defaults to 'files' if not specified.",
86+
)
87+
partitioning: Optional[str] = Field(
88+
default=None,
89+
description="Partitioning scheme, only valid when type='parquet'. Currently supports 'hive'.",
90+
)
8091
filter: List[str] = Field(
8192
default_factory=list,
8293
description="List of regular expression patterns used to filter matching files.",
@@ -152,6 +163,55 @@ def validate_regex(cls, v):
152163
raise ValueError(f"Invalid regex: {pattern} ({e})")
153164
return v
154165

166+
@model_validator(mode="after")
167+
def validate_cross_fields(cls, values):
168+
dataset_type = values.type or "files"
169+
if values.type is None:
170+
warnings.warn(
171+
"No 'type' specified in PathConfig. Assuming 'files' as default.",
172+
UserWarning,
173+
stacklevel=2,
174+
)
175+
values.type = "files"
176+
if (
177+
any(".parquet" in f for f in values.filter)
178+
or ".parquet" in values.s3_uri
179+
):
180+
raise ValueError(
181+
"type must be defined as 'parquet' in run_settings.paths config if ingesting a parquet dataset."
182+
)
183+
elif any(".zarr" in f for f in values.filter) or ".zarr" in values.s3_uri:
184+
raise ValueError(
185+
"type must be defined as 'zarr' in run_settings.paths config if ingesting a zarr dataset."
186+
)
187+
188+
if dataset_type == "parquet":
189+
if values.filter:
190+
raise ValueError("filter must not be defined when type='parquet'")
191+
if values.year_range:
192+
raise ValueError("year_range must not be defined when type='parquet'")
193+
if values.partitioning not in (None, "hive"):
194+
raise ValueError(
195+
f"Invalid partitioning='{values.partitioning}' for parquet dataset. Only 'hive' is supported."
196+
)
197+
198+
elif dataset_type == "zarr":
199+
if values.filter:
200+
raise ValueError("filter must not be defined when type='zarr'")
201+
if values.year_range:
202+
raise ValueError("year_range must not be defined when type='zarr'")
203+
if values.partitioning:
204+
raise ValueError("partitioning is not applicable when type='zarr'")
205+
206+
elif dataset_type == "files":
207+
if values.partitioning:
208+
raise ValueError("partitioning is not applicable when type='files'")
209+
210+
else:
211+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
212+
213+
return values
214+
155215

156216
class WorkerOptions(BaseModel):
157217
"""Worker configuration for Coiled clusters.
@@ -1139,88 +1199,99 @@ def load_config_and_validate(config_filename: str) -> DatasetConfig:
11391199
return DatasetConfig.model_validate(dataset_config)
11401200

11411201

1142-
def json_update(base: dict, updates: dict) -> dict:
1143-
"""Recursively update nested dictionaries."""
1144-
for k, v in updates.items():
1145-
if isinstance(v, dict) and isinstance(base.get(k), dict):
1146-
base[k] = json_update(base[k], v)
1147-
else:
1148-
base[k] = v
1149-
return base
1150-
1151-
11521202
def collect_files(
11531203
path_cfg: PathConfig,
1154-
suffix: str,
1204+
suffix: Optional[str],
11551205
exclude: Optional[str],
11561206
bucket_raw: Optional[str],
11571207
s3_client_opts: Optional[dict] = None,
11581208
) -> List[str]:
1159-
"""Collect files from an S3 bucket using suffix and optional regex filtering.
1209+
"""Collect dataset paths from S3 based on dataset type.
1210+
1211+
Supports:
1212+
- 'files': lists and filters regular files (e.g., NetCDF, CSV)
1213+
- 'parquet': handles both single Parquet files and Hive-partitioned datasets
1214+
- 'zarr': returns the Zarr store path directly
11601215
11611216
Args:
1162-
path_cfg: Configuration object including the S3 URI and optional regex filters.
1217+
path_cfg: Configuration object including type, S3 URI, and optional regex filters.
11631218
suffix: File suffix to filter by, e.g., '.nc'. Set to None to disable suffix filtering.
11641219
exclude: Optional regex string to exclude files.
11651220
bucket_raw: Required if `path_cfg.s3_uri` is not a full S3 URI.
1221+
s3_client_opts: Optional dict with boto3 S3 client options.
11661222
11671223
Returns:
1168-
List of matching file keys (paths) as strings.
1224+
List of dataset paths (files or root URIs) as strings.
11691225
"""
1170-
s3_uri = path_cfg.s3_uri
1171-
1172-
if s3_uri.startswith("s3://"):
1173-
parsed = urlparse(s3_uri)
1174-
bucket = parsed.netloc
1175-
prefix = parsed.path.lstrip("/")
1176-
else:
1177-
if not bucket_raw:
1178-
raise ValueError(
1179-
"bucket_raw must be provided when s3_uri is not a full S3 URI."
1180-
)
1181-
bucket = bucket_raw
1182-
prefix = s3_uri
1183-
1184-
prefix = str(PurePosixPath(prefix)) # normalise path
1226+
dataset_type = getattr(path_cfg, "type", "files") # default value
1227+
s3_uri = path_cfg.s3_uri.rstrip("/")
1228+
1229+
# ---------------------------------------------------------------------
1230+
# Handle 'file' collection (NetCDF, CSV)
1231+
# ---------------------------------------------------------------------
1232+
if dataset_type == "files":
1233+
if s3_uri.startswith("s3://"):
1234+
parsed = urlparse(s3_uri)
1235+
bucket = parsed.netloc
1236+
prefix = parsed.path.lstrip("/")
1237+
else:
1238+
if not bucket_raw:
1239+
raise ValueError(
1240+
"bucket_raw must be provided when s3_uri is not a full S3 URI."
1241+
)
1242+
bucket = bucket_raw
1243+
prefix = s3_uri
11851244

1186-
# handle case when collecting files for a parquet hive partition
1187-
for f in path_cfg.filter:
1188-
s3_uri = path_cfg.s3_uri.rstrip("/") + "/"
1189-
pattern_simplified = f.rstrip("$")
1245+
prefix = str(PurePosixPath(prefix)) # normalise path
11901246

1191-
# Only handle .parquet ending
1192-
if pattern_simplified.endswith(".parquet") or pattern_simplified.endswith(
1193-
".parquet/"
1194-
):
1195-
filename_candidate = pattern_simplified.split("/")[-1]
1247+
matching_files = s3_ls(
1248+
bucket,
1249+
prefix,
1250+
suffix=suffix,
1251+
exclude=exclude,
1252+
s3_client_opts=s3_client_opts,
1253+
)
11961254

1197-
# Check for unsupported regex characters
1198-
if re.search(r"[\*\[\]\(\)\+\?]", filename_candidate):
1255+
for pattern in path_cfg.filter or []:
1256+
logger.info(f"Filtering files with regex pattern: {pattern}")
1257+
regex = re.compile(pattern)
1258+
matching_files = [f for f in matching_files if regex.search(f)]
1259+
if not matching_files:
11991260
raise ValueError(
1200-
f"In the case of a parquet dataset input, the filter value should match a dataset name without complex regex patterns. Filter '{f}' is too complex to convert to a filename. Please modify config"
1261+
f"No files matching {pattern} under {s3_uri}. Modify regexp filter or path in configuration file. Abort"
12011262
)
12021263

1203-
# Remove escaped characters like \.
1204-
filename = re.sub(r"\\(.)", r"\1", filename_candidate)
1205-
return [s3_uri + filename]
1264+
logger.info(f"Matched {len(matching_files)} files")
12061265

1207-
# matching_files = s3_ls(bucket, prefix, suffix=suffix, exclude=exclude)
1208-
matching_files = s3_ls(
1209-
bucket, prefix, suffix=None, exclude=exclude, s3_client_opts=s3_client_opts
1210-
)
1266+
return matching_files
12111267

1212-
for pattern in path_cfg.filter or []:
1213-
logger.info(f"Filtering files with regex pattern: {pattern}")
1214-
regex = re.compile(pattern)
1215-
matching_files = [f for f in matching_files if regex.search(f)]
1216-
if matching_files == []:
1217-
raise ValueError(
1218-
f"No files matching {pattern} under {s3_uri}. Modify regexp filter or path in configuration file. Abort"
1219-
)
1268+
# ---------------------------------------------------------------------
1269+
# Handle 'parquet' (single Parquet file or Hive-partitioned dataset)
1270+
# ---------------------------------------------------------------------
1271+
elif dataset_type == "parquet":
1272+
# No filters
1273+
return [s3_uri]
12201274

1221-
logger.info(f"Matched {len(matching_files)} files")
1275+
# ---------------------------------------------------------------------
1276+
# Handle 'zarr' (Zarr store)
1277+
# ---------------------------------------------------------------------
1278+
elif dataset_type == "zarr":
1279+
raise ValueError("zarr store as an input dataset is not yet implemented")
1280+
# return [s3_uri]
12221281

1223-
return matching_files
1282+
# Unsupported type
1283+
else:
1284+
raise ValueError(f"Unsupported dataset type: {dataset_type}")
1285+
1286+
1287+
def json_update(base: dict, updates: dict) -> dict:
1288+
"""Recursively update nested dictionaries."""
1289+
for k, v in updates.items():
1290+
if isinstance(v, dict) and isinstance(base.get(k), dict):
1291+
base[k] = json_update(base[k], v)
1292+
else:
1293+
base[k] = v
1294+
return base
12241295

12251296

12261297
def join_s3_uri(base_uri: str, *parts: str) -> str:

aodn_cloud_optimised/config/dataset/diver_photoquadrat_score_qc.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
"run_settings": {
66
"paths": [
77
{
8-
"s3_uri": "s3://data-uplift-public/products/reef_life_survey",
9-
"filter": [
10-
"public_reef_life_survey_2025-11-04T03:14:37\\.parquet$"
11-
],
12-
"year_range": []
8+
"type": "parquet",
9+
"s3_uri": "s3://data-uplift-public/products/reef_life_survey/public_reef_life_survey_2025-11-04T03:14:37.parquet"
1310
}
1411
],
1512
"cluster": {

aodn_cloud_optimised/lib/GenericParquetHandler.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -293,19 +293,39 @@ def preprocess_data_parquet(
293293
added by the cloud optimisation process.
294294
"""
295295

296-
# Try reading as a single Parquet file
297-
try:
298-
table = pq.read_table(parquet_fp)
299-
except (pa.ArrowInvalid, OSError):
300-
# Treat as Hive-partitioned dataset
301-
# parquet_fp is a file-like object: extract the key prefix
302-
key_prefix = parquet_fp.path # S3File objects have `.path` attribute
303-
table = pds.dataset(
304-
key_prefix,
305-
format="parquet",
306-
partitioning="hive",
307-
filesystem=self.s3_fs_output,
308-
).to_table()
296+
key_path = getattr(parquet_fp, "path", None)
297+
full_path = key_path if key_path.startswith("s3://") else f"s3://{key_path}"
298+
299+
# matching the parquet file with the correct config in the paths array
300+
matched_cfg = None
301+
for path_cfg in self.dataset_config["run_settings"]["paths"]:
302+
s3_uri = path_cfg.get("s3_uri", "").rstrip("/")
303+
if full_path.startswith(s3_uri):
304+
matched_cfg = path_cfg
305+
break
306+
307+
if matched_cfg is None:
308+
raise ValueError(f"No matching path configuration found for {full_path}")
309+
310+
partitioning = matched_cfg.get("partitioning", None)
311+
312+
match partitioning:
313+
case None:
314+
# reading as a single Parquet file
315+
table = pq.read_table(parquet_fp)
316+
317+
case "hive":
318+
key_prefix = parquet_fp.path # S3File objects have `.path` attribute
319+
table = pds.dataset(
320+
key_prefix,
321+
format="parquet",
322+
partitioning=partitioning,
323+
filesystem=self.s3_fs_output,
324+
).to_table()
325+
case _:
326+
raise ValueError(
327+
f"Partitioning value {partitioning} is not yet supported"
328+
)
309329

310330
df = table.to_pandas()
311331
df = df.drop(columns=self.drop_variables, errors="ignore")

0 commit comments

Comments
 (0)