Skip to content

Commit 244ce28

Browse files
committed
Revert "feat: avoid some copies in torch formatter (#7787)"
This reverts commit c412a6f.
1 parent 483ac2e commit 244ce28

1 file changed

Lines changed: 40 additions & 182 deletions

File tree

src/datasets/formatting/torch_formatter.py

Lines changed: 40 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -21,232 +21,90 @@
2121
import pyarrow as pa
2222

2323
from .. import config
24+
from ..utils.py_utils import map_nested
2425
from .formatting import TensorFormatter
2526

2627

2728
if TYPE_CHECKING:
2829
import torch
2930

30-
# Import torch once at module level once
31-
try:
32-
import torch
33-
34-
_torch_available = True
35-
except ImportError:
36-
_torch_available = False
37-
torch = None
38-
3931

4032
class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]):
4133
def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs):
4234
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
4335
self.torch_tensor_kwargs = torch_tensor_kwargs
44-
45-
if not _torch_available:
46-
raise ImportError("PyTorch is required but not available")
36+
import torch # noqa import torch at initialization
4737

4838
def _consolidate(self, column):
49-
"""Smarter consolidation that only stacks when safe and beneficial."""
50-
if not isinstance(column, list) or not column:
51-
return column
52-
53-
# Check if all items are tensors with matching properties
54-
first = column[0]
55-
if not isinstance(first, torch.Tensor):
56-
return column
57-
58-
# Fast check: if all tensors have same shape, dtype, and device, we can stack
59-
if all(
60-
isinstance(x, torch.Tensor)
61-
and x.shape == first.shape
62-
and x.dtype == first.dtype
63-
and x.device == first.device
64-
for x in column
65-
):
66-
return torch.stack(column)
67-
39+
import torch
40+
41+
if isinstance(column, list) and column:
42+
if all(
43+
isinstance(x, torch.Tensor) and x.shape == column[0].shape and x.dtype == column[0].dtype
44+
for x in column
45+
):
46+
return torch.stack(column)
6847
return column
6948

7049
def _tensorize(self, value):
71-
"""Zero/low-copy tensor conversion with smart dtype handling."""
72-
# Fast path for strings, bytes, None
50+
import torch
51+
7352
if isinstance(value, (str, bytes, type(None))):
7453
return value
75-
76-
# Handle string arrays
77-
if isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
54+
elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
7855
return value.tolist()
7956

80-
# PIL Image fast path - avoid extra copies
57+
default_dtype = {}
58+
59+
if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
60+
default_dtype = {"dtype": torch.int64}
61+
62+
# Convert dtype to np.int64 if it's either np.uint16 or np.uint32 to ensure compatibility.
63+
# np.uint64 is excluded from this conversion as there is no compatible PyTorch dtype that can handle it without loss.
64+
if value.dtype in [np.uint16, np.uint32]:
65+
value = value.astype(np.int64)
66+
67+
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
68+
default_dtype = {"dtype": torch.float32}
69+
8170
if config.PIL_AVAILABLE and "PIL" in sys.modules:
8271
import PIL.Image
8372

8473
if isinstance(value, PIL.Image.Image):
85-
# Single conversion path: PIL -> numpy -> torch
86-
arr = np.asarray(value)
87-
if arr.ndim == 2:
88-
arr = arr[:, :, np.newaxis]
89-
# Use moveaxis instead of transpose
90-
arr = np.moveaxis(arr, -1, 0) # HWC -> CHW
91-
# Ensure contiguous for zero-copy conversion
92-
if not arr.flags.c_contiguous:
93-
arr = np.ascontiguousarray(arr)
94-
# Ensure array is writable for torch conversion
95-
if not arr.flags.writeable:
96-
arr = arr.copy()
97-
return torch.from_numpy(arr)
74+
value = np.asarray(value)
75+
if value.ndim == 2:
76+
value = value[:, :, np.newaxis]
9877

99-
# Video/Audio decoder passthrough
78+
value = value.transpose((2, 0, 1))
10079
if config.TORCHVISION_AVAILABLE and "torchvision" in sys.modules:
10180
from torchvision.io import VideoReader
10281

10382
if isinstance(value, VideoReader):
104-
return value
105-
83+
return value # TODO(QL): set output to torch tensors ?
10684
if config.TORCHCODEC_AVAILABLE and "torchcodec" in sys.modules:
10785
from torchcodec.decoders import AudioDecoder, VideoDecoder
10886

10987
if isinstance(value, (VideoDecoder, AudioDecoder)):
110-
return value
111-
112-
# Support for other tensor libraries via __array__
113-
if hasattr(value, "__array__") and not isinstance(value, torch.Tensor):
114-
value = value.__array__()
115-
116-
# Fast numpy conversion paths
117-
if isinstance(value, np.ndarray):
118-
# Handle integer types with smart casting
119-
if np.issubdtype(value.dtype, np.integer):
120-
# Check if user specified a dtype, otherwise default to int64
121-
kwargs = self.torch_tensor_kwargs.copy()
122-
target_dtype = kwargs.get("dtype", torch.int64)
123-
124-
# Safe casting for unsigned types
125-
if value.dtype in (np.uint16, np.uint32):
126-
# Cast to int64 in numpy (fast) then convert to torch
127-
value = value.astype(np.int64)
128-
if target_dtype == torch.int64:
129-
if not value.flags.writeable:
130-
value = value.copy()
131-
return torch.from_numpy(value)
132-
else:
133-
if not value.flags.writeable:
134-
value = value.copy()
135-
kwargs.setdefault("dtype", target_dtype)
136-
return torch.as_tensor(value, **kwargs)
137-
elif value.dtype == np.uint64:
138-
# Check if values fit in int64 range
139-
if np.all(value <= np.iinfo(np.int64).max):
140-
value = value.astype(np.int64)
141-
if target_dtype == torch.int64:
142-
if not value.flags.writeable:
143-
value = value.copy()
144-
return torch.from_numpy(value)
145-
else:
146-
if not value.flags.writeable:
147-
value = value.copy()
148-
kwargs.setdefault("dtype", target_dtype)
149-
return torch.as_tensor(value, **kwargs)
150-
else:
151-
# Fallback to safe conversion via Python ints
152-
kwargs.setdefault("dtype", target_dtype)
153-
return torch.tensor(value, **kwargs)
154-
else:
155-
# Use zero-copy conversion for compatible integer types
156-
if value.dtype == np.int64 and target_dtype == torch.int64:
157-
# Perfect match, zero-copy conversion
158-
if not value.flags.writeable:
159-
value = value.copy()
160-
return torch.from_numpy(value)
161-
else:
162-
# Need dtype conversion, use as_tensor for efficiency
163-
if not value.flags.writeable:
164-
value = value.copy()
165-
kwargs.setdefault("dtype", target_dtype)
166-
return torch.as_tensor(value, **kwargs)
167-
168-
# Handle floating point types
169-
elif np.issubdtype(value.dtype, np.floating):
170-
# Check if user specified a dtype, otherwise default to float32
171-
kwargs = self.torch_tensor_kwargs.copy()
172-
target_dtype = kwargs.get("dtype", torch.float32)
173-
174-
if value.dtype == np.float32 and target_dtype == torch.float32:
175-
# Zero-copy conversion, but ensure array is writable
176-
if not value.flags.writeable:
177-
value = value.copy()
178-
return torch.from_numpy(value)
179-
else:
180-
# Need dtype conversion
181-
if not value.flags.writeable:
182-
value = value.copy()
183-
kwargs.setdefault("dtype", target_dtype)
184-
return torch.as_tensor(value, **kwargs)
185-
else:
186-
# Other numpy types, use zero-copy when possible
187-
if not value.flags.writeable:
188-
value = value.copy()
189-
return torch.from_numpy(value)
190-
191-
# Handle numpy scalars
192-
elif isinstance(value, np.number):
193-
kwargs = self.torch_tensor_kwargs.copy()
194-
if np.issubdtype(value.dtype, np.integer):
195-
# Use torch.as_tensor for scalar conversion with dtype control
196-
kwargs.setdefault("dtype", torch.int64)
197-
return torch.as_tensor(value, **kwargs)
198-
elif np.issubdtype(value.dtype, np.floating):
199-
kwargs.setdefault("dtype", torch.float32)
200-
return torch.as_tensor(value, **kwargs)
201-
else:
202-
return torch.as_tensor(value, **kwargs)
203-
204-
# Handle Python lists/tuples of numbers efficiently
205-
elif isinstance(value, (list, tuple)):
206-
# Try to convert to numpy first for faster tensor creation
207-
try:
208-
arr = np.array(value)
209-
if arr.dtype.kind in "iuf": # integer, unsigned, float
210-
return self._tensorize(arr) # Recursive call to handle numpy path
211-
except (ValueError, TypeError):
212-
pass # Fall back to torch.tensor
213-
214-
# Default fallback with dtype defaults
215-
default_dtype = {}
216-
if isinstance(value, (int, float)):
217-
if isinstance(value, int):
218-
default_dtype = {"dtype": torch.int64}
219-
else:
220-
default_dtype = {"dtype": torch.float32}
88+
return value # TODO(QL): set output to jax arrays ?
22189

22290
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
22391

22492
def _recursive_tensorize(self, data_struct):
225-
"""Optimized recursive walker with reduced Python overhead."""
226-
# Handle tensor-like objects with __array__ interface
93+
import torch
94+
95+
# support for torch, tf, jax etc.
22796
if hasattr(data_struct, "__array__") and not isinstance(data_struct, torch.Tensor):
22897
data_struct = data_struct.__array__()
229-
230-
# Handle object arrays (nested structures)
98+
# support for nested types like struct of list of struct
23199
if isinstance(data_struct, np.ndarray):
232-
if data_struct.dtype == object:
233-
# Use list comprehension instead of map_nested
234-
result = [self._recursive_tensorize(item) for item in data_struct]
235-
return self._consolidate(result)
236-
# Handle lists and tuples
100+
if data_struct.dtype == object: # torch tensors cannot be instantied from an array of objects
101+
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
237102
elif isinstance(data_struct, (list, tuple)):
238-
result = [self._recursive_tensorize(item) for item in data_struct]
239-
return self._consolidate(result)
240-
# Handle dictionaries
241-
elif isinstance(data_struct, dict):
242-
return {key: self._recursive_tensorize(value) for key, value in data_struct.items()}
243-
244-
# Base case: tensorize the leaf value
103+
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
245104
return self._tensorize(data_struct)
246105

247106
def recursive_tensorize(self, data_struct: dict):
248-
"""Public interface maintaining compatibility."""
249-
return self._recursive_tensorize(data_struct)
107+
return map_nested(self._recursive_tensorize, data_struct, map_list=False)
250108

251109
def format_row(self, pa_table: pa.Table) -> Mapping:
252110
row = self.numpy_arrow_extractor().extract_row(pa_table)

0 commit comments

Comments
 (0)