|
21 | 21 | import pyarrow as pa |
22 | 22 |
|
23 | 23 | from .. import config |
| 24 | +from ..utils.py_utils import map_nested |
24 | 25 | from .formatting import TensorFormatter |
25 | 26 |
|
26 | 27 |
|
27 | 28 | if TYPE_CHECKING: |
28 | 29 | import torch |
29 | 30 |
|
30 | | -# Import torch once at module level once |
31 | | -try: |
32 | | - import torch |
33 | | - |
34 | | - _torch_available = True |
35 | | -except ImportError: |
36 | | - _torch_available = False |
37 | | - torch = None |
38 | | - |
39 | 31 |
|
40 | 32 | class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]): |
41 | 33 | def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs): |
42 | 34 | super().__init__(features=features, token_per_repo_id=token_per_repo_id) |
43 | 35 | self.torch_tensor_kwargs = torch_tensor_kwargs |
44 | | - |
45 | | - if not _torch_available: |
46 | | - raise ImportError("PyTorch is required but not available") |
| 36 | + import torch # noqa import torch at initialization |
47 | 37 |
|
48 | 38 | def _consolidate(self, column): |
49 | | - """Smarter consolidation that only stacks when safe and beneficial.""" |
50 | | - if not isinstance(column, list) or not column: |
51 | | - return column |
52 | | - |
53 | | - # Check if all items are tensors with matching properties |
54 | | - first = column[0] |
55 | | - if not isinstance(first, torch.Tensor): |
56 | | - return column |
57 | | - |
58 | | - # Fast check: if all tensors have same shape, dtype, and device, we can stack |
59 | | - if all( |
60 | | - isinstance(x, torch.Tensor) |
61 | | - and x.shape == first.shape |
62 | | - and x.dtype == first.dtype |
63 | | - and x.device == first.device |
64 | | - for x in column |
65 | | - ): |
66 | | - return torch.stack(column) |
67 | | - |
| 39 | + import torch |
| 40 | + |
| 41 | + if isinstance(column, list) and column: |
| 42 | + if all( |
| 43 | + isinstance(x, torch.Tensor) and x.shape == column[0].shape and x.dtype == column[0].dtype |
| 44 | + for x in column |
| 45 | + ): |
| 46 | + return torch.stack(column) |
68 | 47 | return column |
69 | 48 |
|
70 | 49 | def _tensorize(self, value): |
71 | | - """Zero/low-copy tensor conversion with smart dtype handling.""" |
72 | | - # Fast path for strings, bytes, None |
| 50 | + import torch |
| 51 | + |
73 | 52 | if isinstance(value, (str, bytes, type(None))): |
74 | 53 | return value |
75 | | - |
76 | | - # Handle string arrays |
77 | | - if isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character): |
| 54 | + elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character): |
78 | 55 | return value.tolist() |
79 | 56 |
|
80 | | - # PIL Image fast path - avoid extra copies |
| 57 | + default_dtype = {} |
| 58 | + |
| 59 | + if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer): |
| 60 | + default_dtype = {"dtype": torch.int64} |
| 61 | + |
| 62 | + # Convert dtype to np.int64 if it's either np.uint16 or np.uint32 to ensure compatibility. |
| 63 | + # np.uint64 is excluded from this conversion as there is no compatible PyTorch dtype that can handle it without loss. |
| 64 | + if value.dtype in [np.uint16, np.uint32]: |
| 65 | + value = value.astype(np.int64) |
| 66 | + |
| 67 | + elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): |
| 68 | + default_dtype = {"dtype": torch.float32} |
| 69 | + |
81 | 70 | if config.PIL_AVAILABLE and "PIL" in sys.modules: |
82 | 71 | import PIL.Image |
83 | 72 |
|
84 | 73 | if isinstance(value, PIL.Image.Image): |
85 | | - # Single conversion path: PIL -> numpy -> torch |
86 | | - arr = np.asarray(value) |
87 | | - if arr.ndim == 2: |
88 | | - arr = arr[:, :, np.newaxis] |
89 | | - # Use moveaxis instead of transpose |
90 | | - arr = np.moveaxis(arr, -1, 0) # HWC -> CHW |
91 | | - # Ensure contiguous for zero-copy conversion |
92 | | - if not arr.flags.c_contiguous: |
93 | | - arr = np.ascontiguousarray(arr) |
94 | | - # Ensure array is writable for torch conversion |
95 | | - if not arr.flags.writeable: |
96 | | - arr = arr.copy() |
97 | | - return torch.from_numpy(arr) |
| 74 | + value = np.asarray(value) |
| 75 | + if value.ndim == 2: |
| 76 | + value = value[:, :, np.newaxis] |
98 | 77 |
|
99 | | - # Video/Audio decoder passthrough |
| 78 | + value = value.transpose((2, 0, 1)) |
100 | 79 | if config.TORCHVISION_AVAILABLE and "torchvision" in sys.modules: |
101 | 80 | from torchvision.io import VideoReader |
102 | 81 |
|
103 | 82 | if isinstance(value, VideoReader): |
104 | | - return value |
105 | | - |
| 83 | + return value # TODO(QL): set output to torch tensors ? |
106 | 84 | if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules: |
107 | 85 | from torchcodec.decoders import AudioDecoder, VideoDecoder |
108 | 86 |
|
109 | 87 | if isinstance(value, (VideoDecoder, AudioDecoder)): |
110 | | - return value |
111 | | - |
112 | | - # Support for other tensor libraries via __array__ |
113 | | - if hasattr(value, "__array__") and not isinstance(value, torch.Tensor): |
114 | | - value = value.__array__() |
115 | | - |
116 | | - # Fast numpy conversion paths |
117 | | - if isinstance(value, np.ndarray): |
118 | | - # Handle integer types with smart casting |
119 | | - if np.issubdtype(value.dtype, np.integer): |
120 | | - # Check if user specified a dtype, otherwise default to int64 |
121 | | - kwargs = self.torch_tensor_kwargs.copy() |
122 | | - target_dtype = kwargs.get("dtype", torch.int64) |
123 | | - |
124 | | - # Safe casting for unsigned types |
125 | | - if value.dtype in (np.uint16, np.uint32): |
126 | | - # Cast to int64 in numpy (fast) then convert to torch |
127 | | - value = value.astype(np.int64) |
128 | | - if target_dtype == torch.int64: |
129 | | - if not value.flags.writeable: |
130 | | - value = value.copy() |
131 | | - return torch.from_numpy(value) |
132 | | - else: |
133 | | - if not value.flags.writeable: |
134 | | - value = value.copy() |
135 | | - kwargs.setdefault("dtype", target_dtype) |
136 | | - return torch.as_tensor(value, **kwargs) |
137 | | - elif value.dtype == np.uint64: |
138 | | - # Check if values fit in int64 range |
139 | | - if np.all(value <= np.iinfo(np.int64).max): |
140 | | - value = value.astype(np.int64) |
141 | | - if target_dtype == torch.int64: |
142 | | - if not value.flags.writeable: |
143 | | - value = value.copy() |
144 | | - return torch.from_numpy(value) |
145 | | - else: |
146 | | - if not value.flags.writeable: |
147 | | - value = value.copy() |
148 | | - kwargs.setdefault("dtype", target_dtype) |
149 | | - return torch.as_tensor(value, **kwargs) |
150 | | - else: |
151 | | - # Fallback to safe conversion via Python ints |
152 | | - kwargs.setdefault("dtype", target_dtype) |
153 | | - return torch.tensor(value, **kwargs) |
154 | | - else: |
155 | | - # Use zero-copy conversion for compatible integer types |
156 | | - if value.dtype == np.int64 and target_dtype == torch.int64: |
157 | | - # Perfect match, zero-copy conversion |
158 | | - if not value.flags.writeable: |
159 | | - value = value.copy() |
160 | | - return torch.from_numpy(value) |
161 | | - else: |
162 | | - # Need dtype conversion, use as_tensor for efficiency |
163 | | - if not value.flags.writeable: |
164 | | - value = value.copy() |
165 | | - kwargs.setdefault("dtype", target_dtype) |
166 | | - return torch.as_tensor(value, **kwargs) |
167 | | - |
168 | | - # Handle floating point types |
169 | | - elif np.issubdtype(value.dtype, np.floating): |
170 | | - # Check if user specified a dtype, otherwise default to float32 |
171 | | - kwargs = self.torch_tensor_kwargs.copy() |
172 | | - target_dtype = kwargs.get("dtype", torch.float32) |
173 | | - |
174 | | - if value.dtype == np.float32 and target_dtype == torch.float32: |
175 | | - # Zero-copy conversion, but ensure array is writable |
176 | | - if not value.flags.writeable: |
177 | | - value = value.copy() |
178 | | - return torch.from_numpy(value) |
179 | | - else: |
180 | | - # Need dtype conversion |
181 | | - if not value.flags.writeable: |
182 | | - value = value.copy() |
183 | | - kwargs.setdefault("dtype", target_dtype) |
184 | | - return torch.as_tensor(value, **kwargs) |
185 | | - else: |
186 | | - # Other numpy types, use zero-copy when possible |
187 | | - if not value.flags.writeable: |
188 | | - value = value.copy() |
189 | | - return torch.from_numpy(value) |
190 | | - |
191 | | - # Handle numpy scalars |
192 | | - elif isinstance(value, np.number): |
193 | | - kwargs = self.torch_tensor_kwargs.copy() |
194 | | - if np.issubdtype(value.dtype, np.integer): |
195 | | - # Use torch.as_tensor for scalar conversion with dtype control |
196 | | - kwargs.setdefault("dtype", torch.int64) |
197 | | - return torch.as_tensor(value, **kwargs) |
198 | | - elif np.issubdtype(value.dtype, np.floating): |
199 | | - kwargs.setdefault("dtype", torch.float32) |
200 | | - return torch.as_tensor(value, **kwargs) |
201 | | - else: |
202 | | - return torch.as_tensor(value, **kwargs) |
203 | | - |
204 | | - # Handle Python lists/tuples of numbers efficiently |
205 | | - elif isinstance(value, (list, tuple)): |
206 | | - # Try to convert to numpy first for faster tensor creation |
207 | | - try: |
208 | | - arr = np.array(value) |
209 | | - if arr.dtype.kind in "iuf": # integer, unsigned, float |
210 | | - return self._tensorize(arr) # Recursive call to handle numpy path |
211 | | - except (ValueError, TypeError): |
212 | | - pass # Fall back to torch.tensor |
213 | | - |
214 | | - # Default fallback with dtype defaults |
215 | | - default_dtype = {} |
216 | | - if isinstance(value, (int, float)): |
217 | | - if isinstance(value, int): |
218 | | - default_dtype = {"dtype": torch.int64} |
219 | | - else: |
220 | | - default_dtype = {"dtype": torch.float32} |
| 88 | + return value # TODO(QL): set output to jax arrays ? |
221 | 89 |
|
222 | 90 | return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs}) |
223 | 91 |
|
224 | 92 | def _recursive_tensorize(self, data_struct): |
225 | | - """Optimized recursive walker with reduced Python overhead.""" |
226 | | - # Handle tensor-like objects with __array__ interface |
| 93 | + import torch |
| 94 | + |
| 95 | + # support for torch, tf, jax etc. |
227 | 96 | if hasattr(data_struct, "__array__") and not isinstance(data_struct, torch.Tensor): |
228 | 97 | data_struct = data_struct.__array__() |
229 | | - |
230 | | - # Handle object arrays (nested structures) |
| 98 | + # support for nested types like struct of list of struct |
231 | 99 | if isinstance(data_struct, np.ndarray): |
232 | | - if data_struct.dtype == object: |
233 | | - # Use list comprehension instead of map_nested |
234 | | - result = [self._recursive_tensorize(item) for item in data_struct] |
235 | | - return self._consolidate(result) |
236 | | - # Handle lists and tuples |
| 100 | + if data_struct.dtype == object: # torch tensors cannot be instantied from an array of objects |
| 101 | + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) |
237 | 102 | elif isinstance(data_struct, (list, tuple)): |
238 | | - result = [self._recursive_tensorize(item) for item in data_struct] |
239 | | - return self._consolidate(result) |
240 | | - # Handle dictionaries |
241 | | - elif isinstance(data_struct, dict): |
242 | | - return {key: self._recursive_tensorize(value) for key, value in data_struct.items()} |
243 | | - |
244 | | - # Base case: tensorize the leaf value |
| 103 | + return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct]) |
245 | 104 | return self._tensorize(data_struct) |
246 | 105 |
|
247 | 106 | def recursive_tensorize(self, data_struct: dict): |
248 | | - """Public interface maintaining compatibility.""" |
249 | | - return self._recursive_tensorize(data_struct) |
| 107 | + return map_nested(self._recursive_tensorize, data_struct, map_list=False) |
250 | 108 |
|
251 | 109 | def format_row(self, pa_table: pa.Table) -> Mapping: |
252 | 110 | row = self.numpy_arrow_extractor().extract_row(pa_table) |
|
0 commit comments