@@ -125,7 +125,7 @@ def _fused_kernel_quantize_into_fp8(
125125 # be written
126126 o_curr_ptr = o_ptr + o_offset
127127 o_scale_ptr = o_curr_ptr .to (tl .pointer_type (SCALE_TL_DTYPE ))
128- o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE ))
128+ o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE )) # type: ignore
129129
130130 # Compute maximum for the current row block by block
131131 col_offsets = tl .arange (0 , BLOCK_SIZE )
@@ -233,7 +233,7 @@ def _fused_kernel_dequantize_from_fp8(
233233 # written
234234 o_curr_ptr = o_ptr + o_offset
235235 o_scale_ptr = o_curr_ptr .to (tl .pointer_type (SCALE_TL_DTYPE ))
236- o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE ))
236+ o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE )) # type: ignore
237237
238238 # Load row scale
239239 i_row_scale = tl .load (o_scale_ptr )
@@ -342,7 +342,7 @@ def _fused_kernel_reduce_fp8(
342342 o_rank_row_ptr = o_ptr + all_reduce_rank * o_size_bytes_per_rank + o_offset
343343 o_rank_scale_ptr = o_rank_row_ptr .to (tl .pointer_type (SCALE_TL_DTYPE ))
344344 o_rank_quant_ptr = (o_rank_row_ptr + SCALE_TL_DTYPE_BYTES ).to (
345- tl .pointer_type (TL_FP8_TYPE )
345+ tl .pointer_type (TL_FP8_TYPE ) # type: ignore
346346 )
347347
348348 col_offsets = tl .arange (0 , BLOCK_SIZE )
@@ -411,7 +411,7 @@ def _fused_kernel_accumulate_block(
411411 # Load row scale and block of quantized row
412412 o_scale_ptr = o_row_ptr .to (tl .pointer_type (tl .float32 ))
413413 o_quant_ptr = (o_row_ptr + SCALE_TL_DTYPE_BYTES ).to (
414- tl .pointer_type (TL_FP8_TYPE )
414+ tl .pointer_type (TL_FP8_TYPE ) # type: ignore
415415 )
416416
417417 o_row_scale = tl .load (o_scale_ptr )
@@ -580,7 +580,7 @@ def fused_quantize_into_fp8(
580580 output ,
581581 output_size // all_reduce_group_size ,
582582 all_reduce_group_size ,
583- BLOCK_SIZE = BLOCK_SIZE_T ,
583+ BLOCK_SIZE = BLOCK_SIZE_T , # type: ignore
584584 TL_FP8_TYPE = _get_fp8_type (),
585585 TL_FP8_MAX = _get_fp8_max (),
586586 )
@@ -630,7 +630,7 @@ def fused_dequantize_from_fp8(
630630 output ,
631631 output_size // all_reduce_group_size ,
632632 all_reduce_group_size ,
633- BLOCK_SIZE = BLOCK_SIZE_T ,
633+ BLOCK_SIZE = BLOCK_SIZE_T , # type: ignore
634634 TL_FP8_TYPE = _get_fp8_type (),
635635 )
636636
@@ -680,7 +680,7 @@ def fused_reduce_fp8(
680680 all_reduce_group_size ,
681681 all_reduce_rank ,
682682 1.0 if reduce_op == ReduceOp .SUM else float (all_reduce_group_size ),
683- BLOCK_SIZE = BLOCK_SIZE_T ,
683+ BLOCK_SIZE = BLOCK_SIZE_T , # type: ignore
684684 TL_FP8_TYPE = _get_fp8_type (),
685685 TL_FP8_MAX = _get_fp8_max (),
686686 )
0 commit comments