3939 LayerSubgraph ,
4040 LayerType ,
4141 bfs ,
42- extract_param_names_from_node ,
43- extract_weight_node ,
42+ extract_weight_name ,
43+ extract_weight_nodes ,
4444 filtered_nodes ,
4545 get_all_layer_subgraphs ,
4646 get_layer_after_linear_node ,
4949 is_any_moe_op ,
5050 is_any_ssm_op ,
5151 is_op ,
52- num_users_of_weight_node ,
5352 shape ,
5453 subgraph ,
5554)
@@ -268,7 +267,7 @@ class WeightShardingInfo(ShardingTransformInfo):
268267 min_local_shape : int = 1
269268 layer_type : LayerType = LayerType .MLP
270269 # used for TP sharding of fused weights
271- fused_weight_dims : Optional [list ] = None
270+ fused_weight_dims : Optional [tuple ] = None
272271
273272 def quantization_cb (
274273 self ,
@@ -437,7 +436,7 @@ def shard_load_hook(
437436
438437
439438def _shard_fp4_weight_scale (weight_scale , sharded_uint8_weight_shape , dim , rank , world_size ):
440- assert weight_scale .dim () == 1
439+ # assert weight_scale.dim() == 1
441440 weight_shape_original = list (sharded_uint8_weight_shape )
442441 weight_shape_original [dim ] = weight_shape_original [dim ] * world_size
443442 weight_shape_original [- 1 ] *= 2
@@ -895,13 +894,10 @@ def _load_hook(
895894 # This is quite a hacky solution. A better solution would be to store extra_state in
896895 # the state_dict to identify whether the state_dict is sharded or not.
897896 key = prefix + param_key
898- ad_logger .debug (f"Sharder LOAD hook is called for '{ key } '" )
899897 if key not in state_dict :
900898 return
901899 p_to_load = state_dict [key ]
902-
903900 p_to_load = p_to_load if param_shape == p_to_load .shape else f_split (p_to_load )
904-
905901 state_dict [key ] = p_to_load
906902
907903
@@ -1124,6 +1120,7 @@ def init_process_grid_from_config(
11241120 ShardingDim .EP : {"p" : ep_rank , "w" : ep_size },
11251121 ShardingDim .TP : {"p" : tp_rank , "w" : tp_size },
11261122 }
1123+ ad_logger .info (f"EP + TP sharding process grid: { process_grid } " )
11271124 config .process_grid = process_grid
11281125 return process_grid
11291126
@@ -1187,10 +1184,6 @@ def split_fused_tensor(
11871184 fused_dims : list = fused_weight_dims ,
11881185 d : int = dim ,
11891186 ) -> torch .Tensor :
1190- # dim_d = t.shape[d]
1191- # num_parts = 1
1192- # part_size = dim_d // num_parts
1193- # fused_dims = [part_size] * num_parts
11941187 return torch .cat (
11951188 [split_tensor (w ) for w in torch .split (t , fused_dims , dim = d )],
11961189 dim = d ,
@@ -1229,7 +1222,7 @@ def _shard_parameter_node(
12291222 config : ShardingTransformConfig ,
12301223 add_dist : bool = False ,
12311224 min_local_shape : int = 1 ,
1232- fused_weight_dims : Optional [list ] = None ,
1225+ fused_weight_dims : Optional [tuple ] = None ,
12331226 quantization_cb : Optional [
12341227 Callable [[GraphModule , nn .Module , Node , str , torch .Size , int , int , int ], None ]
12351228 ] = None ,
@@ -1243,67 +1236,58 @@ def _shard_parameter_node(
12431236
12441237 rank , world_size = config .rank , config .world_size
12451238 allreduce_strategy = config .allreduce_strategy .name
1246- num_users = num_users_of_weight_node (node )
1247- if num_users > 1 or num_users == 0 :
1248- ad_logger .warning (
1249- f"Weight node { node } has { num_users } users. This is not supported for sharding. Skipping."
1250- )
1251- return
1252- # get weight and bias key
1253- weight_key , bias_key = extract_param_names_from_node (node )
1254-
1255- modname = weight_key .rpartition ("." )[0 ]
1256- submod = gm .get_submodule (modname )
12571239
12581240 # Shard weight using the unified function (also updates the parameter)
1259- original_weight = gm .get_parameter (weight_key )
1260- _ , weight_new_shape = shard_weight_tensor (
1261- gm = gm ,
1262- weight_tensor = original_weight ,
1263- param_key = weight_key ,
1264- dim = dim ,
1265- rank = rank ,
1266- world_size = world_size ,
1267- min_local_shape = min_local_shape ,
1268- fused_weight_dims = fused_weight_dims ,
1269- )
1270-
1271- if bias_key is not None and dim == 0 :
1272- # update bias for dim 0 --> we can handle it like the weight
1273- original_bias = gm .get_parameter (bias_key )
1274- shard_weight_tensor (
1241+ weight_nodes = extract_weight_nodes (node )
1242+ for weight_node in weight_nodes .weights :
1243+ _ , weight_new_shape = shard_weight_tensor (
12751244 gm = gm ,
1276- weight_tensor = original_bias ,
1277- param_key = bias_key ,
1245+ weight_tensor = weight_node . tensor ,
1246+ param_key = weight_node . node_key ,
12781247 dim = dim ,
12791248 rank = rank ,
12801249 world_size = world_size ,
12811250 min_local_shape = min_local_shape ,
12821251 fused_weight_dims = fused_weight_dims ,
12831252 )
1284- elif bias_key is not None and rank != world_size - 1 :
1285- # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
1286- # double counting it. For all other we will delete the bias.
1287- args = list (node .args )
1288- node_bias = args [2 ]
1289- args [2 ] = None
1290- node .args = tuple (args )
1291- gm .graph .erase_node (node_bias )
1292- bias_param_name = bias_key .rpartition ("." )[- 1 ]
1293- setattr (submod , bias_param_name , None )
1294- gm ._register_load_state_dict_pre_hook (partial (_load_hook_remove , param_key = bias_key ))
1295-
1296- if quantization_cb is not None :
1297- quantization_cb (
1298- gm = gm ,
1299- submod = submod ,
1300- node = node ,
1301- weight_key = weight_key ,
1302- weight_new_shape = weight_new_shape ,
1303- dim = dim ,
1304- rank = rank ,
1305- world_size = world_size ,
1306- )
1253+ if quantization_cb is not None :
1254+ quantization_cb (
1255+ gm = gm ,
1256+ submod = weight_node .submod ,
1257+ node = node ,
1258+ weight_key = weight_node .node_key ,
1259+ weight_new_shape = weight_new_shape ,
1260+ dim = dim ,
1261+ rank = rank ,
1262+ world_size = world_size ,
1263+ )
1264+
1265+ for bias_node in weight_nodes .biases :
1266+ if dim == 0 :
1267+ # update bias for dim 0 --> we can handle it like the weight
1268+ shard_weight_tensor (
1269+ gm = gm ,
1270+ weight_tensor = bias_node .tensor ,
1271+ param_key = bias_node .node_key ,
1272+ dim = dim ,
1273+ rank = rank ,
1274+ world_size = world_size ,
1275+ min_local_shape = min_local_shape ,
1276+ fused_weight_dims = fused_weight_dims ,
1277+ )
1278+ elif bias_node is not None and rank != world_size - 1 :
1279+ # update the bias for dim 1 --> in this case only the last rank gets the bias to avoid
1280+ # double counting it. For all other we will delete the bias.
1281+ args = list (node .args )
1282+ node_bias = args [2 ]
1283+ args [2 ] = None
1284+ node .args = tuple (args )
1285+ gm .graph .erase_node (node_bias )
1286+ bias_param_name = bias_node .node_key .rpartition ("." )[- 1 ]
1287+ setattr (bias_node .submod , bias_param_name , None )
1288+ gm ._register_load_state_dict_pre_hook (
1289+ partial (_load_hook_remove , param_key = bias_node .node_key )
1290+ )
13071291
13081292 # # # column shard with no gather: the output is sharded
13091293 if not add_dist :
@@ -1633,7 +1617,7 @@ def _process_ssm_sharding(
16331617 config = config ,
16341618 dist_op = None ,
16351619 min_local_shape = 1 ,
1636- fused_weight_dims = fused_weight_dims ["in_proj" ],
1620+ fused_weight_dims = tuple ( fused_weight_dims ["in_proj" ]) ,
16371621 layer_type = LayerType .SSM ,
16381622 )
16391623 ):
@@ -1702,7 +1686,7 @@ def _process_ssm_sharding(
17021686 fused_dims = None
17031687 for k , v in fused_weight_dims .items ():
17041688 if k in weight_key :
1705- fused_dims = v
1689+ fused_dims = tuple ( v )
17061690 break
17071691
17081692 # Shard the weight tensor (also updates the parameter in the module)
@@ -1887,7 +1871,7 @@ def _determine_fused_weight_dims(
18871871 ad_logger .warning (
18881872 f"Fused weight dims { fused_weight_dims } do not sum to weight dim { weight_dim } . Skipping."
18891873 )
1890- return
1874+ return None
18911875 chunk_nodes = list (filtered_nodes (linear_node .users , ops = torch .ops .aten .chunk ))
18921876 if len (chunk_nodes ) > 0 :
18931877 assert len (linear_nodes ) == 1
@@ -1896,6 +1880,8 @@ def _determine_fused_weight_dims(
18961880 num_chunks = chunk_nodes [0 ].args [1 ]
18971881 weight_dim = shape (linear_node )[2 ]
18981882 fused_weight_dims = [weight_dim // num_chunks ] * num_chunks
1883+ if fused_weight_dims is not None :
1884+ fused_weight_dims = tuple (fused_weight_dims )
18991885 return fused_weight_dims
19001886
19011887
@@ -2046,9 +2032,9 @@ def detect_sharding_from_config(
20462032
20472033 for lin_node in linear_nodes :
20482034 # use node's weight name to get the module name
2049- module_name = extract_weight_node (lin_node ). target
2035+ weight_name = extract_weight_name (lin_node )
20502036
2051- if any (attn_name in module_name for attn_name in attn_names ):
2037+ if any (attn_name in weight_name for attn_name in attn_names ):
20522038 # find the next attention node and infer the head_dim
20532039 next_attention_node , _ = bfs (
20542040 lin_node , is_any_attention_op , attr_next = "users" , include_root = False
@@ -2072,7 +2058,7 @@ def detect_sharding_from_config(
20722058 # Then we escape dots, and finally we replace @ with .*
20732059 pattern_string = pattern_string .replace ("*" , "@" )
20742060 pattern_regex = re .escape (pattern_string ).replace ("@" , ".*" )
2075- if re .match (pattern_regex , module_name ):
2061+ if re .match (pattern_regex , weight_name ):
20762062 # we have a match. Get the config for this layer
20772063 config = tp_plan [key ]
20782064
@@ -2111,7 +2097,7 @@ def detect_sharding_from_config(
21112097 elif "local" in config :
21122098 # Check if this applies to shared experts in EP parallelism.
21132099 # If yes, apply the TP col-row shard.
2114- if "shared" in module_name :
2100+ if "shared" in weight_name :
21152101 col_row_action = config .replace ("local_" , "" )
21162102 if col_row_action == "colwise" :
21172103 transform_container .add (
@@ -2235,7 +2221,6 @@ def detect_column_row_shard(
22352221 min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism
22362222 splitting, e.g., the individual heads into smaller shards.
22372223 """
2238- # test_moe_variants()
22392224 ad_logger .debug ("Before sharding graph: " + str (gm ))
22402225 config = transform_container .config
22412226 world_size = config .world_size
@@ -2340,7 +2325,7 @@ def detect_column_row_shard(
23402325 # simple shard remaining linear nodes
23412326 if config .shard_all_unprocessed :
23422327 num_simple_shards += _process_simple_shard (unprocessed_linear_nodes , transform_container )
2343- num_column_row_shards += num_ssm_shards
2328+ num_column_row_shards += num_ssm_shards + num_mla_shards
23442329 num_shards = num_simple_shards + num_column_row_shards
23452330 ad_logger .info (
23462331 f"Heuristics found { num_shards } TP shards. Simple: { num_simple_shards } , "
0 commit comments