File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -618,7 +618,14 @@ def _forward(
618618 if input_ids is not None :
619619 hidden_states = self .embed_tokens (input_ids )
620620 else :
621- hidden_states = seq_ctx .inputs_embeds
621+ assert seq_ctx .inputs_embeds is not None , "inputs_embeds should not be None when input_ids is None"
622+ # The clone here is mainly for ActivationOffload. The current offload implementation modifies
623+ # the input tensor in-place, causing subsequent accesses to input_embeds to get a tensor with
624+ # empty storage and trigger errors. So we clone here to ensure later accesses to input_embeds
625+ # won't fail. However, there are two remaining caveats:
626+ # 1. The extra clone may introduce a slight performance overhead.
627+ # 2. hidden_states itself still cannot be reused, as offload will leave it with empty storage.
628+ hidden_states = seq_ctx .inputs_embeds .clone ()
622629
623630 # create position embeddings to be shared across the decoder layers
624631 assert position_ids is not None
You can’t perform that action at this time.
0 commit comments