Skip to content

[BUG] Bug in Llama inference #369

@duynht

Description

@duynht
  • There's a bug of duplicate code w/ wrong indentation level when computing attention_output in CausalSelfAttention._forward_inference. Currently it's never computed.

    if self.rope_interleaved:
    query_states = self.rotary_embedding(query_states, position_ids=position_ids)
    key_states = self.rotary_embedding(key_states, position_ids=position_ids)
    else:
    cos, sin = self.rotary_embedding(value_states, position_ids)
    query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(query_states, key_states, cos, sin)
    # Compute rotary embeddings
    # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
    old_rotary_embed_end = self.rotary_embedding.end
    # interleaved version.
    if self.rope_interleaved:
    query_states = self.rotary_embedding(query_states, position_ids=position_ids)
    key_states = self.rotary_embedding(key_states, position_ids=position_ids)
    # non interleaved version.
    else:
    cos, sin = self.rotary_embedding(value_states, position_ids)
    query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(
    query_states, key_states, cos, sin
    )
    if "key" not in store:
    # First inference iteration (Prefill)
    # TODO @nouamane: support custom masking
    # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
    # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
    assert ~(
    sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
    ).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
    # preallocate k_cache, v_cache to self.prefill_kv_len
    k_cache = torch.zeros(
    (
    batch_size,
    self.prefill_kv_len,
    self.n_local_kv_heads,
    self.d_qk,
    ),
    dtype=query_states.dtype,
    device=query_states.device,
    )
    v_cache = torch.zeros(
    (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
    dtype=query_states.dtype,
    device=query_states.device,
    )
    # Remove pad tokens from key_states and concatenate samples in key_unpad
    # cu_seqlens_k is the cumulative sequence lengths of key_states
    (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
    query_states,
    sequence_mask,
    )
    (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
    key_states, sequence_mask
    )
    (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
    # NOTE: this scale is for µTransfer,
    # in SP, we use sqrt(1/d_h)
    softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
    output_unpad = flash_attn_varlen_func(
    q=query_unpad, # (total_q, n_local_q_heads, d_qk)
    k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
    v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
    cu_seqlens_q=cu_seqlens_q,
    cu_seqlens_k=cu_seqlens_k,
    max_seqlen_q=max_seqlen_q,
    max_seqlen_k=max_seqlen_k,
    dropout_p=0.0,
    softmax_scale=softmax_scale,
    causal=True, # True in prefill phase, False in subsequent phases
    return_attn_probs=False,
    ) # (total_unpadded, n_local_q_heads, d_v)
    attention_output = bert_padding.pad_input(
    output_unpad, indices_q, batch_size, q_length
    ) # (batch_size, q_length, n_local_q_heads, d_v)
    pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
    pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
    else:
    # Pull pre-computed key/value states
    # Subsequent inference iterations (q_length=1)
    k_cache = store["key"]
    v_cache = store["value"]
    # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
    # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
    if self.rotary_embedding.end > old_rotary_embed_end:
    k_cache = torch.cat(
    [
    k_cache,
    torch.zeros(
    (
    batch_size,
    self.rotary_embedding.end - old_rotary_embed_end,
    self.n_local_kv_heads,
    self.d_qk,
    ),
    dtype=query_states.dtype,
    device=query_states.device,
    ),
    ],
    dim=1,
    )
    v_cache = torch.cat(
    [
    v_cache,
    torch.zeros(
    (
    batch_size,
    self.rotary_embedding.end - old_rotary_embed_end,
    self.n_local_kv_heads,
    self.d_v,
    ),
    dtype=query_states.dtype,
    device=query_states.device,
    ),
    ],
    dim=1,
    )
    assert (
    k_cache.shape[1] == self.rotary_embedding.end
    ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
    assert (
    v_cache.shape[1] == self.rotary_embedding.end
    ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
    # [batch_size, seq_length, num_heads, d_qk]
    query_states = query_states.view(
    batch_size, q_length, self.n_local_q_heads, self.d_qk
    ) # [batch_size, q_length, self.n_heads, d_qk]
    kv_length = key_states.shape[1]
    key_states = key_states.view(
    batch_size, kv_length, self.n_local_kv_heads, self.d_qk
    ) # [batch_size, kv_length, self.n_heads, d_qk]
    value_states = value_states.view(
    batch_size, kv_length, self.n_local_kv_heads, self.d_v
    ) # [batch_size, kv_length, self.n_heads, d_v]
    # NOTE: this scale is for µTransfer,
    # in SP, we use sqrt(1/d_h)
    softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
    attention_output = flash_attn_with_kvcache(
    query_states,
    k_cache,
    v_cache,
    key_states,
    value_states,
    rotary_cos=None,
    rotary_sin=None,
    # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
    cache_seqlens=position_offsets.contiguous(),
    softmax_scale=softmax_scale,
    causal=True,
    rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention
    )
    store.update(
    {
    "key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
    "value": v_cache,
    "position_offsets": position_offsets,
    }
    )

  • Wrong argument passed to parametrizator_cls when init_model_randomly for testing

    parametrizator = parametrizator_cls(config=config.model)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions