Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
403 changes: 403 additions & 0 deletions docs/design/load_spec_refactor.md

Large diffs are not rendered by default.

423 changes: 423 additions & 0 deletions tests/utils/test_load_spec.py

Large diffs are not rendered by default.

1,140 changes: 342 additions & 798 deletions xtuner/v1/model/base.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def fully_shard(
)
self.set_modules_to_forward_prefetch([self.embed_tokens, self.layers["0"]]) # type: ignore

self._init_load_spec()
self._to_empty_meta()

# Make sure it works properly when using fsdp
Expand Down
36 changes: 9 additions & 27 deletions xtuner/v1/model/moe/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xtuner.v1.module.decoder_layer.moe_decoder_layer import MoEActFnConfig
from xtuner.v1.module.rope import RopeScalingConfig
from xtuner.v1.module.router.greedy import GreedyRouterConfig
from xtuner.v1.utils.load_spec import HFLoadPlan

from .moe import MoE

Expand Down Expand Up @@ -44,18 +45,11 @@ def safetensors_to_params(
self,
safetensors: list[torch.Tensor],
local_tensor: torch.Tensor,
param_name: str,
start: int | None,
end: int | None,
dim: int | None,
):
if len(safetensors) > 1:
assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1"
loaded_tensor = torch.cat(safetensors, dim=dim)
else:
loaded_tensor = safetensors[0]
load_plan: HFLoadPlan,
) -> None:
loaded_tensor = self._cat_safetensors(safetensors, load_plan)

if "fused_w1w3.weight" in param_name:
if "fused_w1w3.weight" in load_plan.name:
# hf: num_experts, hidden_size, expert_dim * 2
# xtuner: num_experts * 2 * expert_dim, hidden_size
num_experts, hidden_size = loaded_tensor.shape[:2]
Expand All @@ -64,32 +58,20 @@ def safetensors_to_params(
# # num_experts *2 * expert_dim, hidden_size
loaded_tensor = loaded_tensor.transpose(1, 2).reshape(-1, hidden_size)

elif "fused_w2.weight" in param_name:
elif "fused_w2.weight" in load_plan.name:
# hf: num_experts, expert_dim, hidden_size
# xtuner: num_experts * hidden_size, expert_dim
loaded_tensor = loaded_tensor.transpose(1, 2).flatten(0, 1)

if "fused_w1w3.bias" in param_name:
if "fused_w1w3.bias" in load_plan.name:
# hf: num_experts, expert_dim * 2
# xtuner: num_experts, 2 * expert_dim
num_experts = loaded_tensor.size(0)
loaded_tensor = loaded_tensor.reshape(num_experts, -1, 2)
loaded_tensor = loaded_tensor.transpose(1, 2).reshape(num_experts, -1)

if start is not None and end is not None:
start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM])
end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM])
loaded_tensor_slice = loaded_tensor.index_select(
dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device)
)
non_pad_len = end - start
local_tensor[:non_pad_len].copy_(loaded_tensor_slice)

if non_pad_len < local_tensor.shape[self.FSDP_SHARD_DIM]:
assert self.config.float8_cfg is not None
local_tensor[non_pad_len:].copy_(0.0) # type: ignore # padded part must be set to 0
else:
local_tensor.copy_(loaded_tensor)
loaded_tensor = self._apply_load_slices(loaded_tensor, load_plan)
self._copy_loaded_tensor_to_local(loaded_tensor, local_tensor)

def param_to_safetensor(
self,
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ def fully_shard(
if isinstance(module, nn.Embedding):
module.forward = types.MethodType(self.patched_emb_forward, module) # type: ignore

self._init_load_spec()
self._to_empty_meta()
return self

Expand Down
34 changes: 8 additions & 26 deletions xtuner/v1/model/moe/qwen3_5_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig
from xtuner.v1.module.rope import RopeScalingConfig
from xtuner.v1.module.router.greedy import GreedyRouterConfig
from xtuner.v1.utils.load_spec import HFLoadPlan

from .qwen3vl_text import Qwen3VLTextMoE

Expand Down Expand Up @@ -126,42 +127,23 @@ def safetensors_to_params(
self,
safetensors: list[torch.Tensor],
local_tensor: torch.Tensor,
param_name: str,
start: int | None,
end: int | None,
dim: int | None,
):
if len(safetensors) > 1:
assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1"
loaded_tensor = torch.cat(safetensors, dim=dim)
else:
loaded_tensor = safetensors[0]
load_plan: HFLoadPlan,
) -> None:
loaded_tensor = self._cat_safetensors(safetensors, load_plan)

if "fused_w1w3.weight" in param_name and "mtp" not in param_name:
if "fused_w1w3.weight" in load_plan.name and "mtp" not in load_plan.name:
# hf: num_experts, 2 * expert_dim, hidden_size
# xtuner: num_experts * 2 * expert_dim, hidden_size
# num_experts * 2 * expert_dim, hidden_size
loaded_tensor = loaded_tensor.flatten(0, 1)

elif "fused_w2.weight" in param_name and "mtp" not in param_name:
elif "fused_w2.weight" in load_plan.name and "mtp" not in load_plan.name:
# hf: num_experts, hidden_size, expert_dim
# xtuner: num_experts * hidden_size, expert_dim
loaded_tensor = loaded_tensor.flatten(0, 1)

if start is not None and end is not None:
start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM])
end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM])
loaded_tensor_slice = loaded_tensor.index_select(
dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device)
)
non_pad_len = end - start
local_tensor[:non_pad_len].copy_(loaded_tensor_slice)

if non_pad_len < local_tensor.shape[self.FSDP_SHARD_DIM]:
assert self.config.float8_cfg is not None
local_tensor[non_pad_len:].copy_(0.0) # type: ignore # padded part must be set to 0
else:
local_tensor.copy_(loaded_tensor)
loaded_tensor = self._apply_load_slices(loaded_tensor, load_plan)
self._copy_loaded_tensor_to_local(loaded_tensor, local_tensor)

def param_to_safetensor(
self,
Expand Down
34 changes: 8 additions & 26 deletions xtuner/v1/model/moe/qwen3vl_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.utils.activation_offload import async_save_on_cpu
from xtuner.v1.utils.load_spec import HFLoadPlan

from .moe import MoELossContextDict, MoEModelOutputs
from .qwen3 import Qwen3MoE, Qwen3MoE30BA3Config, Qwen3MoE235BA22Config
Expand Down Expand Up @@ -39,44 +40,25 @@ def safetensors_to_params(
self,
safetensors: list[torch.Tensor],
local_tensor: torch.Tensor,
param_name: str,
start: int | None,
end: int | None,
dim: int | None,
):
if len(safetensors) > 1:
assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1"
loaded_tensor = torch.cat(safetensors, dim=dim)
else:
loaded_tensor = safetensors[0]
load_plan: HFLoadPlan,
) -> None:
loaded_tensor = self._cat_safetensors(safetensors, load_plan)

if "fused_w1w3.weight" in param_name:
if "fused_w1w3.weight" in load_plan.name:
# hf: num_experts, hidden_size, 2 * expert_dim
# xtuner: num_experts * 2 * expert_dim, hidden_size
num_experts, hidden_size = loaded_tensor.shape[:2]
loaded_tensor = loaded_tensor.transpose(1, 2) # num_experts, 2 * expert_dim, hidden_size
# num_experts * 2 * expert_dim, hidden_size
loaded_tensor = loaded_tensor.reshape(-1, hidden_size)

elif "fused_w2.weight" in param_name:
elif "fused_w2.weight" in load_plan.name:
# hf: num_experts, expert_dim, hidden_size
# xtuner: num_experts * hidden_size, expert_dim
loaded_tensor = loaded_tensor.transpose(1, 2).flatten(0, 1)

if start is not None and end is not None:
start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM])
end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM])
loaded_tensor_slice = loaded_tensor.index_select(
dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device)
)
non_pad_len = end - start
local_tensor[:non_pad_len].copy_(loaded_tensor_slice)

if non_pad_len < local_tensor.shape[self.FSDP_SHARD_DIM]:
assert self.config.float8_cfg is not None
local_tensor[non_pad_len:].copy_(0.0) # type: ignore # padded part must be set to 0
else:
local_tensor.copy_(loaded_tensor)
loaded_tensor = self._apply_load_slices(loaded_tensor, load_plan)
self._copy_loaded_tensor_to_local(loaded_tensor, local_tensor)

def param_to_safetensor(
self,
Expand Down
Loading
Loading