Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/optimum/rbln/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class RBLNDecoderOnlyOutput(ModelOutput):
logits: torch.FloatTensor = None
generate_idx: torch.Tensor = None
padded_cache_lengths: int = None
hidden_states: Tuple[torch.FloatTensor] = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
visual: Optional[RBLNModelConfig] = None,
batch_size: Optional[int] = None,
use_inputs_embeds: bool = True,
output_hidden_states: Optional[bool] = False,
**kwargs,
):
super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
Expand All @@ -71,4 +70,3 @@ def __init__(
raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")

self.visual = visual
self.output_hidden_states = output_hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
sliding_window_layers: Optional[List[int]] = None,
phases: Optional[List[PhaseType]] = None,
logits_to_keep: Optional[int] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
["prefill", "decode"] if DecoderOnlyModelForCausalLM is used.
logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
kwargs: Additional arguments passed to the parent RBLNModelConfig.

Raises:
Expand Down Expand Up @@ -232,6 +234,8 @@ def __init__(
if self.logits_to_keep is not None and self.logits_to_keep > 1:
raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")

self.output_hidden_states = output_hidden_states or False

self.decoder_batch_sizes = None
if "decode" in self.phases:
self.decoder_batch_sizes = decoder_batch_sizes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(self, *args):
rotary_emb,
) = self.prepare_forward_args(*args)

logit = self.model(
logits, all_hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
Expand All @@ -215,9 +215,13 @@ def forward(self, *args):
global_block_tables=global_block_tables,
local_block_tables=local_block_tables,
lora_int_id=lora_int_id,
output_hidden_states=self.rbln_config.output_hidden_states,
)

return logit
if self.rbln_config.output_hidden_states:
return logits, all_hidden_states
else:
return logits


class DecoderOnlyForCausalLM(nn.Module):
Expand Down Expand Up @@ -272,9 +276,10 @@ def forward(
global_block_tables: Optional[torch.Tensor] = None,
local_block_tables: Optional[torch.Tensor] = None,
lora_int_id: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
):
# outputs
hidden_states = self.model(
hidden_states, all_hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
Expand All @@ -286,6 +291,7 @@ def forward(
global_block_tables=global_block_tables,
local_block_tables=local_block_tables,
lora_int_id=lora_int_id,
output_hidden_states=output_hidden_states,
)

if "prefill" in self.phase:
Expand All @@ -299,7 +305,7 @@ def forward(
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping

return logits
return logits, all_hidden_states


class DecoderOnlyModel(nn.Module):
Expand Down Expand Up @@ -398,6 +404,7 @@ def forward(
global_block_tables: Optional[torch.Tensor] = None,
local_block_tables: Optional[torch.Tensor] = None,
lora_int_id: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
):
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -460,7 +467,11 @@ def forward(
if len(self.sliding_window_layers) > 0:
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)

all_hidden_states = () if output_hidden_states else None
for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

is_sliding = True if layer_idx in self.sliding_window_layers else False
hidden_states = layer(
hidden_states=hidden_states,
Expand All @@ -474,7 +485,10 @@ def forward(
)

hidden_states = self.get_last_layernorm()(hidden_states)
return hidden_states
if output_hidden_states:
all_hidden_states += (hidden_states,)

return hidden_states, all_hidden_states


class DecoderOnlyLayer(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def decode_forward(

attention_mask = self.dec_attn_mask

logits = super().forward(
outputs = super().forward(
inputs,
cache_position,
block_tables,
Expand All @@ -312,7 +312,10 @@ def decode_forward(
lora_int_ids if self.rbln_config.use_lora else None,
)

return RBLNDecoderOnlyOutput(logits=logits)
if self.rbln_config.output_hidden_states:
return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
else:
return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)

def _prepare_prefill_inputs(
self,
Expand Down Expand Up @@ -436,6 +439,7 @@ def prefill_forward(

# Process input in chunks of size `prefill_chunk_size`
output_logits = []
all_hidden_states = [] if self.rbln_config.output_hidden_states else None
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
s, e = step, step + self.rbln_config.prefill_chunk_size
# Extract the current chunk of inputs, cache positions, position ids, and position embeddings
Expand Down Expand Up @@ -468,7 +472,7 @@ def prefill_forward(
query_position = None

# Forward pass for the current chunk
output_logit = super().forward(
outputs = super().forward(
input_chunk,
cache_pos_chunk,
block_tables,
Expand All @@ -478,9 +482,25 @@ def prefill_forward(
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
position_ids_chunk,
lora_int_ids if self.rbln_config.use_lora else None,
out=self.out_buffers,
out=None
if self.rbln_config.output_hidden_states
else self.out_buffers, # TODO(taehoon): add hidden states output
)
output_logits.append(output_logit)
if self.rbln_config.output_hidden_states:
output_logits.append(outputs[0])
all_hidden_states.append(tuple(outputs[1:]))
else:
output_logits.append(outputs)

if self.rbln_config.output_hidden_states:
num_hidden_layers = len(all_hidden_states[0]) - 1
concatenated_hidden_states = ()
for l_idx in range(num_hidden_layers + 1):
l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
l_hidden_states = l_hidden_states[:, :query_length, :]
concatenated_hidden_states += (l_hidden_states,)

all_hidden_states = concatenated_hidden_states

# Aggregate output_logits
output_logits = torch.concat(output_logits, dim=-2)
Expand All @@ -505,4 +525,6 @@ def prefill_forward(
self.dec_attn_mask[batch_idx].fill_(0)
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1

return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
return RBLNDecoderOnlyOutput(
logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
)
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_embed: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.FloatTensor]:
inputs = inputs_embeds if inputs_embeds is not None else input_ids
Expand All @@ -613,23 +614,52 @@ def forward(
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
)
if output_hidden_states != self.rbln_config.output_hidden_states:
raise ValueError(
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
f"Please compile again with the correct argument."
)

all_last_hidden_states = []
all_hidden_states = (
tuple(
torch.zeros(
self.rbln_config.batch_size,
inputs.shape[1],
self.config.hidden_size,
dtype=self.rbln_config.torch_dtype,
)
for _ in range(self.config.num_hidden_layers + 1)
)
if output_hidden_states
else None
)
for b_idx in range(self.rbln_config.batch_size):
query_length = (
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
)
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
last_hidden_states = self.prefill_decoder(
outputs = self.prefill_decoder(
inputs[b_idx : b_idx + 1],
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
cache_position=cache_position,
batch_idx=b_idx,
).logits
all_last_hidden_states.append(last_hidden_states)
)
all_last_hidden_states.append(outputs.logits)

if output_hidden_states:
for l_idx in range(self.config.num_hidden_layers + 1):
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
all_hidden_states[l_idx][b_idx].index_copy_(
dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
)

last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)


class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
Expand Down Expand Up @@ -725,6 +755,7 @@ def forward(
token_type_ids: Optional[torch.Tensor] = None,
lora_int_ids: Optional[torch.Tensor] = None,
return_dict: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.FloatTensor]:
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
Expand All @@ -748,6 +779,15 @@ def forward(
)
padded_cache_lengths = torch.zeros_like(generate_idx)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
)
if output_hidden_states != self.rbln_config.output_hidden_states:
raise ValueError(
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
f"Please compile again with the correct argument."
)

# Prefill
if cache_position is None:
logits = []
Expand All @@ -763,9 +803,17 @@ def forward(
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
)

all_hidden_states = (
tuple(
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
for _ in range(self.config.num_hidden_layers + 1)
)
if self.rbln_config.output_hidden_states
else None
)
for b_idx in range(batch_size):
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
output = self.prefill_decoder(
outputs = self.prefill_decoder(
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
Expand All @@ -774,8 +822,16 @@ def forward(
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
)
padded_cache_lengths[b_idx] += output.padded_cache_lengths
logits.append(output.logits)
padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
logits.append(outputs.logits)

if self.rbln_config.output_hidden_states:
for l_idx in range(self.config.num_hidden_layers + 1):
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
all_hidden_states[l_idx][b_idx].index_copy_(
dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
)

logits = torch.cat(logits, dim=0)
# Decoder
else:
Expand All @@ -796,17 +852,22 @@ def forward(
f"or `max_length` in the generation config."
)

logits = self.decoders[batch_size](
outputs = self.decoders[batch_size](
input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids if self.rbln_config.use_position_ids else None,
lora_int_ids=lora_int_ids,
).logits
)
logits = outputs.logits
all_hidden_states = outputs.hidden_states

if not return_dict:
return logits, generate_idx, padded_cache_lengths
return logits, generate_idx, padded_cache_lengths, all_hidden_states
else:
return RBLNDecoderOnlyOutput(
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
logits=logits,
generate_idx=generate_idx,
padded_cache_lengths=padded_cache_lengths,
hidden_states=all_hidden_states,
)
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def forward(
global_block_tables: Optional[torch.Tensor] = None,
local_block_tables: Optional[torch.Tensor] = None,
lora_int_id: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
):
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -96,7 +97,10 @@ def forward(

sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)

all_hidden_states = () if output_hidden_states else None
for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
is_sliding = True if layer_idx in self.sliding_window_layers else False
hidden_states = layer(
hidden_states=hidden_states,
Expand All @@ -110,7 +114,9 @@ def forward(
)

hidden_states = self.get_last_layernorm()(hidden_states)
return hidden_states
if output_hidden_states:
all_hidden_states += (hidden_states,)
return hidden_states, all_hidden_states


class Gemma3DecoderLayer(DecoderOnlyLayer):
Expand Down
Loading