99from .tensor_details import ragged_tensor as ragged_tensor_details
1010from .tensor_details .layout import BlackwellMXValueLayout , Layout , StridedLayout
1111from .tensor_details .ragged_tensor import RaggedTensorMetadata
12- from .tensor_details .dtype import IntegerType , FloatType , DataType , FP4 , UINT8 , FP8_E4M3FN , FP8_E4M3FNUZ , FP8_E5M2 , FP16 , BF16 , FP32 , FP64
12+ from .tensor_details .dtype import IntegerType , FloatType , DataType
13+ from .tensor_details .dtype import FP4 , UINT8 , FP8_E4M3FN , FP8_E4M3FNUZ , FP8_E5M2 , FP16 , BF16 , FP32 , FP64 , INT16 , INT32 , INT64
1314
1415
1516# storage
@@ -246,6 +247,9 @@ def dtype_to_torch_dtype(dtype: DataType) -> torch.dtype:
246247 FP32 : torch .float32 ,
247248 FP16 : torch .float16 ,
248249 FP64 : torch .float64 ,
250+ INT16 : torch .int16 ,
251+ INT32 : torch .int32 ,
252+ INT64 : torch .int64 ,
249253 }[dtype ]
250254
251255
@@ -262,6 +266,9 @@ def torch_dtype_to_dtype(dtype: torch.dtype) -> DataType:
262266 "bfloat16" : BF16 ,
263267 "float32" : FP32 ,
264268 "float64" : FP64 ,
269+ "int16" : INT16 ,
270+ "int32" : INT32 ,
271+ "int64" : INT64 ,
265272 }
266273 if id in vals :
267274 return vals [id ]
@@ -270,15 +277,13 @@ def torch_dtype_to_dtype(dtype: torch.dtype) -> DataType:
270277 assert False , f"Unknown dtype: { id } "
271278
272279
273- def empty (shape : tuple [int ], dtype : DataType , device : torch .device , layout = None ):
280+ def empty (shape : tuple [int ], dtype : DataType , device : torch .device , layout = None ,
281+ allow_implicit_conversion : bool = False ):
274282 storage_shape = list (shape )
275283 storage_dtype = torch .uint8 if dtype == FP4 else dtype_to_torch_dtype (dtype )
284+ initial_layout = layout if isinstance (layout , StridedLayout ) else StridedLayout ()
276285 # pack sub-byte datatype along last dimension
277- if layout is None :
278- layout = StridedLayout ()
279- # storage shape
280- assert isinstance (layout , StridedLayout )
281- order = layout .order (len (storage_shape ))
286+ order = initial_layout .order (len (storage_shape ))
282287 dim = order [0 ]
283288 storage_shape [dim ] = storage_shape [dim ] // (storage_dtype .itemsize * 8 // dtype .bitwidth )
284289 # storage strides
@@ -288,4 +293,8 @@ def empty(shape: tuple[int], dtype: DataType, device: torch.device, layout=None)
288293 strides [d ] = running
289294 running *= storage_shape [d ]
290295 storage = torch .empty_strided (storage_shape , strides , device = device , dtype = storage_dtype )
291- return wrap_torch_tensor (storage , dtype = dtype , shape = shape , layout = layout )
296+ ret = wrap_torch_tensor (storage , dtype = dtype , shape = shape , layout = initial_layout )
297+ assert initial_layout == ret .storage .layout or allow_implicit_conversion
298+ if allow_implicit_conversion :
299+ ret = convert_layout (ret , layout )
300+ return ret
0 commit comments