@@ -104,7 +104,6 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
104104 cache_position = cache_position ,
105105 input_block_ids = valid_block_ids ,
106106 )
107- decoder_input_ids = kwargs .pop ("input_ids" )
108107 decoder_cache_position = kwargs .pop ("cache_position" )
109108 decoder_block_tables = kwargs .pop ("block_tables" )
110109 # FIXME Is it ok generate torch.zero tensor for each forward?
@@ -123,7 +122,11 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
123122 decoder_cache_position [batch_idx ] = 0
124123 decoder_attention_mask [batch_idx , 0 ] = 1
125124 self .dec_lengths [batch_idx ] = 1
126-
125+ decoder_input_ids = torch .full (
126+ (self .batch_size , 1 ),
127+ self .model .config .decoder_start_token_id ,
128+ dtype = torch .long ,
129+ )
127130 decoder_output = self .model .decoder (
128131 decoder_input_ids = decoder_input_ids .contiguous (),
129132 decoder_attention_mask = decoder_attention_mask ,
@@ -132,6 +135,7 @@ def forward(self, model_input: ModelInputForRBLN, **kwargs) -> torch.Tensor:
132135 )
133136
134137 else :
138+ decoder_input_ids = kwargs .pop ("input_ids" )
135139 # Generate cache_position using dec_lengths
136140 for batch_idx in valid_block_ids :
137141 decoder_cache_position [batch_idx ] = self .dec_lengths [batch_idx ]
0 commit comments