Skip to content

Commit fd2e18d

Browse files
authored
Merge pull request #222 from aodn/165-rnd-create-a-co-like-partitioned-parquet-file-v2
165 rnd create a co like partitioned parquet file v2
2 parents 7b54630 + 933b960 commit fd2e18d

File tree

7 files changed

+586
-69
lines changed

7 files changed

+586
-69
lines changed

aodn_cloud_optimised/bin/create_dataset_config.py

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@
3737
import importlib.util
3838
import json
3939
import os
40+
import pathlib
4041
import uuid
4142
from collections import OrderedDict
4243
from importlib.resources import files
4344

4445
import nbformat
4546
import pandas as pd
47+
import pyarrow as pa
48+
import pyarrow.parquet as pq
4649
import s3fs
4750
from s3path import PureS3Path
4851
from termcolor import colored
@@ -734,21 +737,19 @@ def main():
734737

735738
# Handle S3 path
736739
if args.file.startswith("s3://"):
737-
nc_file = args.file
738-
p = PureS3Path.from_uri(nc_file)
739-
bucket = p.bucket
740-
obj_key = str(p.key)
740+
fp = args.file
741+
s3_path = PureS3Path.from_uri(fp)
742+
bucket = s3_path.bucket
743+
obj_key = str(s3_path.key)
741744
else:
742745
obj_key = args.file
743746
bucket = args.bucket
744-
nc_file = (
745-
PureS3Path.from_uri(f"s3://{args.bucket}").joinpath(args.file).as_uri()
746-
)
747+
fp = PureS3Path.from_uri(f"s3://{args.bucket}").joinpath(args.file).as_uri()
747748

748749
# Create an empty NetCDF with NaN variables alongside the JSON files. Acts as the source of truth for restoring missing dimensions.
749750
# only useful for Zarr to concatenate NetCDF together with missing var/dim in some NetCDF files
750751
if args.cloud_format == "zarr":
751-
nc_nullify_path = nullify_netcdf_variables(nc_file, args.dataset_name)
752+
nc_nullify_path = nullify_netcdf_variables(fp, args.dataset_name)
752753

753754
# optionals s3fs options
754755
if args.s3fs_opts:
@@ -758,43 +759,58 @@ def main():
758759
anon=False,
759760
)
760761

761-
# Generate schema based on input type (NetCDF or CSV)
762-
if obj_key.lower().endswith(".csv"):
763-
csv_file = nc_file # TODO: rename
764-
765-
csv_opts = json.loads(args.csv_opts) if args.csv_opts else {}
766-
with fs.open(csv_file, "rb") as f:
767-
df = pd.read_csv(f, **csv_opts)
768-
769-
dataset_config_schema = {"type": "object", "properties": {}}
770-
for col, dtype in df.dtypes.items():
771-
if pd.api.types.is_integer_dtype(dtype):
772-
js_type = "integer"
773-
elif pd.api.types.is_float_dtype(dtype):
774-
js_type = "number"
775-
elif pd.api.types.is_bool_dtype(dtype):
776-
js_type = "boolean"
777-
elif pd.api.types.is_object_dtype(dtype) | pd.api.types.is_string_dtype(
778-
dtype
779-
):
780-
js_type = "string"
781-
else:
782-
raise NotImplementedError(
783-
f"found dtype that did not fit into configured categories: `{dtype}`"
784-
)
762+
# Route by file type
763+
obj_key_suffix = pathlib.Path(obj_key.lower()).suffix
764+
match obj_key_suffix:
765+
case ".nc":
785766

786-
dataset_config_schema["properties"][col] = {"type": js_type}
787-
788-
elif obj_key.lower().endswith(".nc"):
789-
# Generate JSON schema from the NetCDF file
790-
temp_file_path = generate_json_schema_from_s3_netcdf(
791-
nc_file, cloud_format=args.cloud_format, s3_fs=fs
792-
)
793-
with open(temp_file_path, "r") as file:
794-
dataset_config_schema = json.load(file)
795-
os.remove(temp_file_path)
796-
else:
797-
raise NotImplementedError(f"input file type `{obj_key}` not implemented")
767+
# Generate JSON schema from the NetCDF file
768+
temp_file_path = generate_json_schema_from_s3_netcdf(
769+
fp, cloud_format=args.cloud_format, s3_fs=fs
770+
)
771+
with open(temp_file_path, "r") as file:
772+
dataset_config_schema = json.load(file)
773+
os.remove(temp_file_path)
774+
775+
case ".csv":
776+
777+
csv_opts = json.loads(args.csv_opts) if args.csv_opts else {}
778+
with fs.open(fp, "rb") as f:
779+
df = pd.read_csv(f, **csv_opts)
780+
781+
dataset_config_schema = {"type": "object", "properties": {}}
782+
for col, dtype in df.dtypes.items():
783+
if pd.api.types.is_integer_dtype(dtype):
784+
js_type = "integer"
785+
elif pd.api.types.is_float_dtype(dtype):
786+
js_type = "number"
787+
elif pd.api.types.is_bool_dtype(dtype):
788+
js_type = "boolean"
789+
elif pd.api.types.is_object_dtype(dtype) | pd.api.types.is_string_dtype(
790+
dtype
791+
):
792+
js_type = "string"
793+
else:
794+
raise NotImplementedError(
795+
f"found dtype that did not fit into configured categories: `{dtype}`"
796+
)
797+
798+
dataset_config_schema["properties"][col] = {"type": js_type}
799+
800+
case ".parquet":
801+
802+
with fs.open(fp, "rb") as f:
803+
schema = pq.read_schema(f)
804+
dataset_config_schema = dict()
805+
806+
for field in schema:
807+
dataset_config_schema[field.name] = {"type": str(field.type)}
808+
809+
# Default: Raise NotImplemented
810+
case _:
811+
raise NotImplementedError(
812+
f"input file type `{obj_key_suffix}` not implemented"
813+
)
798814

799815
dataset_config = {"schema": dataset_config_schema}
800816
# Define the path to the validation schema file
@@ -835,7 +851,7 @@ def main():
835851
"mode": f"{TO_REPLACE_PLACEHOLDER}",
836852
"restart_every_path": False,
837853
}
838-
parent_s3_path = PureS3Path.from_uri(nc_file).parent.as_uri()
854+
parent_s3_path = PureS3Path.from_uri(fp).parent.as_uri()
839855
dataset_config["run_settings"]["paths"] = [
840856
{"s3_uri": parent_s3_path, "filter": [".*\\.nc"], "year_range": []}
841857
]
@@ -941,9 +957,7 @@ def main():
941957
with open(f"{module_path}/config/dataset/{args.dataset_name}.json", "w") as f:
942958
json.dump(dataset_config, f, indent=2)
943959

944-
create_dataset_script(
945-
args.dataset_name, f"{args.dataset_name}.json", nc_file, bucket
946-
)
960+
create_dataset_script(args.dataset_name, f"{args.dataset_name}.json", fp, bucket)
947961
update_pyproject_toml(args.dataset_name)
948962

949963
# fill up aws registry with GN3 uuid

aodn_cloud_optimised/lib/GenericParquetHandler.py

Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gc
22
import importlib.resources
33
import os
4+
import pathlib
45
import re
56
import timeit
67
import traceback
@@ -15,6 +16,7 @@
1516
import pandas as pd
1617
import pyarrow as pa
1718
import pyarrow.parquet as pq
19+
import s3fs.core
1820
import xarray as xr
1921
from dask.distributed import wait
2022
from shapely.geometry import Point, Polygon
@@ -226,8 +228,60 @@ def preprocess_data_netcdf(
226228
f"{self.uuid_log}: The NetCDF file does not conform to the pre-defined schema."
227229
)
228230

231+
def preprocess_data_parquet(
232+
self, parquet_fp
233+
) -> Generator[Tuple[pd.DataFrame, xr.Dataset], None, None]:
234+
"""
235+
Preprocesses a parquet file using pyarrow and converts it into an xarray Dataset based on the dataset configuration.
236+
237+
Args:
238+
parquet_fp (str or s3fs.core.S3File): File path or s3fs object of the parquet file to be processed.
239+
240+
Yields:
241+
Tuple[pd.DataFrame, xr.Dataset]: A generator yielding a tuple containing the processed pandas DataFrame
242+
and its corresponding xarray Dataset.
243+
244+
This method reads a parquet file(`parquet_fp`) using pyarrow.parquet `read_table` function.
245+
246+
The resultin DataFrame (`df`) is then converted into an xarray Dataset using `xr.Dataset.from_dataframe()`.
247+
248+
# TODO: Document `pq.read_table` options
249+
250+
The method also uses the 'schema' from the dataset configuration to assign attributes to variables in the
251+
xarray Dataset. Each variable's attributes are extracted from the 'schema' and assigned to the Dataset variable's
252+
attributes. The 'type' attribute from the `pyarrow_schema` is removed from the Dataset variables' attributes since it
253+
is considered unnecessary.
254+
255+
If a variable in the Dataset is not found in the schema, an error is logged.
256+
257+
Notes:
258+
Ensure that the config schema includes a column named "index" of type int64. When the internal conversions
259+
occur between xarray, pandas and pyarrow, an "index" column is added to the pyarrow table. Rather than
260+
detect when "index" should not have been added, it is easier to add "index" as an expected column that is
261+
added by the cloud optimisation process.
262+
"""
263+
264+
table = pq.read_table(parquet_fp)
265+
df = table.to_pandas()
266+
df = df.drop(columns=self.drop_variables, errors="ignore")
267+
ds = xr.Dataset.from_dataframe(df)
268+
269+
for var in ds.variables:
270+
if var not in self.schema:
271+
self.logger.error(
272+
f"{self.uuid_log}: Missing variable: {var} from dataset config"
273+
)
274+
else:
275+
ds[var].attrs = self.schema.get(var)
276+
del ds[var].attrs[
277+
"type"
278+
] # remove the type attribute which is not necessary at all
279+
280+
yield df, ds
281+
229282
def preprocess_data(
230-
self, fp
283+
self,
284+
fp: str | s3fs.core.S3File,
231285
) -> Generator[Tuple[pd.DataFrame, xr.Dataset], None, None]:
232286
"""
233287
Overwrites the preprocess_data method from CommonHandler.
@@ -239,12 +293,31 @@ def preprocess_data(
239293
tuple: A tuple containing DataFrame and Dataset.
240294
241295
If `fp` ends with ".nc", it delegates to `self.preprocess_data_netcdf(fp)`.
242-
If `fp` ends with ".csv", it delegates to `self.preprocess_data_csv(fp)`.
296+
Elif `fp` ends with ".csv", it delegates to `self.preprocess_data_csv(fp)`.
297+
Elif `fp` ends with ".parquet", it delegates to `self.preprocess_data_parquet(fp)`.
298+
Else raises a NotImplementedError
299+
300+
Raises:
301+
NotImplementedError: Where the file type is not yet implemented
243302
"""
244-
if fp.path.endswith(".nc"):
245-
return self.preprocess_data_netcdf(fp)
246-
if fp.path.endswith(".csv"):
247-
return self.preprocess_data_csv(fp)
303+
# Extract file suffix
304+
if isinstance(fp, str):
305+
file_suffix = pathlib.Path(fp).suffix
306+
elif isinstance(fp, s3fs.core.S3File):
307+
file_suffix = pathlib.Path(fp.path).suffix
308+
309+
# Match preprocess method
310+
match file_suffix.lower():
311+
case ".nc":
312+
return self.preprocess_data_netcdf(fp)
313+
case ".csv":
314+
return self.preprocess_data_csv(fp)
315+
case ".parquet":
316+
return self.preprocess_data_parquet(fp)
317+
case _:
318+
raise NotImplementedError(
319+
f"files with suffix `{file_suffix}` not yet implemented in preprocess_data"
320+
)
248321

249322
@staticmethod
250323
def cast_table_by_schema(table, schema) -> pa.Table:
@@ -396,9 +469,9 @@ def _add_polygon(self, df: pd.DataFrame) -> pd.DataFrame:
396469
self.logger.warning(
397470
f"{self.uuid_log}: The NetCDF contains NaN values of {geo_var}. Removing corresponding data"
398471
)
399-
df = df.dropna(
400-
subset=[geo_var]
401-
).reset_index() # .reset_index(drop=True)
472+
df = df.dropna(subset=[geo_var]).reset_index(
473+
drop=False
474+
) # For now leaving drop false to ensure no breaking changes
402475

403476
point_geometry = [
404477
Point(lon, lat) for lon, lat in zip(df[lon_varname], df[lat_varname])
@@ -451,9 +524,11 @@ def _add_timestamp_df(self, df: pd.DataFrame, f) -> pd.DataFrame:
451524
if item.get("time_extent") is not None:
452525
timestamp_info = item
453526

527+
# Extract time partition information
454528
timestamp_varname = timestamp_info.get("source_variable")
455529
time_varname = timestamp_info["time_extent"].get("time_varname", "TIME")
456530
partition_period = timestamp_info["time_extent"].get("partition_period")
531+
457532
# look for the variable or column with datetime64 type
458533
if isinstance(df.index, pd.MultiIndex) and (time_varname in df.index.names):
459534
# for example, files with timeSeries and TIME dimensions such as
@@ -476,6 +551,37 @@ def _add_timestamp_df(self, df: pd.DataFrame, f) -> pd.DataFrame:
476551
if pd.api.types.is_datetime64_any_dtype(df.index):
477552
datetime_var = df.index
478553

554+
# Finally attempt to validate the defined time partition column
555+
if "datetime_var" not in locals():
556+
557+
# Else look for the time columns with a different time related dtype
558+
time_partition_column = df[time_varname]
559+
560+
# Validate no missing values
561+
if time_partition_column.isnull().any():
562+
raise ValueError(
563+
"time partition column may not contain null values"
564+
)
565+
566+
# Validate that the time partition column translated via pd.to_datetime
567+
try:
568+
pd.to_datetime(time_partition_column)
569+
except Exception as e:
570+
raise ValueError(
571+
"time partition column failed to translate to pandas datetime dtype: {e}"
572+
)
573+
574+
# Because the df does not have a date time index, we have to create and fill the column in separately here
575+
datetime_index = pd.DatetimeIndex(pd.to_datetime(time_partition_column))
576+
df[timestamp_varname] = (
577+
np.int64(datetime_index.to_period(partition_period).to_timestamp())
578+
/ 10**9
579+
)
580+
return df
581+
582+
if "datetime_var" not in locals():
583+
raise ValueError("could not determine the datetime column/variable")
584+
479585
if not isinstance(df.index, pd.MultiIndex) and (time_varname in df.index.names):
480586
today = datetime.today()
481587
# assume that todays year + 1 is the future, and no in-situ data should be in the future, since we're not dealing
@@ -661,13 +767,16 @@ def _rm_bad_timestamp_df(self, df: pd.DataFrame, f) -> pd.DataFrame:
661767
timestamp_varname = timestamp_info.get("source_variable")
662768
time_varname = timestamp_info["time_extent"].get("time_varname", "TIME")
663769

664-
if any(df[timestamp_varname] <= 0):
770+
# Check any timestamps are before `1900-01-01 00:00:00`
771+
if any(df[timestamp_varname] < -2208988800):
665772
self.logger.warning(
666773
f"{self.uuid_log}: {f.path}: Bad values detected in {time_varname} time variable. Trimming corresponding data."
667774
)
668-
df2 = df[df[timestamp_varname] > 0].copy()
775+
df2 = df[df[timestamp_varname] >= -2208988800].copy()
669776
df = df2
670-
df = df.reset_index()
777+
df = df.reset_index(
778+
drop=False
779+
) # For now leaving drop false to ensure no breaking changes
671780

672781
if df.empty:
673782
self.logger.error(
@@ -788,6 +897,28 @@ def check_var_attributes(self, ds):
788897
else:
789898
return True
790899

900+
def validate_dataset_dimensions(self, ds: xr.Dataset) -> None:
901+
"""Validate that all dataset dimensions have corresponding variables as defined in the schema.
902+
For each dimension present in the dataset (TIME, LATITUDE, LONGITUDE), this function checks whether the
903+
dimension is declared in ``dataset_config["schema"]``. If it is, it ensures
904+
that a variable of the same name exists in the dataset (For example, dimension such as id won't be defined). If a required
905+
variable is missing, a ``ValueError`` is raised.
906+
Args:
907+
ds: The xarray Dataset to validate.
908+
dataset_config: Configuration dictionary containing a ``"schema"`` key
909+
mapping variable names to their definitions.
910+
Raises:
911+
ValueError: If a dimension is defined in the schema but the corresponding
912+
variable is missing in the dataset.
913+
"""
914+
schema = self.dataset_config.get("schema", {})
915+
916+
for dim in ds.dims:
917+
if dim in schema and dim not in ds.variables:
918+
raise ValueError(
919+
f"{self.uuid_log}: Dimension '{dim}' is defined in schema but missing as a variable in dataset."
920+
)
921+
791922
def publish_cloud_optimised(
792923
self, df: pd.DataFrame, ds: xr.Dataset, s3_file_handle
793924
) -> None:
@@ -805,6 +936,7 @@ def publish_cloud_optimised(
805936
x["source_variable"]
806937
for x in self.dataset_config["schema_transformation"]["partitioning"]
807938
]
939+
self.validate_dataset_dimensions(ds)
808940
df = self._fix_datetimejulian(df)
809941
df = self._add_timestamp_df(df, s3_file_handle)
810942
df = self._add_columns_df(df, ds, s3_file_handle)

0 commit comments

Comments
 (0)