diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 83bb0c8fa8..e172f71557 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -16,7 +16,7 @@ import typing -from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer +from .types import DataFormat, FlyteDirectory, FlyteDirToMultipartBlobTransformer, StreamingKwargs # The following section provides some predefined aliases for commonly used FlyteDirectory formats. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 699278b0b6..e55c3cdf80 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -6,25 +6,37 @@ import random import typing from dataclasses import dataclass, field +from enum import Enum from functools import partial from pathlib import Path -from typing import Any, Dict, Generator, Tuple +from typing import Annotated, Any, Dict, Generator, Tuple from uuid import UUID import fsspec +import jsonlines import msgpack from dataclasses_json import DataClassJsonMixin, config from fsspec.utils import get_protocol from google.protobuf import json_format as _json_format from google.protobuf.struct_pb2 import Struct -from marshmallow import fields +from marshmallow import fields, validate from mashumaro.types import SerializableType +from typing_extensions import get_args, get_origin from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size +from flytekit.core.type_engine import ( + AsyncTypeTransformer, + TypeEngine, + TypeTransformerFailedError, + get_batch_size, +) from flytekit.exceptions.user import FlyteAssertion -from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator +from flytekit.extras.pydantic_transformer.decorator import ( + model_serializer, + model_validator, +) +from flytekit.loggers import logger from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.core.types import BlobType @@ -32,6 +44,18 @@ from flytekit.models.types import LiteralType from flytekit.types.file import FileExt, FlyteFile +try: + import streaming # noqa: F401 + + _has_streaming = True +except ImportError: + _has_streaming = False + +if _has_streaming: + from streaming.base import MDSWriter, StreamingDataset +else: + logger.info("Streaming is unavailable.") + T = typing.TypeVar("T") PathType = typing.Union[str, os.PathLike] @@ -39,6 +63,27 @@ def noop(): ... +class DataFormat(Enum): + JSONL = "jsonl" + PARQUET = "parquet" + ARROW = "arrow" + + +@dataclass +class StreamingKwargs(DataClassJsonMixin): + shards_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict())) + stream_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict())) + data_format: str = field( + default=DataFormat.JSONL, + metadata=config(mm_field=fields.String(validate=validate.OneOf([format.value for format in DataFormat]))), + ) + + def __post_init__(self): + columns = self.shards_config.get("columns") if self.shards_config else None + if columns: + self.shards_config["columns"] = {k: v.__name__ if isinstance(v, type) else v for k, v in columns.items()} + + @dataclass class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]): path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore @@ -47,7 +92,8 @@ class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.G This class should not be used on very large datasets, as merely listing the dataset will cause the entire dataset to be downloaded. Listing on S3 and other backend object stores is not consistent - and we should not need data to be downloaded to list. + and we should not need data to be downloaded to list. If you need to work with large datasets efficiently, + consider using FlyteDirectory with **streaming** instead of downloading everything at once. Please first read through the comments on the :py:class:`flytekit.types.file.FlyteFile` class as the implementation here is similar. @@ -126,6 +172,21 @@ def t1(in1: FlyteDirectory["svg"]): The format [] bit is still there because in Flyte, directories are stored as Blob Types also, just like files, and the Blob type has the format field. The difference in the type field is represented in the ``dimensionality`` field in the ``BlobType``. + + To stream a FlyteDirectory, use the following approach: + + .. code-block:: python + + from typing import Annotated + from flytekit.types.directory.types import FlyteDirectory + + def t2(dataset: Annotated[FlyteDirectory, StreamingKwargs(shards_config={}, stream_config={})]): + # Returns an instance of a subclass of PyTorch's IterableDataset, yielding data samples as an iterator. + for i in range(dataset.num_samples): + print(dataset[i]) + + This leverages MosaicML's streaming library under the hood. + The dataset is represented as a StreamingDataset, which extends PyTorch's IterableDataset, enabling efficient, on-the-fly data loading. """ def _serialize(self) -> typing.Dict[str, str]: @@ -632,20 +693,157 @@ def wf(dc: DC): python_val = json.loads(json_str) return self.dict_to_flyte_directory(python_val, expected_python_type) + def _is_valid_jsonl_file(self, file_path: str) -> bool: + try: + with jsonlines.open(file_path) as reader: + for _ in reader: + pass + return True + except (jsonlines.InvalidLineError, UnicodeDecodeError): + return False + + def _is_valid_parquet_file(self, file_path: str) -> bool: + import pyarrow.parquet as pq + + try: + pq.ParquetFile(file_path) + return True + except Exception: + return False + + def _is_valid_arrow_file(self, file_path: str) -> bool: + import pyarrow as pa + + try: + with pa.memory_map(file_path, "r") as mmap: + pa.RecordBatchStreamReader(mmap) + return True + except Exception: + return False + + def _write_jsonl(self, out, src: str): + if self._is_valid_jsonl_file(src): + with jsonlines.open(src) as reader: + for obj in reader: + out.write(obj) + + def _process_batches(self, out, reader): + import numpy as np + + for batch in reader: + records = ( + { + name: np.array(val.as_py()) + if hasattr(val, "as_py") and isinstance(val.as_py(), list) + else val.as_py() + if hasattr(val, "as_py") + else np.array(val) + if isinstance(val, list) + else val + for name, val in zip(batch.schema.names, row) + } + for row in zip(*batch.columns) + ) + for record in records: + out.write(record) + + def _write_parquet_or_arrow(self, out, src: str, is_parquet: bool = True): + import pyarrow as pa + import pyarrow.parquet as pq + + if is_parquet: + if self._is_valid_parquet_file(src): + reader = pq.ParquetFile(src).iter_batches(batch_size=5) + self._process_batches(out, reader) + else: + if self._is_valid_arrow_file(src): + with pa.memory_map(src, "r") as mmap, pa.RecordBatchStreamReader(mmap) as reader: + self._process_batches(out, reader) + return + return + + def _create_shards( + self, ctx: FlyteContext, uri: str, fd: FlyteDirectory = None, aa: StreamingKwargs = None + ) -> typing.Union[bool, str]: + if not aa.shards_config: + return None + + aa.shards_config.setdefault("out", ctx.file_access.get_random_local_directory()) + + if aa.data_format == DataFormat.JSONL: + writer_func = self._write_jsonl + elif aa.data_format == DataFormat.PARQUET: + writer_func = self._write_parquet_or_arrow + elif aa.data_format == DataFormat.ARROW: + writer_func = partial(self._write_parquet_or_arrow, is_parquet=False) + else: + raise ValueError(f"Unsupported data format: {aa.data_format}") + + with MDSWriter(**aa.shards_config) as out: + if fd: + sources = (FlyteFile.from_source(os.path.join(base, x)).download() for base, x in fd.crawl()) + else: + sources = Path(uri).rglob(f"*.{aa.data_format.name.lower()}") + + try: + first_source = next(sources) # Check if at least one file exists + writer_func(out, str(first_source)) + for src in sources: # Continue processing the rest + writer_func(out, src) + except StopIteration: + raise ValueError(f"No {aa.data_format.name.lower()} files found in {uri}") + + return aa.shards_config["out"] + + def _process_directory_with_streaming( + self, + ctx: FlyteContext, + uri: str, + fd: FlyteDirectory = None, + base_type: typing.Type = None, + annotate_args: typing.List = None, + ) -> typing.Union[StreamingDataset, FlyteDirectory]: + for aa in annotate_args: + if isinstance(aa, StreamingKwargs): + # Process shards if configured + output_path = self._create_shards(ctx, uri, fd, aa) if aa.shards_config else None + + return StreamingDataset( + local=output_path or uri if output_path or not fd else None, + remote=uri if not output_path and fd else None, + **(aa.stream_config or {}), # Make config optional + ) + + # Return appropriate object if we didn't create a StreamingDataset + return fd or base_type(uri, remote_directory=False) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] - ) -> FlyteDirectory: + ) -> typing.Any: + base_type = None + has_streaming_kwargs = False + annotate_args = [] + + # Extract base type and annotations for StreamingKwargs + if get_origin(expected_python_type) is Annotated: + base_type, *annotate_args = get_args(expected_python_type) + has_streaming_kwargs = any(isinstance(arg, StreamingKwargs) for arg in annotate_args) + if has_streaming_kwargs and not _has_streaming: + raise TypeTransformerFailedError( + "In order to use StreamingKwargs, you need to install mosaicml-streaming first." + ) + # Handle dataclass attribute access if lv.scalar: if lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + return self.from_binary_idl(lv.scalar.binary, base_type or expected_python_type) if lv.scalar.generic: - return self.from_generic_idl(lv.scalar.generic, expected_python_type) + return self.from_generic_idl(lv.scalar.generic, base_type or expected_python_type) try: uri = lv.scalar.blob.uri except AttributeError: - raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {base_type or expected_python_type}") if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART: raise TypeTransformerFailedError(f"{lv.scalar.blob.uri} is not a directory.") @@ -653,22 +851,37 @@ async def async_to_python_value( if not ctx.file_access.is_remote(uri) and not os.path.isdir(uri): raise FlyteAssertion(f"Expected a directory, but the given uri '{uri}' is not a directory.") - # This is a local file path, like /usr/local/my_dir, don't mess with it. Certainly, downloading it doesn't - # make any sense. + # Local file path handling if not ctx.file_access.is_remote(uri): - return expected_python_type(uri, remote_directory=False) + if base_type and has_streaming_kwargs: + return self._process_directory_with_streaming( + ctx, + uri, + fd=None, + base_type=base_type, + annotate_args=annotate_args, + ) + else: + return (base_type or expected_python_type)(uri, remote_directory=False) - # For the remote case, return a FlyteDirectory object that can download + # Remote file path handling local_folder = ctx.file_access.get_random_local_directory() - - batch_size = get_batch_size(expected_python_type) - + batch_size = get_batch_size(base_type or expected_python_type) _downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size) - - expected_format = self.get_format(expected_python_type) + expected_format = self.get_format(base_type or expected_python_type) fd = FlyteDirectory.__class_getitem__(expected_format)(local_folder, _downloader) fd._remote_source = uri + + if base_type and has_streaming_kwargs: + return self._process_directory_with_streaming( + ctx, + uri, + fd=fd, + base_type=base_type, + annotate_args=annotate_args, + ) + return fd def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirectory[typing.Any]]: