Skip to content

Commit d5580f3

Browse files
authored
[Fix] extra ops is called in _fsdp_state of pytorch2.8 (InternLM#1237)
* [Fix] extra ops is called in _fsdp_state of pytorch2.8 * [fix] remove the default value of class variable in SequenceContext
1 parent 7d8a1c5 commit d5580f3

1 file changed

Lines changed: 17 additions & 16 deletions

File tree

xtuner/v1/data_proto/sequence_context.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from dataclasses import dataclass
32
from typing import cast
43

54
import torch
@@ -9,7 +8,10 @@
98
from .utils import pad_to_multiple_of, split_for_sequence_parallel
109

1110

12-
@dataclass
11+
# Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above
12+
# The extra ops is introduced by function _apply_to_tensors in
13+
# https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/fsdp/_fully_shard/_fsdp_state.py
14+
# Due to dataclasses.replace is called in _apply_to_tensors that triggering SequenceContext.__init__
1315
class SequenceContext:
1416
"""Keyword arguments for Flash Attention with Compile.
1517
@@ -29,26 +31,25 @@ class SequenceContext:
2931
cu_seq_lens_k: torch.IntTensor
3032
max_length_q: torch.Tensor
3133
max_length_k: torch.Tensor
32-
num_padding: int = 0
33-
sequence_parallel_mesh: DeviceMesh | None = None
34-
block_table: torch.Tensor | None = None
35-
device: str | torch.device = "cpu" # TODO: 这个地方有点乱,到处是 device
36-
position_ids: torch.LongTensor | None = None
34+
num_padding: int
35+
sequence_parallel_mesh: DeviceMesh | None
36+
block_table: torch.Tensor | None
37+
device: str | torch.device # TODO: 这个地方有点乱,到处是 device
38+
position_ids: torch.LongTensor | None
3739

3840
# Intern-S1
39-
image_flags: torch.LongTensor | None = None
41+
image_flags: torch.LongTensor | None
4042
# Qwen3VL
41-
image_grid_thw: torch.Tensor | None = None
42-
deepstack_visual_embeds: list[torch.Tensor] | None = None
43-
visual_pos_masks: torch.Tensor | None = None
44-
43+
image_grid_thw: torch.Tensor | None
44+
deepstack_visual_embeds: list[torch.Tensor] | None
45+
visual_pos_masks: torch.Tensor | None
4546
# mllm model
46-
pixel_values: torch.FloatTensor | None = None
47-
inputs_embeds: torch.FloatTensor | None = None
48-
num_img_tokens: list[int] | None = None
47+
pixel_values: torch.FloatTensor | None
48+
inputs_embeds: torch.FloatTensor | None
49+
num_img_tokens: list[int] | None
4950

5051
# moe routed_experts
51-
rollout_routed_experts: torch.LongTensor | None = None
52+
rollout_routed_experts: torch.LongTensor | None
5253

5354
def __init__(
5455
self,

0 commit comments

Comments
 (0)