Skip to content

Commit fa3542f

Browse files
committed
fix: handle dtype
1 parent 2f0806a commit fa3542f

1 file changed

Lines changed: 28 additions & 18 deletions

File tree

src/datasets/formatting/torch_formatter.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,61 +116,71 @@ def _tensorize(self, value):
116116
if isinstance(value, np.ndarray):
117117
# Handle integer types with smart casting
118118
if np.issubdtype(value.dtype, np.integer):
119-
target_dtype = torch.int64
119+
# Check if user specified a dtype, otherwise default to int64
120+
kwargs = self.torch_tensor_kwargs.copy()
121+
target_dtype = kwargs.get("dtype", torch.int64)
120122

121123
# Safe casting for unsigned types
122124
if value.dtype in (np.uint16, np.uint32):
123125
# Cast to int64 in numpy (fast) then convert to torch
124126
value = value.astype(np.int64)
125-
return torch.from_numpy(value)
127+
if target_dtype == torch.int64:
128+
return torch.from_numpy(value)
129+
else:
130+
kwargs.setdefault("dtype", target_dtype)
131+
return torch.as_tensor(value, **kwargs)
126132
elif value.dtype == np.uint64:
127133
# Check if values fit in int64 range
128134
if np.all(value <= np.iinfo(np.int64).max):
129135
value = value.astype(np.int64)
130-
return torch.from_numpy(value)
136+
if target_dtype == torch.int64:
137+
return torch.from_numpy(value)
138+
else:
139+
kwargs.setdefault("dtype", target_dtype)
140+
return torch.as_tensor(value, **kwargs)
131141
else:
132142
# Fallback to safe conversion via Python ints
133-
kwargs = self.torch_tensor_kwargs.copy()
134143
kwargs.setdefault("dtype", target_dtype)
135144
return torch.tensor(value, **kwargs)
136145
else:
137146
# Use zero-copy conversion for compatible integer types
138-
if value.dtype != np.int64:
147+
if value.dtype == np.int64 and target_dtype == torch.int64:
148+
# Perfect match, zero-copy conversion
149+
return torch.from_numpy(value)
150+
else:
139151
# Need dtype conversion, use as_tensor for efficiency
140-
kwargs = self.torch_tensor_kwargs.copy()
141152
kwargs.setdefault("dtype", target_dtype)
142153
return torch.as_tensor(value, **kwargs)
143-
else:
144-
# Perfect match, zero-copy conversion
145-
return torch.from_numpy(value)
146154

147155
# Handle floating point types
148156
elif np.issubdtype(value.dtype, np.floating):
149-
if value.dtype != np.float32:
150-
# Need dtype conversion
151-
kwargs = self.torch_tensor_kwargs.copy()
152-
kwargs.setdefault("dtype", torch.float32)
153-
return torch.as_tensor(value, **kwargs)
154-
else:
157+
# Check if user specified a dtype, otherwise default to float32
158+
kwargs = self.torch_tensor_kwargs.copy()
159+
target_dtype = kwargs.get("dtype", torch.float32)
160+
161+
if value.dtype == np.float32 and target_dtype == torch.float32:
155162
# Zero-copy conversion
156163
return torch.from_numpy(value)
164+
else:
165+
# Need dtype conversion
166+
kwargs.setdefault("dtype", target_dtype)
167+
return torch.as_tensor(value, **kwargs)
157168
else:
158169
# Other numpy types, use zero-copy when possible
159170
return torch.from_numpy(value)
160171

161172
# Handle numpy scalars
162173
elif isinstance(value, np.number):
174+
kwargs = self.torch_tensor_kwargs.copy()
163175
if np.issubdtype(value.dtype, np.integer):
164176
# Use torch.as_tensor for scalar conversion with dtype control
165-
kwargs = self.torch_tensor_kwargs.copy()
166177
kwargs.setdefault("dtype", torch.int64)
167178
return torch.as_tensor(value, **kwargs)
168179
elif np.issubdtype(value.dtype, np.floating):
169-
kwargs = self.torch_tensor_kwargs.copy()
170180
kwargs.setdefault("dtype", torch.float32)
171181
return torch.as_tensor(value, **kwargs)
172182
else:
173-
return torch.as_tensor(value, **self.torch_tensor_kwargs)
183+
return torch.as_tensor(value, **kwargs)
174184

175185
# Handle Python lists/tuples of numbers efficiently
176186
elif isinstance(value, (list, tuple)):

0 commit comments

Comments
 (0)