diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 5b1b33ea64..60d846bea5 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -111,17 +111,21 @@ def _create_scale_bmm2_d_tensor( """ if data_dtype == torch.float16: # Create int32 buffer on device, write FP16 value to lower 16 bits via view - result = torch.zeros(1, dtype=torch.int32, device=device) - result.view(torch.float16)[0] = scale_bmm2 - return result + return ( + torch.full((1,), scale_bmm2, dtype=torch.float16, device=device) + .view(torch.uint16) + .to(torch.int32) + ) elif data_dtype == torch.bfloat16: # Create int32 buffer on device, write BF16 value to lower 16 bits via view - result = torch.zeros(1, dtype=torch.int32, device=device) - result.view(torch.bfloat16)[0] = scale_bmm2 - return result + return ( + torch.full((1,), scale_bmm2, dtype=torch.bfloat16, device=device) + .view(torch.uint16) + .to(torch.int32) + ) else: # FP8, INT8, etc. use FP32 accumulation - create FP32 tensor and view as int32 - return torch.tensor([scale_bmm2], dtype=torch.float32, device=device).view( + return torch.full((1,), scale_bmm2, dtype=torch.float32, device=device).view( torch.int32 )