@@ -2066,12 +2066,14 @@ def test_chunked_prefill_hidden_state_prevents_token_bloat(self):
20662066 # 3. Call update_requests
20672067 active_requests_mask = torch .tensor ([1 , 1 ], dtype = torch .int32 , device = 'cuda' )
20682068 new_tokens = torch .tensor ([99 , 199 ], dtype = torch .int32 , device = 'cuda' )
2069- new_spec = torch .tensor ([[100 , 200 ], [101 , 201 ], [102 , 202 ]], dtype = torch .int32 , device = 'cuda' )
2069+ new_spec = torch .tensor (
2070+ [[100 , 200 ], [101 , 201 ], [102 , 202 ]], dtype = torch .int32 , device = 'cuda'
2071+ )
20702072
20712073 ctx .update_requests (
20722074 active_requests_mask = active_requests_mask ,
20732075 new_tokens = new_tokens ,
2074- new_speculative_tokens = new_spec
2076+ new_speculative_tokens = new_spec ,
20752077 )
20762078
20772079 # 4. Verify Hiding Invariants:
@@ -2167,9 +2169,7 @@ def test_chunked_prefill_swap_with_speculative_tokens(self):
21672169
21682170 # 4. Verify that the new_speculative_tokens tensor itself was swapped so that
21692171 # the hidden state perfectly preserves the alignment for subsequent steps.
2170- expected_swapped_spec_tokens = torch .tensor (
2171- [[201 , 101 ], [202 , 102 ]], device = 'cuda'
2172- )
2172+ expected_swapped_spec_tokens = torch .tensor ([[201 , 101 ], [202 , 102 ]], device = 'cuda' )
21732173 assert torch .equal (
21742174 new_speculative_tokens , expected_swapped_spec_tokens
21752175 ), "new_speculative_tokens was not swapped in-place alongside the request metadata!"
0 commit comments