Skip to content

mosaic data streaming integration #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/types/directory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import typing

from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from .types import DataFormat, FlyteDirectory, FlyteDirToMultipartBlobTransformer, StreamingKwargs

# The following section provides some predefined aliases for commonly used FlyteDirectory formats.

Expand Down
249 changes: 231 additions & 18 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,84 @@
import random
import typing
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
from typing import Annotated, Any, Dict, Generator, Tuple
from uuid import UUID

import fsspec
import jsonlines
import msgpack
from dataclasses_json import DataClassJsonMixin, config
from fsspec.utils import get_protocol
from google.protobuf import json_format as _json_format
from google.protobuf.struct_pb2 import Struct
from marshmallow import fields
from marshmallow import fields, validate
from mashumaro.types import SerializableType
from typing_extensions import get_args, get_origin

from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size
from flytekit.core.type_engine import (
AsyncTypeTransformer,
TypeEngine,
TypeTransformerFailedError,
get_batch_size,
)
from flytekit.exceptions.user import FlyteAssertion
from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator
from flytekit.extras.pydantic_transformer.decorator import (
model_serializer,
model_validator,
)
from flytekit.loggers import logger
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Binary, Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file import FileExt, FlyteFile

try:
import streaming # noqa: F401

_has_streaming = True

Check warning on line 50 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L50

Added line #L50 was not covered by tests
except ImportError:
_has_streaming = False

if _has_streaming:
from streaming.base import MDSWriter, StreamingDataset

Check warning on line 55 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L55

Added line #L55 was not covered by tests
else:
logger.info("Streaming is unavailable.")

T = typing.TypeVar("T")
PathType = typing.Union[str, os.PathLike]


def noop(): ...


class DataFormat(Enum):
JSONL = "jsonl"
PARQUET = "parquet"
ARROW = "arrow"


@dataclass
class StreamingKwargs(DataClassJsonMixin):
shards_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict()))
stream_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict()))
data_format: str = field(
default=DataFormat.JSONL,
metadata=config(mm_field=fields.String(validate=validate.OneOf([format.value for format in DataFormat]))),
)

def __post_init__(self):
columns = self.shards_config.get("columns") if self.shards_config else None

Check warning on line 82 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L82

Added line #L82 was not covered by tests
if columns:
self.shards_config["columns"] = {k: v.__name__ if isinstance(v, type) else v for k, v in columns.items()}

Check warning on line 84 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L84

Added line #L84 was not covered by tests


@dataclass
class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]):
path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
Expand All @@ -47,7 +92,8 @@

This class should not be used on very large datasets, as merely listing the dataset will cause
the entire dataset to be downloaded. Listing on S3 and other backend object stores is not consistent
and we should not need data to be downloaded to list.
and we should not need data to be downloaded to list. If you need to work with large datasets efficiently,
consider using FlyteDirectory with **streaming** instead of downloading everything at once.

Please first read through the comments on the :py:class:`flytekit.types.file.FlyteFile` class as the
implementation here is similar.
Expand Down Expand Up @@ -126,6 +172,21 @@
The format [] bit is still there because in Flyte, directories are stored as Blob Types also, just like files, and
the Blob type has the format field. The difference in the type field is represented in the ``dimensionality``
field in the ``BlobType``.

To stream a FlyteDirectory, use the following approach:

.. code-block:: python

from typing import Annotated
from flytekit.types.directory.types import FlyteDirectory

def t2(dataset: Annotated[FlyteDirectory, StreamingKwargs(shards_config={}, stream_config={})]):
# Returns an instance of a subclass of PyTorch's IterableDataset, yielding data samples as an iterator.
for i in range(dataset.num_samples):
print(dataset[i])

This leverages MosaicML's streaming library under the hood.
The dataset is represented as a StreamingDataset, which extends PyTorch's IterableDataset, enabling efficient, on-the-fly data loading.
"""

def _serialize(self) -> typing.Dict[str, str]:
Expand Down Expand Up @@ -632,43 +693,195 @@
python_val = json.loads(json_str)
return self.dict_to_flyte_directory(python_val, expected_python_type)

def _is_valid_jsonl_file(self, file_path: str) -> bool:
try:
with jsonlines.open(file_path) as reader:

Check warning on line 698 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L697-L698

Added lines #L697 - L698 were not covered by tests
for _ in reader:
pass
return True
except (jsonlines.InvalidLineError, UnicodeDecodeError):
return False

Check warning on line 703 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L700-L703

Added lines #L700 - L703 were not covered by tests

def _is_valid_parquet_file(self, file_path: str) -> bool:
import pyarrow.parquet as pq

Check warning on line 706 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L706

Added line #L706 was not covered by tests

try:
pq.ParquetFile(file_path)
return True
except Exception:
return False

Check warning on line 712 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L708-L712

Added lines #L708 - L712 were not covered by tests

def _is_valid_arrow_file(self, file_path: str) -> bool:
import pyarrow as pa

Check warning on line 715 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L715

Added line #L715 was not covered by tests

try:
with pa.memory_map(file_path, "r") as mmap:
pa.RecordBatchStreamReader(mmap)
return True
except Exception:
return False

Check warning on line 722 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L717-L722

Added lines #L717 - L722 were not covered by tests

def _write_jsonl(self, out, src: str):
if self._is_valid_jsonl_file(src):
with jsonlines.open(src) as reader:

Check warning on line 726 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L726

Added line #L726 was not covered by tests
for obj in reader:
out.write(obj)

Check warning on line 728 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L728

Added line #L728 was not covered by tests

def _process_batches(self, out, reader):
import numpy as np

Check warning on line 731 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L731

Added line #L731 was not covered by tests

for batch in reader:
records = (

Check warning on line 734 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L734

Added line #L734 was not covered by tests
{
name: np.array(val.as_py())
if hasattr(val, "as_py") and isinstance(val.as_py(), list)
else val.as_py()
if hasattr(val, "as_py")
else np.array(val)
if isinstance(val, list)
else val
for name, val in zip(batch.schema.names, row)
}
for row in zip(*batch.columns)
)
for record in records:
out.write(record)

Check warning on line 748 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L748

Added line #L748 was not covered by tests

def _write_parquet_or_arrow(self, out, src: str, is_parquet: bool = True):
import pyarrow as pa
import pyarrow.parquet as pq

Check warning on line 752 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L751-L752

Added lines #L751 - L752 were not covered by tests

if is_parquet:
if self._is_valid_parquet_file(src):
reader = pq.ParquetFile(src).iter_batches(batch_size=5)
self._process_batches(out, reader)

Check warning on line 757 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L756-L757

Added lines #L756 - L757 were not covered by tests
else:
if self._is_valid_arrow_file(src):
with pa.memory_map(src, "r") as mmap, pa.RecordBatchStreamReader(mmap) as reader:
self._process_batches(out, reader)
return
return

Check warning on line 763 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L760-L763

Added lines #L760 - L763 were not covered by tests

def _create_shards(
self, ctx: FlyteContext, uri: str, fd: FlyteDirectory = None, aa: StreamingKwargs = None
) -> typing.Union[bool, str]:
if not aa.shards_config:
return None

Check warning on line 769 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L769

Added line #L769 was not covered by tests

aa.shards_config.setdefault("out", ctx.file_access.get_random_local_directory())

Check warning on line 771 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L771

Added line #L771 was not covered by tests

if aa.data_format == DataFormat.JSONL:
writer_func = self._write_jsonl

Check warning on line 774 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L774

Added line #L774 was not covered by tests
elif aa.data_format == DataFormat.PARQUET:
writer_func = self._write_parquet_or_arrow

Check warning on line 776 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L776

Added line #L776 was not covered by tests
elif aa.data_format == DataFormat.ARROW:
writer_func = partial(self._write_parquet_or_arrow, is_parquet=False)

Check warning on line 778 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L778

Added line #L778 was not covered by tests
else:
raise ValueError(f"Unsupported data format: {aa.data_format}")

Check warning on line 780 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L780

Added line #L780 was not covered by tests

with MDSWriter(**aa.shards_config) as out:

Check warning on line 782 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L782

Added line #L782 was not covered by tests
if fd:
sources = (FlyteFile.from_source(os.path.join(base, x)).download() for base, x in fd.crawl())

Check warning on line 784 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L784

Added line #L784 was not covered by tests
else:
sources = Path(uri).rglob(f"*.{aa.data_format.name.lower()}")

Check warning on line 786 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L786

Added line #L786 was not covered by tests

try:
first_source = next(sources) # Check if at least one file exists
writer_func(out, str(first_source))

Check warning on line 790 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L788-L790

Added lines #L788 - L790 were not covered by tests
for src in sources: # Continue processing the rest
writer_func(out, src)
except StopIteration:
raise ValueError(f"No {aa.data_format.name.lower()} files found in {uri}")

Check warning on line 794 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L792-L794

Added lines #L792 - L794 were not covered by tests

return aa.shards_config["out"]

Check warning on line 796 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L796

Added line #L796 was not covered by tests

def _process_directory_with_streaming(
self,
ctx: FlyteContext,
uri: str,
fd: FlyteDirectory = None,
base_type: typing.Type = None,
annotate_args: typing.List = None,
) -> typing.Union[StreamingDataset, FlyteDirectory]:
for aa in annotate_args:
if isinstance(aa, StreamingKwargs):
# Process shards if configured
output_path = self._create_shards(ctx, uri, fd, aa) if aa.shards_config else None

Check warning on line 809 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L809

Added line #L809 was not covered by tests

return StreamingDataset(

Check warning on line 811 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L811

Added line #L811 was not covered by tests
local=output_path or uri if output_path or not fd else None,
remote=uri if not output_path and fd else None,
**(aa.stream_config or {}), # Make config optional
)

# Return appropriate object if we didn't create a StreamingDataset
return fd or base_type(uri, remote_directory=False)

Check warning on line 818 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L818

Added line #L818 was not covered by tests

async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory]
) -> FlyteDirectory:
) -> typing.Any:
base_type = None
has_streaming_kwargs = False
annotate_args = []

# Extract base type and annotations for StreamingKwargs
if get_origin(expected_python_type) is Annotated:
base_type, *annotate_args = get_args(expected_python_type)
has_streaming_kwargs = any(isinstance(arg, StreamingKwargs) for arg in annotate_args)

Check warning on line 830 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L829-L830

Added lines #L829 - L830 were not covered by tests
if has_streaming_kwargs and not _has_streaming:
raise TypeTransformerFailedError(

Check warning on line 832 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L832

Added line #L832 was not covered by tests
"In order to use StreamingKwargs, you need to install mosaicml-streaming first."
)

# Handle dataclass attribute access
if lv.scalar:
if lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)
return self.from_binary_idl(lv.scalar.binary, base_type or expected_python_type)
if lv.scalar.generic:
return self.from_generic_idl(lv.scalar.generic, expected_python_type)
return self.from_generic_idl(lv.scalar.generic, base_type or expected_python_type)

try:
uri = lv.scalar.blob.uri
except AttributeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {base_type or expected_python_type}")

Check warning on line 846 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L846

Added line #L846 was not covered by tests

if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART:
raise TypeTransformerFailedError(f"{lv.scalar.blob.uri} is not a directory.")

if not ctx.file_access.is_remote(uri) and not os.path.isdir(uri):
raise FlyteAssertion(f"Expected a directory, but the given uri '{uri}' is not a directory.")

# This is a local file path, like /usr/local/my_dir, don't mess with it. Certainly, downloading it doesn't
# make any sense.
# Local file path handling
if not ctx.file_access.is_remote(uri):
return expected_python_type(uri, remote_directory=False)
if base_type and has_streaming_kwargs:
return self._process_directory_with_streaming(

Check warning on line 857 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L857

Added line #L857 was not covered by tests
ctx,
uri,
fd=None,
base_type=base_type,
annotate_args=annotate_args,
)
else:
return (base_type or expected_python_type)(uri, remote_directory=False)

# For the remote case, return a FlyteDirectory object that can download
# Remote file path handling
local_folder = ctx.file_access.get_random_local_directory()

batch_size = get_batch_size(expected_python_type)

batch_size = get_batch_size(base_type or expected_python_type)
_downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size)

expected_format = self.get_format(expected_python_type)
expected_format = self.get_format(base_type or expected_python_type)

fd = FlyteDirectory.__class_getitem__(expected_format)(local_folder, _downloader)
fd._remote_source = uri

if base_type and has_streaming_kwargs:
return self._process_directory_with_streaming(

Check warning on line 877 in flytekit/types/directory/types.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/directory/types.py#L877

Added line #L877 was not covered by tests
ctx,
uri,
fd=fd,
base_type=base_type,
annotate_args=annotate_args,
)

return fd

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirectory[typing.Any]]:
Expand Down
Loading