@@ -2581,7 +2581,7 @@ def mock_mtp_forward(*args, **kwargs):
25812581 base_logits [:, :, 0 ] = 100.0 # High probability for token 0
25822582
25832583 # Cache hidden states for serial MTP computation
2584- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
2584+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
25852585 tokens .size (1 ), 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
25862586 )
25872587 if test_config .materialize_only_last_token_logits :
@@ -2720,7 +2720,7 @@ def mock_deterministic_forward(*args, **kwargs):
27202720 base_logits .scatter_ (2 , next_toks .unsqueeze (- 1 ), 100.0 )
27212721
27222722 # Cache hidden states for serial MTP computation
2723- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
2723+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
27242724 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
27252725 )
27262726 if test_config .materialize_only_last_token_logits :
@@ -2815,7 +2815,7 @@ def mock_deterministic_forward(*args, **kwargs):
28152815 base_logits .scatter_ (2 , next_toks .unsqueeze (- 1 ), 100.0 )
28162816
28172817 # Cache hidden states for serial MTP computation
2818- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
2818+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
28192819 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
28202820 )
28212821 if test_config .materialize_only_last_token_logits :
@@ -2911,7 +2911,7 @@ def mock_deterministic_forward(*args, **kwargs):
29112911 base_logits .scatter_ (2 , next_toks .unsqueeze (- 1 ), 100.0 )
29122912
29132913 # Cache hidden states for serial MTP computation
2914- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
2914+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
29152915 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
29162916 )
29172917 if test_config .materialize_only_last_token_logits :
@@ -3187,7 +3187,7 @@ def mock_mtp_forward(*args, **kwargs):
31873187 next_toks = (tokens + 1 ).clamp (max = test_config .vocab_size - 1 )
31883188 base_logits .scatter_ (2 , next_toks .unsqueeze (- 1 ), 100.0 )
31893189
3190- model . _decoder_hidden_states_cache = torch .zeros (
3190+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
31913191 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
31923192 )
31933193 if test_config .materialize_only_last_token_logits :
@@ -3308,7 +3308,7 @@ def mock_safe_forward(*args, **kwargs):
33083308 base_logits [:, :, 0 ] = 100.0 # Force model to deterministically pick token 0
33093309
33103310 # Cache hidden states for serial MTP computation
3311- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
3311+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
33123312 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
33133313 )
33143314 if test_config .materialize_only_last_token_logits :
@@ -3526,7 +3526,7 @@ def mock_mtp_forward(*args, **kwargs):
35263526 dtype = torch .bfloat16 ,
35273527 )
35283528 base_logits [:, :, 0 ] = 100.0
3529- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
3529+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
35303530 tokens .size (1 ), 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
35313531 )
35323532 return base_logits
@@ -3669,7 +3669,7 @@ def mock_deterministic_forward(*args, **kwargs):
36693669 )
36703670 # Make token 0 very likely so speculative tokens get accepted.
36713671 base_logits [:, :, 0 ] = 100.0
3672- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
3672+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
36733673 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
36743674 )
36753675 return base_logits
@@ -3791,7 +3791,7 @@ def mock_deterministic_forward(*args, **kwargs):
37913791 b , s , test_config .vocab_size , device = tokens .device , dtype = torch .bfloat16
37923792 )
37933793 base_logits [:, :, 0 ] = 100.0
3794- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
3794+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
37953795 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
37963796 )
37973797 return base_logits
@@ -3923,7 +3923,7 @@ def mock_deterministic_forward(*args, **kwargs):
39233923 b , s , test_config .vocab_size , device = tokens .device , dtype = torch .bfloat16
39243924 )
39253925 base_logits [:, :, 0 ] = 100.0
3926- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
3926+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
39273927 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
39283928 )
39293929 return base_logits
@@ -4178,7 +4178,7 @@ def mock_deterministic_forward(*args, **kwargs):
41784178 )
41794179 next_toks = (tokens + 1 ).clamp (max = test_config .vocab_size - 1 )
41804180 base_logits .scatter_ (2 , next_toks .unsqueeze (- 1 ), 100.0 )
4181- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
4181+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
41824182 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
41834183 )
41844184 return base_logits
@@ -4276,7 +4276,7 @@ def mock_deterministic_forward(*args, **kwargs):
42764276 )
42774277 next_toks = (tokens + 1 ).clamp (max = test_config .vocab_size - 1 )
42784278 base_logits .scatter_ (2 , next_toks .unsqueeze (- 1 ), 100.0 )
4279- unwrapped_model . _decoder_hidden_states_cache = torch .zeros (
4279+ env . engine . context . mtp_decoder_hidden_states = torch .zeros (
42804280 s , 1 , hidden_size , device = tokens .device , dtype = torch .bfloat16
42814281 )
42824282 return base_logits
0 commit comments