|
3 | 3 | import logging |
4 | 4 | from typing import TYPE_CHECKING, Any |
5 | 5 |
|
| 6 | +from daft.dependencies import np, torch |
| 7 | +from daft.runners.partitioning import MaterializedResult |
| 8 | + |
6 | 9 | if TYPE_CHECKING: |
7 | 10 | from collections.abc import Iterable, Iterator |
8 | 11 |
|
| 12 | + from daft.dataframe.dataframe import DataFrame |
| 13 | + from daft.recordbatch import MicroPartition |
| 14 | + |
9 | 15 | logger = logging.getLogger(__name__) |
10 | 16 |
|
11 | 17 | try: |
@@ -47,3 +53,84 @@ def __init__(self, iterable: Iterable[dict[str, Any]]): |
47 | 53 |
|
48 | 54 | def __iter__(self) -> Iterator[dict[str, Any]]: |
49 | 55 | return iter(self.iterable) |
| 56 | + |
| 57 | + |
| 58 | +class DaftTorchDataLoader: |
| 59 | + """Streams batched partitions from a Daft DataFrame and yields PyTorch-ready batch dicts. |
| 60 | +
|
| 61 | + Note: |
| 62 | + This simulates the behavior of a PyTorch DataLoader, but does not use the DataLoader class itself. |
| 63 | + If the underlying DataFrame is already materialized, it will reuse the existing data. |
| 64 | + """ |
| 65 | + |
| 66 | + def __init__( |
| 67 | + self, |
| 68 | + df: DataFrame, |
| 69 | + batch_size: int = 1, |
| 70 | + *, |
| 71 | + pin_memory: bool = False, |
| 72 | + pin_memory_device: str = "", |
| 73 | + prefetch_count: int = 0, |
| 74 | + # TODO: Add support for drop_last when we have strict into_batches |
| 75 | + ) -> None: |
| 76 | + if batch_size <= 0: |
| 77 | + raise ValueError("batch_size must be greater than 0") |
| 78 | + |
| 79 | + self._batch_size = batch_size |
| 80 | + self._pin_memory = pin_memory |
| 81 | + self._pin_memory_device = pin_memory_device if pin_memory_device else None |
| 82 | + self._prefetch_count = prefetch_count |
| 83 | + |
| 84 | + self._batched_df = df.into_batches(batch_size) |
| 85 | + |
| 86 | + def __iter__(self) -> Iterator[dict[str, Any]]: |
| 87 | + from daft.runners import get_or_create_runner |
| 88 | + |
| 89 | + results = self._batched_df._result |
| 90 | + if results is not None: |
| 91 | + for _, mat_result in results.items(): |
| 92 | + batch = self._to_torch_batch(mat_result.micropartition()) |
| 93 | + if batch is not None: |
| 94 | + yield batch |
| 95 | + else: |
| 96 | + buffer_size = self._prefetch_count if self._prefetch_count > 0 else None |
| 97 | + partitions_iter: Iterator[MaterializedResult[Any]] = get_or_create_runner().run_iter( |
| 98 | + self._batched_df._builder, results_buffer_size=buffer_size |
| 99 | + ) |
| 100 | + for mat_result in partitions_iter: |
| 101 | + batch = self._to_torch_batch(mat_result.micropartition()) |
| 102 | + if batch is not None: |
| 103 | + yield batch |
| 104 | + |
| 105 | + def _to_torch_batch(self, batch: MicroPartition) -> dict[str, Any]: |
| 106 | + return {key: self._column_to_tensor(values) for key, values in batch.to_pydict().items()} |
| 107 | + |
| 108 | + def _column_to_tensor(self, values: list[Any]) -> Any: |
| 109 | + if len(values) == 0: |
| 110 | + return self._pin(torch.tensor([])) |
| 111 | + |
| 112 | + first = values[0] |
| 113 | + |
| 114 | + if isinstance(first, torch.Tensor): |
| 115 | + return self._pin(torch.stack(values)) |
| 116 | + if hasattr(first, "__array__") and not isinstance(first, (str, bytes)): |
| 117 | + if isinstance(first, np.ndarray) and first.ndim > 0: |
| 118 | + return self._pin(torch.stack([torch.as_tensor(v) for v in values])) |
| 119 | + return self._pin(torch.as_tensor(values)) |
| 120 | + if isinstance(first, (bool, int, float)): |
| 121 | + return self._pin(torch.as_tensor(values)) |
| 122 | + |
| 123 | + return values |
| 124 | + |
| 125 | + def _pin(self, tensor: torch.Tensor) -> torch.Tensor: |
| 126 | + if not self._pin_memory: |
| 127 | + return tensor |
| 128 | + |
| 129 | + # Pinned host memory is only used for async CPU -> CUDA copies. |
| 130 | + if not torch.cuda.is_available(): |
| 131 | + return tensor |
| 132 | + |
| 133 | + # If a specific device is provided, use it. Otherwise, use the default device. |
| 134 | + if self._pin_memory_device: |
| 135 | + return tensor.pin_memory(device=self._pin_memory_device) |
| 136 | + return tensor.pin_memory() |
0 commit comments