Skip to content

Commit 7bf4dd9

Browse files
[TRTLLM-10318][feat] Fixing Nemotron sharding: support for sharding buffers (NVIDIA#10319)
Signed-off-by: greg-kwasniewski1 <213329731+greg-kwasniewski1@users.noreply.github.com> Signed-off-by: Lucas <11156568+lucaslie@users.noreply.github.com> Signed-off-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com> Co-authored-by: Lucas <11156568+lucaslie@users.noreply.github.com>
1 parent cef67b4 commit 7bf4dd9

7 files changed

Lines changed: 450 additions & 142 deletions

File tree

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ...shim.interface import CachedSequenceInterface
1414
from ...utils.cuda_mem_tracker import cuda_memory_tracker
1515
from ...utils.logger import ad_logger
16-
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
16+
from ...utils.node_utils import extract_weight_name, is_linear_op, is_op
1717
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1818

1919

@@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
3636
y2 = y[:, out1:out1+out2]
3737
"""
3838
# some info we need
39-
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
39+
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
4040
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
4141
sizes_unfused = [p.size(0) for p in params_unfused]
4242
key_fused = f"fused_weight_{idx}"
@@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple
128128
def _insert_fused_quant_gemm(
129129
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
130130
):
131-
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
131+
keys_unfused = [extract_weight_name(n) for n in linear_nodes]
132132
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
133133
sizes_unfused = [p.size(0) for p in params_unfused]
134134
key_fused = f"fused_weight_{idx}"

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ...models.factory import ModelFactory
1818
from ...shim.interface import CachedSequenceInterface
1919
from ...utils.node_utils import (
20-
extract_param_names_from_node,
20+
extract_weight_nodes,
2121
get_quantization_params_from_linear_node,
2222
is_bmm_op,
2323
is_linear_op,
@@ -139,13 +139,13 @@ def _insert_quantized_linear(
139139
140140
The state_dict is also updated to contain the sharded weights.
141141
"""
142-
param_name, _ = extract_param_names_from_node(node)
143-
original_weight = gm.get_parameter(param_name)
144-
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
145-
modname, _, attrname = param_name.rpartition(".")
142+
weight_nodes = extract_weight_nodes(node)
143+
assert len(weight_nodes.weights) == 1, "Expected exactly one weight node"
144+
lin_weight = weight_nodes.weights[0]
145+
new_param = nn.Parameter(self.quantize_weight(lin_weight.tensor), requires_grad=False)
146+
modname, _, attrname = lin_weight.node_key.rpartition(".")
146147

147-
submod = gm.get_submodule(modname)
148-
setattr(submod, attrname, new_param)
148+
setattr(lin_weight.submod, attrname, new_param)
149149

150150
# check modelopt quantizers from graph
151151
if is_quantized_graph:
@@ -171,10 +171,12 @@ def _insert_quantized_linear(
171171
)
172172
# Note: canonicalize_graph() will remove input/weight/output quantizer
173173

174-
for scale_name, scale in self.default_scales(original_weight.shape).items():
175-
submod.register_buffer(scale_name, scale)
174+
for scale_name, scale in self.default_scales(lin_weight.tensor.shape).items():
175+
lin_weight.submod.register_buffer(scale_name, scale)
176176

177-
gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name))
177+
gm._register_load_state_dict_pre_hook(
178+
partial(self.load_hook, weight_name=lin_weight.node_key)
179+
)
178180

179181
with gm.graph.inserting_before(node):
180182
scales = {}

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
LayerSubgraph,
4040
LayerType,
4141
bfs,
42-
extract_param_names_from_node,
43-
extract_weight_node,
42+
extract_weight_name,
43+
extract_weight_nodes,
4444
filtered_nodes,
4545
get_all_layer_subgraphs,
4646
get_layer_after_linear_node,
@@ -49,7 +49,6 @@
4949
is_any_moe_op,
5050
is_any_ssm_op,
5151
is_op,
52-
num_users_of_weight_node,
5352
shape,
5453
subgraph,
5554
)
@@ -268,7 +267,7 @@ class WeightShardingInfo(ShardingTransformInfo):
268267
min_local_shape: int = 1
269268
layer_type: LayerType = LayerType.MLP
270269
# used for TP sharding of fused weights
271-
fused_weight_dims: Optional[list] = None
270+
fused_weight_dims: Optional[tuple] = None
272271

273272
def quantization_cb(
274273
self,
@@ -437,7 +436,7 @@ def shard_load_hook(
437436

438437

439438
def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, world_size):
440-
assert weight_scale.dim() == 1
439+
# assert weight_scale.dim() == 1
441440
weight_shape_original = list(sharded_uint8_weight_shape)
442441
weight_shape_original[dim] = weight_shape_original[dim] * world_size
443442
weight_shape_original[-1] *= 2
@@ -895,13 +894,10 @@ def _load_hook(
895894
# This is quite a hacky solution. A better solution would be to store extra_state in
896895
# the state_dict to identify whether the state_dict is sharded or not.
897896
key = prefix + param_key
898-
ad_logger.debug(f"Sharder LOAD hook is called for '{key}'")
899897
if key not in state_dict:
900898
return
901899
p_to_load = state_dict[key]
902-
903900
p_to_load = p_to_load if param_shape == p_to_load.shape else f_split(p_to_load)
904-
905901
state_dict[key] = p_to_load
906902

907903

@@ -1124,6 +1120,7 @@ def init_process_grid_from_config(
11241120
ShardingDim.EP: {"p": ep_rank, "w": ep_size},
11251121
ShardingDim.TP: {"p": tp_rank, "w": tp_size},
11261122
}
1123+
ad_logger.info(f"EP + TP sharding process grid: {process_grid}")
11271124
config.process_grid = process_grid
11281125
return process_grid
11291126

@@ -1187,10 +1184,6 @@ def split_fused_tensor(
11871184
fused_dims: list = fused_weight_dims,
11881185
d: int = dim,
11891186
) -> torch.Tensor:
1190-
# dim_d = t.shape[d]
1191-
# num_parts = 1
1192-
# part_size = dim_d // num_parts
1193-
# fused_dims = [part_size] * num_parts
11941187
return torch.cat(
11951188
[split_tensor(w) for w in torch.split(t, fused_dims, dim=d)],
11961189
dim=d,
@@ -1229,7 +1222,7 @@ def _shard_parameter_node(
12291222
config: ShardingTransformConfig,
12301223
add_dist: bool = False,
12311224
min_local_shape: int = 1,
1232-
fused_weight_dims: Optional[list] = None,
1225+
fused_weight_dims: Optional[tuple] = None,
12331226
quantization_cb: Optional[
12341227
Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None]
12351228
] = None,
@@ -1243,67 +1236,58 @@ def _shard_parameter_node(
12431236

12441237
rank, world_size = config.rank, config.world_size
12451238
allreduce_strategy = config.allreduce_strategy.name
1246-
num_users = num_users_of_weight_node(node)
1247-
if num_users > 1 or num_users == 0:
1248-
ad_logger.warning(
1249-
f"Weight node {node} has {num_users} users. This is not supported for sharding. Skipping."
1250-
)
1251-
return
1252-
# get weight and bias key
1253-
weight_key, bias_key = extract_param_names_from_node(node)
1254-
1255-
modname = weight_key.rpartition(".")[0]
1256-
submod = gm.get_submodule(modname)
12571239

12581240
# Shard weight using the unified function (also updates the parameter)
1259-
original_weight = gm.get_parameter(weight_key)
1260-
_, weight_new_shape = shard_weight_tensor(
1261-
gm=gm,
1262-
weight_tensor=original_weight,
1263-
param_key=weight_key,
1264-
dim=dim,
1265-
rank=rank,
1266-
world_size=world_size,
1267-
min_local_shape=min_local_shape,
1268-
fused_weight_dims=fused_weight_dims,
1269-
)
1270-
1271-
if bias_key is not None and dim == 0:
1272-
# update bias for dim 0 --> we can handle it like the weight
1273-
original_bias = gm.get_parameter(bias_key)
1274-
shard_weight_tensor(
1241+
weight_nodes = extract_weight_nodes(node)
1242+
for weight_node in weight_nodes.weights:
1243+
_, weight_new_shape = shard_weight_tensor(
12751244
gm=gm,
1276-
weight_tensor=original_bias,
1277-
param_key=bias_key,
1245+
weight_tensor=weight_node.tensor,
1246+
param_key=weight_node.node_key,
12781247
dim=dim,
12791248
rank=rank,
12801249
world_size=world_size,
12811250
min_local_shape=min_local_shape,
12821251
fused_weight_dims=fused_weight_dims,
12831252
)
1284-
elif bias_key is not None and rank != world_size - 1:
1285-
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
1286-
# double counting it. For all other we will delete the bias.
1287-
args = list(node.args)
1288-
node_bias = args[2]
1289-
args[2] = None
1290-
node.args = tuple(args)
1291-
gm.graph.erase_node(node_bias)
1292-
bias_param_name = bias_key.rpartition(".")[-1]
1293-
setattr(submod, bias_param_name, None)
1294-
gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key))
1295-
1296-
if quantization_cb is not None:
1297-
quantization_cb(
1298-
gm=gm,
1299-
submod=submod,
1300-
node=node,
1301-
weight_key=weight_key,
1302-
weight_new_shape=weight_new_shape,
1303-
dim=dim,
1304-
rank=rank,
1305-
world_size=world_size,
1306-
)
1253+
if quantization_cb is not None:
1254+
quantization_cb(
1255+
gm=gm,
1256+
submod=weight_node.submod,
1257+
node=node,
1258+
weight_key=weight_node.node_key,
1259+
weight_new_shape=weight_new_shape,
1260+
dim=dim,
1261+
rank=rank,
1262+
world_size=world_size,
1263+
)
1264+
1265+
for bias_node in weight_nodes.biases:
1266+
if dim == 0:
1267+
# update bias for dim 0 --> we can handle it like the weight
1268+
shard_weight_tensor(
1269+
gm=gm,
1270+
weight_tensor=bias_node.tensor,
1271+
param_key=bias_node.node_key,
1272+
dim=dim,
1273+
rank=rank,
1274+
world_size=world_size,
1275+
min_local_shape=min_local_shape,
1276+
fused_weight_dims=fused_weight_dims,
1277+
)
1278+
elif bias_node is not None and rank != world_size - 1:
1279+
# update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
1280+
# double counting it. For all other we will delete the bias.
1281+
args = list(node.args)
1282+
node_bias = args[2]
1283+
args[2] = None
1284+
node.args = tuple(args)
1285+
gm.graph.erase_node(node_bias)
1286+
bias_param_name = bias_node.node_key.rpartition(".")[-1]
1287+
setattr(bias_node.submod, bias_param_name, None)
1288+
gm._register_load_state_dict_pre_hook(
1289+
partial(_load_hook_remove, param_key=bias_node.node_key)
1290+
)
13071291

13081292
# # # column shard with no gather: the output is sharded
13091293
if not add_dist:
@@ -1633,7 +1617,7 @@ def _process_ssm_sharding(
16331617
config=config,
16341618
dist_op=None,
16351619
min_local_shape=1,
1636-
fused_weight_dims=fused_weight_dims["in_proj"],
1620+
fused_weight_dims=tuple(fused_weight_dims["in_proj"]),
16371621
layer_type=LayerType.SSM,
16381622
)
16391623
):
@@ -1702,7 +1686,7 @@ def _process_ssm_sharding(
17021686
fused_dims = None
17031687
for k, v in fused_weight_dims.items():
17041688
if k in weight_key:
1705-
fused_dims = v
1689+
fused_dims = tuple(v)
17061690
break
17071691

17081692
# Shard the weight tensor (also updates the parameter in the module)
@@ -1887,7 +1871,7 @@ def _determine_fused_weight_dims(
18871871
ad_logger.warning(
18881872
f"Fused weight dims {fused_weight_dims} do not sum to weight dim {weight_dim}. Skipping."
18891873
)
1890-
return
1874+
return None
18911875
chunk_nodes = list(filtered_nodes(linear_node.users, ops=torch.ops.aten.chunk))
18921876
if len(chunk_nodes) > 0:
18931877
assert len(linear_nodes) == 1
@@ -1896,6 +1880,8 @@ def _determine_fused_weight_dims(
18961880
num_chunks = chunk_nodes[0].args[1]
18971881
weight_dim = shape(linear_node)[2]
18981882
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
1883+
if fused_weight_dims is not None:
1884+
fused_weight_dims = tuple(fused_weight_dims)
18991885
return fused_weight_dims
19001886

19011887

@@ -2046,9 +2032,9 @@ def detect_sharding_from_config(
20462032

20472033
for lin_node in linear_nodes:
20482034
# use node's weight name to get the module name
2049-
module_name = extract_weight_node(lin_node).target
2035+
weight_name = extract_weight_name(lin_node)
20502036

2051-
if any(attn_name in module_name for attn_name in attn_names):
2037+
if any(attn_name in weight_name for attn_name in attn_names):
20522038
# find the next attention node and infer the head_dim
20532039
next_attention_node, _ = bfs(
20542040
lin_node, is_any_attention_op, attr_next="users", include_root=False
@@ -2072,7 +2058,7 @@ def detect_sharding_from_config(
20722058
# Then we escape dots, and finally we replace @ with .*
20732059
pattern_string = pattern_string.replace("*", "@")
20742060
pattern_regex = re.escape(pattern_string).replace("@", ".*")
2075-
if re.match(pattern_regex, module_name):
2061+
if re.match(pattern_regex, weight_name):
20762062
# we have a match. Get the config for this layer
20772063
config = tp_plan[key]
20782064

@@ -2111,7 +2097,7 @@ def detect_sharding_from_config(
21112097
elif "local" in config:
21122098
# Check if this applies to shared experts in EP parallelism.
21132099
# If yes, apply the TP col-row shard.
2114-
if "shared" in module_name:
2100+
if "shared" in weight_name:
21152101
col_row_action = config.replace("local_", "")
21162102
if col_row_action == "colwise":
21172103
transform_container.add(
@@ -2235,7 +2221,6 @@ def detect_column_row_shard(
22352221
min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism
22362222
splitting, e.g., the individual heads into smaller shards.
22372223
"""
2238-
# test_moe_variants()
22392224
ad_logger.debug("Before sharding graph: " + str(gm))
22402225
config = transform_container.config
22412226
world_size = config.world_size
@@ -2340,7 +2325,7 @@ def detect_column_row_shard(
23402325
# simple shard remaining linear nodes
23412326
if config.shard_all_unprocessed:
23422327
num_simple_shards += _process_simple_shard(unprocessed_linear_nodes, transform_container)
2343-
num_column_row_shards += num_ssm_shards
2328+
num_column_row_shards += num_ssm_shards + num_mla_shards
23442329
num_shards = num_simple_shards + num_column_row_shards
23452330
ad_logger.info(
23462331
f"Heuristics found {num_shards} TP shards. Simple: {num_simple_shards}, "

tensorrt_llm/_torch/auto_deploy/utils/_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def get_input_embeddings(model: nn.Module) -> torch.Tensor:
354354
op="call_function", target=torch.ops.aten.embedding.default
355355
)
356356
for node in found_nodes:
357-
embedding_weights.append(get_weight_tensor(gm, node))
357+
embedding_weights.append(get_weight_tensor(node))
358358

359359
if hasattr(model, "get_input_embeddings"):
360360
embedding_weights.append(model.get_input_embeddings())
@@ -400,7 +400,7 @@ def get_lm_head_node(gm: GraphModule, output_node: Optional[Node] = None) -> Nod
400400
def get_lm_head_weights(model: nn.Module) -> torch.Tensor:
401401
gm, output_node = get_output_node(model)
402402
lm_head_node = get_lm_head_node(gm, output_node)
403-
return get_weight_tensor(gm, lm_head_node)
403+
return get_weight_tensor(lm_head_node)
404404

405405

406406
def get_attr_by_name(obj, name):

0 commit comments

Comments
 (0)