diff --git a/src/optimum/rbln/transformers/modeling_outputs.py b/src/optimum/rbln/transformers/modeling_outputs.py index a92424f81..38fda29ea 100644 --- a/src/optimum/rbln/transformers/modeling_outputs.py +++ b/src/optimum/rbln/transformers/modeling_outputs.py @@ -24,6 +24,7 @@ class RBLNDecoderOnlyOutput(ModelOutput): logits: torch.FloatTensor = None generate_idx: torch.Tensor = None padded_cache_lengths: int = None + hidden_states: Tuple[torch.FloatTensor] = None @dataclass diff --git a/src/optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py b/src/optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py index 9e3d478a1..c979f80a3 100644 --- a/src/optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +++ b/src/optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py @@ -58,7 +58,6 @@ def __init__( visual: Optional[RBLNModelConfig] = None, batch_size: Optional[int] = None, use_inputs_embeds: bool = True, - output_hidden_states: Optional[bool] = False, **kwargs, ): super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs) @@ -71,4 +70,3 @@ def __init__( raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig") self.visual = visual - self.output_hidden_states = output_hidden_states diff --git a/src/optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py b/src/optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py index 1bf81d0c2..baadded18 100644 --- a/src/optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +++ b/src/optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py @@ -58,6 +58,7 @@ def __init__( sliding_window_layers: Optional[List[int]] = None, phases: Optional[List[PhaseType]] = None, logits_to_keep: Optional[int] = None, + output_hidden_states: Optional[bool] = None, **kwargs, ): """ @@ -112,6 +113,7 @@ def __init__( ["prefill", "decode"] if DecoderOnlyModelForCausalLM is used. logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits. Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used. + output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False. kwargs: Additional arguments passed to the parent RBLNModelConfig. Raises: @@ -232,6 +234,8 @@ def __init__( if self.logits_to_keep is not None and self.logits_to_keep > 1: raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.") + self.output_hidden_states = output_hidden_states or False + self.decoder_batch_sizes = None if "decode" in self.phases: self.decoder_batch_sizes = decoder_batch_sizes diff --git a/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py b/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py index f33cf9165..6d2605346 100644 --- a/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +++ b/src/optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py @@ -203,7 +203,7 @@ def forward(self, *args): rotary_emb, ) = self.prepare_forward_args(*args) - logit = self.model( + logits, all_hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, @@ -215,9 +215,13 @@ def forward(self, *args): global_block_tables=global_block_tables, local_block_tables=local_block_tables, lora_int_id=lora_int_id, + output_hidden_states=self.rbln_config.output_hidden_states, ) - return logit + if self.rbln_config.output_hidden_states: + return logits, all_hidden_states + else: + return logits class DecoderOnlyForCausalLM(nn.Module): @@ -272,9 +276,10 @@ def forward( global_block_tables: Optional[torch.Tensor] = None, local_block_tables: Optional[torch.Tensor] = None, lora_int_id: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, ): # outputs - hidden_states = self.model( + hidden_states, all_hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, @@ -286,6 +291,7 @@ def forward( global_block_tables=global_block_tables, local_block_tables=local_block_tables, lora_int_id=lora_int_id, + output_hidden_states=output_hidden_states, ) if "prefill" in self.phase: @@ -299,7 +305,7 @@ def forward( logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping - return logits + return logits, all_hidden_states class DecoderOnlyModel(nn.Module): @@ -398,6 +404,7 @@ def forward( global_block_tables: Optional[torch.Tensor] = None, local_block_tables: Optional[torch.Tensor] = None, lora_int_id: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, ): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): @@ -460,7 +467,11 @@ def forward( if len(self.sliding_window_layers) > 0: sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position) + all_hidden_states = () if output_hidden_states else None for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + is_sliding = True if layer_idx in self.sliding_window_layers else False hidden_states = layer( hidden_states=hidden_states, @@ -474,7 +485,10 @@ def forward( ) hidden_states = self.get_last_layernorm()(hidden_states) - return hidden_states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return hidden_states, all_hidden_states class DecoderOnlyLayer(nn.Module): diff --git a/src/optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py b/src/optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py index 9daa923b6..2a9f68973 100644 --- a/src/optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +++ b/src/optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py @@ -301,7 +301,7 @@ def decode_forward( attention_mask = self.dec_attn_mask - logits = super().forward( + outputs = super().forward( inputs, cache_position, block_tables, @@ -312,7 +312,10 @@ def decode_forward( lora_int_ids if self.rbln_config.use_lora else None, ) - return RBLNDecoderOnlyOutput(logits=logits) + if self.rbln_config.output_hidden_states: + return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:])) + else: + return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None) def _prepare_prefill_inputs( self, @@ -436,6 +439,7 @@ def prefill_forward( # Process input in chunks of size `prefill_chunk_size` output_logits = [] + all_hidden_states = [] if self.rbln_config.output_hidden_states else None for step in range(0, query_length, self.rbln_config.prefill_chunk_size): s, e = step, step + self.rbln_config.prefill_chunk_size # Extract the current chunk of inputs, cache positions, position ids, and position embeddings @@ -468,7 +472,7 @@ def prefill_forward( query_position = None # Forward pass for the current chunk - output_logit = super().forward( + outputs = super().forward( input_chunk, cache_pos_chunk, block_tables, @@ -478,9 +482,25 @@ def prefill_forward( chunked_attention_mask if self.rbln_config.use_attention_mask else None, position_ids_chunk, lora_int_ids if self.rbln_config.use_lora else None, - out=self.out_buffers, + out=None + if self.rbln_config.output_hidden_states + else self.out_buffers, # TODO(taehoon): add hidden states output ) - output_logits.append(output_logit) + if self.rbln_config.output_hidden_states: + output_logits.append(outputs[0]) + all_hidden_states.append(tuple(outputs[1:])) + else: + output_logits.append(outputs) + + if self.rbln_config.output_hidden_states: + num_hidden_layers = len(all_hidden_states[0]) - 1 + concatenated_hidden_states = () + for l_idx in range(num_hidden_layers + 1): + l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1) + l_hidden_states = l_hidden_states[:, :query_length, :] + concatenated_hidden_states += (l_hidden_states,) + + all_hidden_states = concatenated_hidden_states # Aggregate output_logits output_logits = torch.concat(output_logits, dim=-2) @@ -505,4 +525,6 @@ def prefill_forward( self.dec_attn_mask[batch_idx].fill_(0) self.dec_attn_mask[batch_idx, :, :, :query_length] = 1 - return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths) + return RBLNDecoderOnlyOutput( + logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states + ) diff --git a/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py b/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py index 676214c4a..532496612 100644 --- a/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +++ b/src/optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py @@ -603,6 +603,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None, position_embed: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, **kwargs, ) -> Tuple[torch.FloatTensor]: inputs = inputs_embeds if inputs_embeds is not None else input_ids @@ -613,23 +614,52 @@ def forward( f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})." ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states + ) + if output_hidden_states != self.rbln_config.output_hidden_states: + raise ValueError( + f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} " + f"Please compile again with the correct argument." + ) + all_last_hidden_states = [] + all_hidden_states = ( + tuple( + torch.zeros( + self.rbln_config.batch_size, + inputs.shape[1], + self.config.hidden_size, + dtype=self.rbln_config.torch_dtype, + ) + for _ in range(self.config.num_hidden_layers + 1) + ) + if output_hidden_states + else None + ) for b_idx in range(self.rbln_config.batch_size): query_length = ( attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1] ) cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0) - last_hidden_states = self.prefill_decoder( + outputs = self.prefill_decoder( inputs[b_idx : b_idx + 1], attention_mask=attention_mask[b_idx] if attention_mask is not None else None, position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None, cache_position=cache_position, batch_idx=b_idx, - ).logits - all_last_hidden_states.append(last_hidden_states) + ) + all_last_hidden_states.append(outputs.logits) + + if output_hidden_states: + for l_idx in range(self.config.num_hidden_layers + 1): + mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0] + all_hidden_states[l_idx][b_idx].index_copy_( + dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0] + ) last_hidden_states = torch.concat(all_last_hidden_states, dim=0) - return BaseModelOutputWithPast(last_hidden_state=last_hidden_states) + return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states) class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin): @@ -725,6 +755,7 @@ def forward( token_type_ids: Optional[torch.Tensor] = None, lora_int_ids: Optional[torch.Tensor] = None, return_dict: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, **kwargs, ) -> Tuple[torch.FloatTensor]: # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API. @@ -748,6 +779,15 @@ def forward( ) padded_cache_lengths = torch.zeros_like(generate_idx) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states + ) + if output_hidden_states != self.rbln_config.output_hidden_states: + raise ValueError( + f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} " + f"Please compile again with the correct argument." + ) + # Prefill if cache_position is None: logits = [] @@ -763,9 +803,17 @@ def forward( f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})." ) + all_hidden_states = ( + tuple( + torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype) + for _ in range(self.config.num_hidden_layers + 1) + ) + if self.rbln_config.output_hidden_states + else None + ) for b_idx in range(batch_size): cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0) - output = self.prefill_decoder( + outputs = self.prefill_decoder( input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None, inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None, attention_mask=attention_mask[b_idx] if attention_mask is not None else None, @@ -774,8 +822,16 @@ def forward( token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None, lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None, ) - padded_cache_lengths[b_idx] += output.padded_cache_lengths - logits.append(output.logits) + padded_cache_lengths[b_idx] += outputs.padded_cache_lengths + logits.append(outputs.logits) + + if self.rbln_config.output_hidden_states: + for l_idx in range(self.config.num_hidden_layers + 1): + mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0] + all_hidden_states[l_idx][b_idx].index_copy_( + dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0] + ) + logits = torch.cat(logits, dim=0) # Decoder else: @@ -796,17 +852,22 @@ def forward( f"or `max_length` in the generation config." ) - logits = self.decoders[batch_size]( + outputs = self.decoders[batch_size]( input_ids=input_ids, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids if self.rbln_config.use_position_ids else None, lora_int_ids=lora_int_ids, - ).logits + ) + logits = outputs.logits + all_hidden_states = outputs.hidden_states if not return_dict: - return logits, generate_idx, padded_cache_lengths + return logits, generate_idx, padded_cache_lengths, all_hidden_states else: return RBLNDecoderOnlyOutput( - logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths + logits=logits, + generate_idx=generate_idx, + padded_cache_lengths=padded_cache_lengths, + hidden_states=all_hidden_states, ) diff --git a/src/optimum/rbln/transformers/models/gemma3/gemma3_architecture.py b/src/optimum/rbln/transformers/models/gemma3/gemma3_architecture.py index 73a70e0db..7babfc126 100644 --- a/src/optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +++ b/src/optimum/rbln/transformers/models/gemma3/gemma3_architecture.py @@ -64,6 +64,7 @@ def forward( global_block_tables: Optional[torch.Tensor] = None, local_block_tables: Optional[torch.Tensor] = None, lora_int_id: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, ): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): @@ -96,7 +97,10 @@ def forward( sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position) + all_hidden_states = () if output_hidden_states else None for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) is_sliding = True if layer_idx in self.sliding_window_layers else False hidden_states = layer( hidden_states=hidden_states, @@ -110,7 +114,9 @@ def forward( ) hidden_states = self.get_last_layernorm()(hidden_states) - return hidden_states + if output_hidden_states: + all_hidden_states += (hidden_states,) + return hidden_states, all_hidden_states class Gemma3DecoderLayer(DecoderOnlyLayer): diff --git a/src/optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py b/src/optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py index 67f27fdff..ad67ad092 100644 --- a/src/optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +++ b/src/optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py @@ -106,6 +106,8 @@ def prefill_forward( ) step = 0 + output_logits = [] + all_hidden_states = [] if self.rbln_config.output_hidden_states else None while step < query_length: if self.rbln_config.use_image_prefill: # Check if the prefill chunk is an image prefill @@ -146,7 +148,7 @@ def prefill_forward( query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16) if is_image_prefill: - logits = self.image_prefill( + outputs = self.image_prefill( input_chunk, cache_pos_chunk, block_tables, @@ -157,7 +159,7 @@ def prefill_forward( lora_int_ids if self.rbln_config.use_lora else None, ) else: - logits = self.prefill( + outputs = self.prefill( input_chunk, cache_pos_chunk, block_tables, @@ -168,14 +170,51 @@ def prefill_forward( lora_int_ids if self.rbln_config.use_lora else None, ) + if self.rbln_config.output_hidden_states: + output_logits.append(outputs[0]) + all_hidden_states.append(tuple(outputs[1:])) + else: + output_logits.append(outputs) + padded_cache_lengths += current_padded_cache_lengths step += num_processed_tokens + if self.rbln_config.output_hidden_states: + num_hidden_layers = len(all_hidden_states[0]) - 1 + concatenated_hidden_states = () + for l_idx in range(num_hidden_layers + 1): + l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1) + l_hidden_states = l_hidden_states[:, :query_length, :] + concatenated_hidden_states += (l_hidden_states,) + + all_hidden_states = concatenated_hidden_states + + # Aggregate output_logits + output_logits = torch.concat(output_logits, dim=-2) + if self.rbln_config.logits_to_keep > 0: + output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :] + else: + output_logits = output_logits[:, :query_length, :] + # index copy for masked output_logits + if attention_mask is not None: + new_output_logits = torch.full( + (1, attention_mask.shape[-1], output_logits.shape[-1]), + fill_value=1e-10, + dtype=output_logits.dtype, + ) + mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0] + new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits) + + output_logits = new_output_logits + if not is_external_block_tables: self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask return RBLNGemma3ForCausalLMOutput( - logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask + logits=output_logits, + padded_cache_lengths=padded_cache_lengths, + attention_mask=chunked_attention_mask, + hidden_states=all_hidden_states, ) def decode_forward( @@ -240,6 +279,9 @@ def decode_forward( if attention_mask is not None and self.batch_size < attention_mask.shape[0]: attention_mask = attention_mask[: self.batch_size] - logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids) + outputs = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids) - return RBLNDecoderOnlyOutput(logits=logits) + if self.rbln_config.output_hidden_states: + return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:])) + else: + return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None) diff --git a/src/optimum/rbln/transformers/models/gemma3/modeling_gemma3.py b/src/optimum/rbln/transformers/models/gemma3/modeling_gemma3.py index 176ab9989..efe4bb5c3 100644 --- a/src/optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +++ b/src/optimum/rbln/transformers/models/gemma3/modeling_gemma3.py @@ -299,28 +299,60 @@ def forward( generate_idx: Optional[torch.Tensor] = None, padded_cache_lengths: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = None, **lm_kwargs: Dict[str, Any], ) -> Union[Tuple, RBLNDecoderOnlyOutput]: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.rbln_config.language_model.output_hidden_states + ) + if output_hidden_states != self.rbln_config.language_model.output_hidden_states: + raise ValueError( + f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.language_model.output_hidden_states {self.rbln_config.language_model.output_hidden_states} " + f"Please compile again with the correct argument." + ) + # prefill if cache_position is None: logits = [] inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values) batch_size = inputs_embeds.shape[0] + all_hidden_states = ( + tuple( + torch.zeros( + batch_size, + inputs_embeds.shape[1], + self.config.text_config.hidden_size, + dtype=self.rbln_config.torch_dtype, + ) + for _ in range(self.config.text_config.num_hidden_layers + 1) + ) + if self.rbln_config.language_model.output_hidden_states + else None + ) + for b_idx in range(batch_size): cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0) token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()] cache_position = self.get_padded_cache_position(cache_position, token_type_id) - output = self.language_model.prefill_decoder( + outputs = self.language_model.prefill_decoder( inputs_embeds=inputs_embeds[b_idx : b_idx + 1], attention_mask=attention_mask[b_idx], cache_position=cache_position, batch_idx=b_idx, token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id ) - padded_cache_lengths[b_idx] += output.padded_cache_lengths - logits.append(output.logits) + padded_cache_lengths[b_idx] += outputs.padded_cache_lengths + logits.append(outputs.logits) + if self.rbln_config.language_model.output_hidden_states: + for l_idx in range(self.config.text_config.num_hidden_layers + 1): + mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0] + all_hidden_states[l_idx][b_idx].index_copy_( + dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0] + ) logits = torch.cat(logits, dim=0) # decoder @@ -334,15 +366,20 @@ def forward( f"Please run your model with one of these batch sizes or add support for batch size {batch_size}." ) - logits = self.language_model.decoders[batch_size]( + outputs = self.language_model.decoders[batch_size]( input_ids=input_ids, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None, - ).logits + ) + logits = outputs.logits + all_hidden_states = outputs.hidden_states return RBLNDecoderOnlyOutput( - logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths + logits=logits, + generate_idx=generate_idx, + padded_cache_lengths=padded_cache_lengths, + hidden_states=all_hidden_states, ) diff --git a/tests/test_llm.py b/tests/test_llm.py index 3cef0ec43..8f1a77951 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -120,6 +120,26 @@ class TestQwen2Model(LLMTest.TestLLMWithoutLMHead): HF_CONFIG_KWARGS = {"num_hidden_layers": 1, "layer_types": ["full_attention"], "max_position_embeddings": 1024} +class TestQwen2ForCausalLM_OutputHiddenStates(TestQwen2ForCausalLM): + RBLN_CLASS_KWARGS = {"rbln_config": {"output_hidden_states": True}} + + def get_inputs(self): + inputs = super().get_inputs() + inputs["return_dict_in_generate"] = True + inputs["output_hidden_states"] = True + return inputs + + +class TestQwen2Model_OutputHiddenStates(TestQwen2Model): + RBLN_CLASS_KWARGS = {"rbln_config": {"output_hidden_states": True}} + + def get_inputs(self): + inputs = super().get_inputs() + inputs["return_dict_in_generate"] = True + inputs["output_hidden_states"] = True + return inputs + + class TestQwen3ForCausalLM(LLMTest.TestLLM): RBLN_CLASS = RBLNQwen3ForCausalLM HF_MODEL_ID = "trl-internal-testing/tiny-Qwen3ForCausalLM" @@ -643,6 +663,20 @@ def get_inputs(self): return inputs +class TestGemma3ForConditionalGeneration_OutputHiddenStates(TestGemma3ForConditionalGeneration): + RBLN_CLASS_KWARGS = { + "rbln_config": { + "language_model": {"use_inputs_embeds": True, "kvcache_partition_len": 4096, "output_hidden_states": True} + } + } + + def get_inputs(self): + inputs = super().get_inputs() + inputs["return_dict_in_generate"] = True + inputs["output_hidden_states"] = True + return inputs + + class TestGemma3ForCausalLM(LLMTest.TestLLM): RBLN_CLASS = RBLNGemma3ForCausalLM HF_MODEL_ID = "google/gemma-3-1b-it"