Skip to content

Commit 3724180

Browse files
committed
Merge branch 'main' into refactor-llm
2 parents 0f31481 + 0b5708c commit 3724180

File tree

4 files changed

+17
-26
lines changed

4 files changed

+17
-26
lines changed

requirements/runtime.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ tiktoken
1919
torch<=2.1.2
2020
torchvision<=0.16.2
2121
# Minimum 4.36.0 to support `Cache` data structure used by KV Cache
22-
transformers>=4.36.0
22+
# Registering a causal mask in `LlamaModel` is not friendly for very large
23+
# `max_position_embeddings`. Refer to
24+
# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923
25+
transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2
2326
transformers_stream_generator

xtuner/model/modules/dispatch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,12 @@ def dispatch_modules(model, use_varlen_attn=False):
305305
dispatch_internlm2_attn_forward(model, use_varlen_attn)
306306
if USE_TRITON_KERNEL:
307307
dispatch_internlm2_rmsnorm_forward(model)
308-
# replace_internlm2_rote(model)
308+
replace_internlm2_rote(model)
309309
elif 'internlm' in model_name:
310310
dispatch_internlm_attn_forward(model, use_varlen_attn)
311311
if USE_TRITON_KERNEL:
312312
dispatch_internlm_rmsnorm_forward(model)
313-
# replace_internlm_rote(model)
313+
replace_internlm_rote(model)
314314
elif 'llama' in model_name:
315315
dispatch_llama_attn_forward(model, use_varlen_attn)
316316
if USE_TRITON_KERNEL:

xtuner/model/modules/dispatch/llama.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -234,26 +234,18 @@ def llama_attn_forward_legacy(
234234
**kwargs,
235235
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
236236
Optional[Tuple[torch.Tensor]]]:
237-
# LlamaFlashAttention2 attention does not support output_attentions
237+
# Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
238238
if 'padding_mask' in kwargs:
239239
warnings.warn(
240-
'Passing `padding_mask` is deprecated and will be removed in v4.37'
241-
' Please make sure use `attention_mask` instead.`')
242-
243-
# overwrite attention_mask with padding_mask
244-
attention_mask = kwargs.pop('padding_mask')
245-
246-
output_attentions = False
240+
'Passing `padding_mask` is deprecated and will be removed in '
241+
'v4.37. Please make sure use `attention_mask` instead.`')
247242

248243
bsz, q_len, _ = hidden_states.size()
249244

250245
query_states = self.q_proj(hidden_states)
251246
key_states = self.k_proj(hidden_states)
252247
value_states = self.v_proj(hidden_states)
253248

254-
# Flash attention requires the input to have the shape
255-
# batch_size x seq_length x head_dim x hidden_dim
256-
# therefore we just need to keep the original shape
257249
query_states = query_states.view(bsz, q_len, self.num_heads,
258250
self.head_dim).transpose(1, 2)
259251
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
@@ -263,6 +255,13 @@ def llama_attn_forward_legacy(
263255

264256
kv_seq_len = key_states.shape[-2]
265257
if past_key_value is not None:
258+
if self.layer_idx is None:
259+
raise ValueError(
260+
'The cache structure has changed since version v4.36. '
261+
f'If you are using {self.__class__.__name__} '
262+
'for auto-regressive decoding with k/v caching, '
263+
'please make sure to initialize the attention class '
264+
'with a layer index.')
266265
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
267266
self.layer_idx)
268267
assert position_ids is not None
@@ -282,10 +281,6 @@ def llama_attn_forward_legacy(
282281
key_states = repeat_kv(key_states, self.num_key_value_groups)
283282
value_states = repeat_kv(value_states, self.num_key_value_groups)
284283

285-
# repeat kv for sequence parallel
286-
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
287-
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
288-
289284
assert SUPPORT_FLASH2
290285
query_states = query_states.transpose(1, 2)
291286
key_states = key_states.transpose(1, 2)

xtuner/model/sft.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,6 @@ def _build_from_cfg_or_module(self, cfg_or_mod):
206206
return cfg_or_mod
207207
elif isinstance(cfg_or_mod, dict):
208208
traverse_dict(cfg_or_mod)
209-
if SUPPORT_FLASH2:
210-
cfg_or_mod.torch_dtype = torch.bfloat16 \
211-
if torch.cuda.is_bf16_supported() else torch.float16
212-
cfg_or_mod.attn_implementation = 'flash_attention_2'
213-
if max_position_embeddings is not None:
214-
cfg_or_mod = self._prepare_for_long_context_training(
215-
cfg_or_mod, max_position_embeddings)
216209
return BUILDER.build(cfg_or_mod)
217210
else:
218211
raise NotImplementedError
@@ -265,4 +258,4 @@ def __getattr__(self, name: str):
265258
try:
266259
return super().__getattr__(name)
267260
except AttributeError:
268-
return getattr(self.llm, name)
261+
return getattr(self.llm, name)

0 commit comments

Comments
 (0)