@@ -149,8 +149,12 @@ def test_append_paged_kv_cache_int4_matches_quantized_layout(
149149 torch .testing .assert_close (
150150 v_cache .scale [page_indices , :, page_positions , :], expected_v .scale
151151 )
152- gathered_k = flashinfer .int4_dequantize (k_cache )[page_indices , :, page_positions ]
153- gathered_v = flashinfer .int4_dequantize (v_cache )[page_indices , :, page_positions ]
152+ gathered_k = flashinfer .int4_dequantize (k_cache )[
153+ page_indices , :, page_positions
154+ ]
155+ gathered_v = flashinfer .int4_dequantize (v_cache )[
156+ page_indices , :, page_positions
157+ ]
154158
155159 torch .testing .assert_close (
156160 gathered_k ,
@@ -177,11 +181,19 @@ def test_single_decode_with_kv_cache_int4(kv_layout, head_dim, use_tensor_cores)
177181
178182 q = torch .randn (num_qo_heads , head_dim , dtype = torch .float16 , device = device )
179183 if kv_layout == "NHD" :
180- k = torch .randn (kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device )
181- v = torch .randn (kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device )
184+ k = torch .randn (
185+ kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device
186+ )
187+ v = torch .randn (
188+ kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device
189+ )
182190 else :
183- k = torch .randn (num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device )
184- v = torch .randn (num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device )
191+ k = torch .randn (
192+ num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device
193+ )
194+ v = torch .randn (
195+ num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device
196+ )
185197
186198 k_int4 = flashinfer .int4_quantize (k )
187199 v_int4 = flashinfer .int4_quantize (v )
@@ -232,11 +244,19 @@ def test_single_prefill_with_kv_cache_int4(kv_layout, head_dim):
232244
233245 q = torch .randn (qo_len , num_qo_heads , head_dim , dtype = torch .float16 , device = device )
234246 if kv_layout == "NHD" :
235- k = torch .randn (kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device )
236- v = torch .randn (kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device )
247+ k = torch .randn (
248+ kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device
249+ )
250+ v = torch .randn (
251+ kv_len , num_kv_heads , head_dim , dtype = torch .float16 , device = device
252+ )
237253 else :
238- k = torch .randn (num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device )
239- v = torch .randn (num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device )
254+ k = torch .randn (
255+ num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device
256+ )
257+ v = torch .randn (
258+ num_kv_heads , kv_len , head_dim , dtype = torch .float16 , device = device
259+ )
240260
241261 k_int4 = flashinfer .int4_quantize (k )
242262 v_int4 = flashinfer .int4_quantize (v )
@@ -666,9 +686,8 @@ def test_int4_paged_kv_cache_cuda_graph_unsupported():
666686 head_dim = 128
667687 device = "cuda:0"
668688
669- kv_indptr = (
670- torch .arange (0 , batch_size + 1 , device = device , dtype = torch .int32 )
671- * ((kv_len + page_size - 1 ) // page_size )
689+ kv_indptr = torch .arange (0 , batch_size + 1 , device = device , dtype = torch .int32 ) * (
690+ (kv_len + page_size - 1 ) // page_size
672691 )
673692 kv_indices = torch .arange (kv_indptr [- 1 ].item (), device = device , dtype = torch .int32 )
674693 kv_last_page_len = torch .full (
0 commit comments