11# Copyright (c) OpenMMLab. All rights reserved.
2- from dataclasses import dataclass
32from typing import cast
43
54import torch
98from .utils import pad_to_multiple_of , split_for_sequence_parallel
109
1110
12- @dataclass
11+ # Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above
12+ # The extra ops is introduced by function _apply_to_tensors in
13+ # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/fsdp/_fully_shard/_fsdp_state.py
14+ # Due to dataclasses.replace is called in _apply_to_tensors that triggering SequenceContext.__init__
1315class SequenceContext :
1416 """Keyword arguments for Flash Attention with Compile.
1517
@@ -29,26 +31,25 @@ class SequenceContext:
2931 cu_seq_lens_k : torch .IntTensor
3032 max_length_q : torch .Tensor
3133 max_length_k : torch .Tensor
32- num_padding : int = 0
33- sequence_parallel_mesh : DeviceMesh | None = None
34- block_table : torch .Tensor | None = None
35- device : str | torch .device = "cpu" # TODO: 这个地方有点乱,到处是 device
36- position_ids : torch .LongTensor | None = None
34+ num_padding : int
35+ sequence_parallel_mesh : DeviceMesh | None
36+ block_table : torch .Tensor | None
37+ device : str | torch .device # TODO: 这个地方有点乱,到处是 device
38+ position_ids : torch .LongTensor | None
3739
3840 # Intern-S1
39- image_flags : torch .LongTensor | None = None
41+ image_flags : torch .LongTensor | None
4042 # Qwen3VL
41- image_grid_thw : torch .Tensor | None = None
42- deepstack_visual_embeds : list [torch .Tensor ] | None = None
43- visual_pos_masks : torch .Tensor | None = None
44-
43+ image_grid_thw : torch .Tensor | None
44+ deepstack_visual_embeds : list [torch .Tensor ] | None
45+ visual_pos_masks : torch .Tensor | None
4546 # mllm model
46- pixel_values : torch .FloatTensor | None = None
47- inputs_embeds : torch .FloatTensor | None = None
48- num_img_tokens : list [int ] | None = None
47+ pixel_values : torch .FloatTensor | None
48+ inputs_embeds : torch .FloatTensor | None
49+ num_img_tokens : list [int ] | None
4950
5051 # moe routed_experts
51- rollout_routed_experts : torch .LongTensor | None = None
52+ rollout_routed_experts : torch .LongTensor | None
5253
5354 def __init__ (
5455 self ,
0 commit comments