Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/optimum/rbln/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class RBLNDecoderOnlyOutput(ModelOutput):
logits: torch.FloatTensor = None
generate_idx: torch.Tensor = None
padded_cache_lengths: int = None
hidden_states: Tuple[torch.FloatTensor] = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
sliding_window_layers: Optional[List[int]] = None,
phases: Optional[List[PhaseType]] = None,
logits_to_keep: Optional[int] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
["prefill", "decode"] if DecoderOnlyModelForCausalLM is used.
logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
kwargs: Additional arguments passed to the parent RBLNModelConfig.

Raises:
Expand Down Expand Up @@ -232,6 +234,8 @@ def __init__(
if self.logits_to_keep is not None and self.logits_to_keep > 1:
raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")

self.output_hidden_states = output_hidden_states or False

self.decoder_batch_sizes = None
if "decode" in self.phases:
self.decoder_batch_sizes = decoder_batch_sizes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(self, *args):
rotary_emb,
) = self.prepare_forward_args(*args)

logit = self.model(
logits, all_hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
Expand All @@ -215,9 +215,13 @@ def forward(self, *args):
global_block_tables=global_block_tables,
local_block_tables=local_block_tables,
lora_int_id=lora_int_id,
output_hidden_states=self.rbln_config.output_hidden_states,
)

return logit
if self.rbln_config.output_hidden_states:
return logits, all_hidden_states
else:
return logits


class DecoderOnlyForCausalLM(nn.Module):
Expand Down Expand Up @@ -272,9 +276,10 @@ def forward(
global_block_tables: Optional[torch.Tensor] = None,
local_block_tables: Optional[torch.Tensor] = None,
lora_int_id: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
):
# outputs
hidden_states = self.model(
hidden_states, all_hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
Expand All @@ -286,6 +291,7 @@ def forward(
global_block_tables=global_block_tables,
local_block_tables=local_block_tables,
lora_int_id=lora_int_id,
output_hidden_states=output_hidden_states,
)

if "prefill" in self.phase:
Expand All @@ -299,7 +305,7 @@ def forward(
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping

return logits
return logits, all_hidden_states


class DecoderOnlyModel(nn.Module):
Expand Down Expand Up @@ -398,6 +404,7 @@ def forward(
global_block_tables: Optional[torch.Tensor] = None,
local_block_tables: Optional[torch.Tensor] = None,
lora_int_id: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
):
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -460,7 +467,11 @@ def forward(
if len(self.sliding_window_layers) > 0:
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)

all_hidden_states = () if output_hidden_states else None
for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

is_sliding = True if layer_idx in self.sliding_window_layers else False
hidden_states = layer(
hidden_states=hidden_states,
Expand All @@ -474,7 +485,10 @@ def forward(
)

hidden_states = self.get_last_layernorm()(hidden_states)
return hidden_states
if output_hidden_states:
all_hidden_states += (hidden_states,)

return hidden_states, all_hidden_states


class DecoderOnlyLayer(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def decode_forward(

attention_mask = self.dec_attn_mask

logits = super().forward(
outputs = super().forward(
inputs,
cache_position,
block_tables,
Expand All @@ -312,7 +312,10 @@ def decode_forward(
lora_int_ids if self.rbln_config.use_lora else None,
)

return RBLNDecoderOnlyOutput(logits=logits)
if self.rbln_config.output_hidden_states:
return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
else:
return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)

def _prepare_prefill_inputs(
self,
Expand Down Expand Up @@ -436,6 +439,7 @@ def prefill_forward(

# Process input in chunks of size `prefill_chunk_size`
output_logits = []
all_hidden_states = [] if self.rbln_config.output_hidden_states else None
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
s, e = step, step + self.rbln_config.prefill_chunk_size
# Extract the current chunk of inputs, cache positions, position ids, and position embeddings
Expand Down Expand Up @@ -468,7 +472,7 @@ def prefill_forward(
query_position = None

# Forward pass for the current chunk
output_logit = super().forward(
outputs = super().forward(
input_chunk,
cache_pos_chunk,
block_tables,
Expand All @@ -478,9 +482,23 @@ def prefill_forward(
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
position_ids_chunk,
lora_int_ids if self.rbln_config.use_lora else None,
out=self.out_buffers,
out=None if self.rbln_config.output_hidden_states else self.out_buffers,
)
output_logits.append(output_logit)
if self.rbln_config.output_hidden_states:
output_logits.append(outputs[0])
all_hidden_states.append(tuple(outputs[1:]))
else:
output_logits.append(outputs)

if self.rbln_config.output_hidden_states:
num_hidden_layers = len(all_hidden_states[0]) - 1
concatenated_hidden_states = ()
for l_idx in range(num_hidden_layers + 1):
l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
l_hidden_states = l_hidden_states[:, :query_length, :]
concatenated_hidden_states += (l_hidden_states,)

all_hidden_states = concatenated_hidden_states

# Aggregate output_logits
output_logits = torch.concat(output_logits, dim=-2)
Expand All @@ -505,4 +523,6 @@ def prefill_forward(
self.dec_attn_mask[batch_idx].fill_(0)
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1

return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
return RBLNDecoderOnlyOutput(
logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
)
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_embed: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.FloatTensor]:
inputs = inputs_embeds if inputs_embeds is not None else input_ids
Expand All @@ -613,23 +614,52 @@ def forward(
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
)
if output_hidden_states != self.rbln_config.output_hidden_states:
raise ValueError(
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
f"Please compile again with the correct argument."
)

all_last_hidden_states = []
all_hidden_states = (
tuple(
torch.zeros(
self.rbln_config.batch_size,
inputs.shape[1],
self.config.hidden_size,
dtype=self.rbln_config.torch_dtype,
)
for _ in range(self.config.num_hidden_layers + 1)
)
if output_hidden_states
else None
)
for b_idx in range(self.rbln_config.batch_size):
query_length = (
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
)
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
last_hidden_states = self.prefill_decoder(
output = self.prefill_decoder(
inputs[b_idx : b_idx + 1],
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
cache_position=cache_position,
batch_idx=b_idx,
).logits
all_last_hidden_states.append(last_hidden_states)
)
all_last_hidden_states.append(output.logits)

if output_hidden_states:
for l_idx in range(self.config.num_hidden_layers + 1):
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
all_hidden_states[l_idx][b_idx].index_copy_(
dim=0, index=mask_indices, source=output.hidden_states[l_idx][0]
)

last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)


class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
Expand Down Expand Up @@ -725,6 +755,7 @@ def forward(
token_type_ids: Optional[torch.Tensor] = None,
lora_int_ids: Optional[torch.Tensor] = None,
return_dict: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.FloatTensor]:
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
Expand All @@ -748,6 +779,15 @@ def forward(
)
padded_cache_lengths = torch.zeros_like(generate_idx)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
)
if output_hidden_states != self.rbln_config.output_hidden_states:
raise ValueError(
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
f"Please compile again with the correct argument."
)

# Prefill
if cache_position is None:
logits = []
Expand All @@ -763,6 +803,14 @@ def forward(
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
)

all_hidden_states = (
tuple(
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
for _ in range(self.config.num_hidden_layers + 1)
)
if self.rbln_config.output_hidden_states
else None
)
for b_idx in range(batch_size):
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
output = self.prefill_decoder(
Expand All @@ -776,6 +824,14 @@ def forward(
)
padded_cache_lengths[b_idx] += output.padded_cache_lengths
logits.append(output.logits)

if self.rbln_config.output_hidden_states:
for l_idx in range(self.config.num_hidden_layers + 1):
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
all_hidden_states[l_idx][b_idx].index_copy_(
dim=0, index=mask_indices, source=output.hidden_states[l_idx][0]
)

logits = torch.cat(logits, dim=0)
# Decoder
else:
Expand All @@ -796,17 +852,22 @@ def forward(
f"or `max_length` in the generation config."
)

logits = self.decoders[batch_size](
output = self.decoders[batch_size](
input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids if self.rbln_config.use_position_ids else None,
lora_int_ids=lora_int_ids,
).logits
)
logits = output.logits
all_hidden_states = output.hidden_states

if not return_dict:
return logits, generate_idx, padded_cache_lengths
return logits, generate_idx, padded_cache_lengths, all_hidden_states
else:
return RBLNDecoderOnlyOutput(
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
logits=logits,
generate_idx=generate_idx,
padded_cache_lengths=padded_cache_lengths,
hidden_states=all_hidden_states,
)
Loading
Loading