@@ -116,61 +116,71 @@ def _tensorize(self, value):
116116 if isinstance (value , np .ndarray ):
117117 # Handle integer types with smart casting
118118 if np .issubdtype (value .dtype , np .integer ):
119- target_dtype = torch .int64
119+ # Check if user specified a dtype, otherwise default to int64
120+ kwargs = self .torch_tensor_kwargs .copy ()
121+ target_dtype = kwargs .get ("dtype" , torch .int64 )
120122
121123 # Safe casting for unsigned types
122124 if value .dtype in (np .uint16 , np .uint32 ):
123125 # Cast to int64 in numpy (fast) then convert to torch
124126 value = value .astype (np .int64 )
125- return torch .from_numpy (value )
127+ if target_dtype == torch .int64 :
128+ return torch .from_numpy (value )
129+ else :
130+ kwargs .setdefault ("dtype" , target_dtype )
131+ return torch .as_tensor (value , ** kwargs )
126132 elif value .dtype == np .uint64 :
127133 # Check if values fit in int64 range
128134 if np .all (value <= np .iinfo (np .int64 ).max ):
129135 value = value .astype (np .int64 )
130- return torch .from_numpy (value )
136+ if target_dtype == torch .int64 :
137+ return torch .from_numpy (value )
138+ else :
139+ kwargs .setdefault ("dtype" , target_dtype )
140+ return torch .as_tensor (value , ** kwargs )
131141 else :
132142 # Fallback to safe conversion via Python ints
133- kwargs = self .torch_tensor_kwargs .copy ()
134143 kwargs .setdefault ("dtype" , target_dtype )
135144 return torch .tensor (value , ** kwargs )
136145 else :
137146 # Use zero-copy conversion for compatible integer types
138- if value .dtype != np .int64 :
147+ if value .dtype == np .int64 and target_dtype == torch .int64 :
148+ # Perfect match, zero-copy conversion
149+ return torch .from_numpy (value )
150+ else :
139151 # Need dtype conversion, use as_tensor for efficiency
140- kwargs = self .torch_tensor_kwargs .copy ()
141152 kwargs .setdefault ("dtype" , target_dtype )
142153 return torch .as_tensor (value , ** kwargs )
143- else :
144- # Perfect match, zero-copy conversion
145- return torch .from_numpy (value )
146154
147155 # Handle floating point types
148156 elif np .issubdtype (value .dtype , np .floating ):
149- if value .dtype != np .float32 :
150- # Need dtype conversion
151- kwargs = self .torch_tensor_kwargs .copy ()
152- kwargs .setdefault ("dtype" , torch .float32 )
153- return torch .as_tensor (value , ** kwargs )
154- else :
157+ # Check if user specified a dtype, otherwise default to float32
158+ kwargs = self .torch_tensor_kwargs .copy ()
159+ target_dtype = kwargs .get ("dtype" , torch .float32 )
160+
161+ if value .dtype == np .float32 and target_dtype == torch .float32 :
155162 # Zero-copy conversion
156163 return torch .from_numpy (value )
164+ else :
165+ # Need dtype conversion
166+ kwargs .setdefault ("dtype" , target_dtype )
167+ return torch .as_tensor (value , ** kwargs )
157168 else :
158169 # Other numpy types, use zero-copy when possible
159170 return torch .from_numpy (value )
160171
161172 # Handle numpy scalars
162173 elif isinstance (value , np .number ):
174+ kwargs = self .torch_tensor_kwargs .copy ()
163175 if np .issubdtype (value .dtype , np .integer ):
164176 # Use torch.as_tensor for scalar conversion with dtype control
165- kwargs = self .torch_tensor_kwargs .copy ()
166177 kwargs .setdefault ("dtype" , torch .int64 )
167178 return torch .as_tensor (value , ** kwargs )
168179 elif np .issubdtype (value .dtype , np .floating ):
169- kwargs = self .torch_tensor_kwargs .copy ()
170180 kwargs .setdefault ("dtype" , torch .float32 )
171181 return torch .as_tensor (value , ** kwargs )
172182 else :
173- return torch .as_tensor (value , ** self . torch_tensor_kwargs )
183+ return torch .as_tensor (value , ** kwargs )
174184
175185 # Handle Python lists/tuples of numbers efficiently
176186 elif isinstance (value , (list , tuple )):
0 commit comments