Skip to content

Commit 06b2c60

Browse files
committed
fix: handle non writable np arrays
1 parent fa3542f commit 06b2c60

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

src/datasets/formatting/torch_formatter.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def _tensorize(self, value):
9393
# Ensure contiguous for zero-copy conversion
9494
if not arr.flags.c_contiguous:
9595
arr = np.ascontiguousarray(arr)
96+
# Ensure array is writable for torch conversion
97+
if not arr.flags.writeable:
98+
arr = arr.copy()
9699
return torch.from_numpy(arr)
97100

98101
# Video/Audio decoder passthrough
@@ -125,17 +128,25 @@ def _tensorize(self, value):
125128
# Cast to int64 in numpy (fast) then convert to torch
126129
value = value.astype(np.int64)
127130
if target_dtype == torch.int64:
131+
if not value.flags.writeable:
132+
value = value.copy()
128133
return torch.from_numpy(value)
129134
else:
135+
if not value.flags.writeable:
136+
value = value.copy()
130137
kwargs.setdefault("dtype", target_dtype)
131138
return torch.as_tensor(value, **kwargs)
132139
elif value.dtype == np.uint64:
133140
# Check if values fit in int64 range
134141
if np.all(value <= np.iinfo(np.int64).max):
135142
value = value.astype(np.int64)
136143
if target_dtype == torch.int64:
144+
if not value.flags.writeable:
145+
value = value.copy()
137146
return torch.from_numpy(value)
138147
else:
148+
if not value.flags.writeable:
149+
value = value.copy()
139150
kwargs.setdefault("dtype", target_dtype)
140151
return torch.as_tensor(value, **kwargs)
141152
else:
@@ -146,9 +157,13 @@ def _tensorize(self, value):
146157
# Use zero-copy conversion for compatible integer types
147158
if value.dtype == np.int64 and target_dtype == torch.int64:
148159
# Perfect match, zero-copy conversion
160+
if not value.flags.writeable:
161+
value = value.copy()
149162
return torch.from_numpy(value)
150163
else:
151164
# Need dtype conversion, use as_tensor for efficiency
165+
if not value.flags.writeable:
166+
value = value.copy()
152167
kwargs.setdefault("dtype", target_dtype)
153168
return torch.as_tensor(value, **kwargs)
154169

@@ -159,14 +174,20 @@ def _tensorize(self, value):
159174
target_dtype = kwargs.get("dtype", torch.float32)
160175

161176
if value.dtype == np.float32 and target_dtype == torch.float32:
162-
# Zero-copy conversion
177+
# Zero-copy conversion, but ensure array is writable
178+
if not value.flags.writeable:
179+
value = value.copy()
163180
return torch.from_numpy(value)
164181
else:
165182
# Need dtype conversion
183+
if not value.flags.writeable:
184+
value = value.copy()
166185
kwargs.setdefault("dtype", target_dtype)
167186
return torch.as_tensor(value, **kwargs)
168187
else:
169188
# Other numpy types, use zero-copy when possible
189+
if not value.flags.writeable:
190+
value = value.copy()
170191
return torch.from_numpy(value)
171192

172193
# Handle numpy scalars

0 commit comments

Comments
 (0)