@@ -264,10 +264,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
264264
265265 # Quantization-specific variables (INT4, INT8, etc.)
266266 int4_algo_config = self .make_int4_algo_config (extra_options .get ("int4_algo_config" , "default" ))
267+ self .int4_block_size = extra_options .get ("int4_block_size" , 32 )
267268 self .quant_attrs = {
268269 "int4" : {
269270 "accuracy_level" : int (extra_options .get ("int4_accuracy_level" , 4 if self .ep in ["cpu" , "webgpu" ] else 0 )),
270- "block_size" : int (extra_options . get ( " int4_block_size" , 32 ) ),
271+ "block_size" : int (self . int4_block_size ),
271272 "is_symmetric" : extra_options .get ("int4_is_symmetric" , True ),
272273 "op_types_to_quantize" : extra_options .get ("int4_op_types_to_quantize" , ("MatMul" , )),
273274 "nodes_to_exclude" : extra_options .get ("int4_nodes_to_exclude" , []),
@@ -280,6 +281,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
280281 self .quant_attrs ["config" ] = config .quantization_config
281282 self .quant_attrs ["use_g_idx" ] = config .quantization_config ["desc_act" ] if "desc_act" in config .quantization_config else False
282283
284+ self .int4_tied_embeddings = config .tie_word_embeddings if hasattr (config , "tie_word_embeddings" ) and config .tie_word_embeddings is not None else False
285+ self .int4_tied_embeddings = extra_options .get ("int4_tied_embeddings" , self .int4_tied_embeddings )
286+ self .int8_lm_head = extra_options .get ("int4_algo_config" , "default" ) in {"k_quant_mixed" , "k_quant_last" }
287+ if not self .int8_lm_head :
288+ # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
289+ self .int4_tied_embeddings = False
290+
283291 def to_str_dtype (self , dtype : ir .DataType ) -> str :
284292 return dtype .name
285293
@@ -1069,13 +1077,28 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs):
10691077 self .make_add_bias (add , name , root_input , ** kwargs )
10701078
10711079 def make_embedding (self , embedding ):
1072- weight = "model.embed_tokens.weight"
1073- self .make_initializer (embedding , weight , to = self .io_dtype )
1074-
10751080 basename = "/model/embed_tokens"
1076- gather_name = f"{ basename } /Gather"
1077- gather_output = f"{ gather_name } /output_0"
1078- self .make_node ('Gather' , inputs = [weight , 'input_ids' ], outputs = [gather_output ], name = gather_name )
1081+ if self .int4_tied_embeddings :
1082+ gather_name = f"{ basename } /GatherBlockQuantized"
1083+ gather_output = f"{ gather_name } /output_0"
1084+
1085+ weight_reshape_name = f"{ basename } /Reshape"
1086+ bits = 8 if self .int8_lm_head else 4
1087+ weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{ bits } G{ self .int4_block_size } " , f"/model/constants/INT64/[{ self .vocab_size } , { self .hidden_size } ]" ]
1088+ weight_reshape_output = f"{ weight_reshape_name } /output_0"
1089+ # quantized weight dtype is uint8, see here
1090+ # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73
1091+ self .make_reshape (weight_reshape_name , weight_reshape_inputs , dtype = ir .DataType .UINT8 , shape = ['vocab_size' , 'hidden_size' ])
1092+
1093+ self .make_node ('GatherBlockQuantized' , inputs = [weight_reshape_output , 'input_ids' , 'lm_head.MatMul.weight_scale' , 'lm_head.MatMul.weight_zp' ], outputs = [gather_output ], name = gather_name , domain = "com.microsoft" , bits = bits , block_size = int (self .int4_block_size ))
1094+ else :
1095+ weight = "model.embed_tokens.weight"
1096+ self .make_initializer (embedding , weight , to = self .io_dtype )
1097+
1098+ gather_name = f"{ basename } /Gather"
1099+ gather_output = f"{ gather_name } /output_0"
1100+ self .make_node ('Gather' , inputs = [weight , 'input_ids' ], outputs = [gather_output ], name = gather_name )
1101+
10791102 self .make_value (gather_output , self .io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
10801103
10811104 if self .embed_attrs ["scale" ] != 1 :
@@ -4172,7 +4195,7 @@ def check_extra_options(kv_pairs):
41724195 """
41734196 bools = [
41744197 "int4_is_symmetric" , "exclude_embeds" , "exclude_lm_head" , "include_hidden_states" , "enable_cuda_graph" ,
4175- "use_8bits_moe" , "use_qdq" , "use_webgpu_fp32" , "use_cuda_bf16" ,
4198+ "use_8bits_moe" , "use_qdq" , "use_webgpu_fp32" , "use_cuda_bf16" , "int4_tied_embeddings"
41764199 ]
41774200 for key in bools :
41784201 if key in kv_pairs :
@@ -4459,6 +4482,8 @@ def get_args():
44594482 Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'.
44604483 k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8).
44614484 k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
4485+ int4_tied_embeddings = Enable weight sharing for quantization. Default is false.
4486+ Use this option when you want to share the weights in the embedding and unembedding.
44624487 num_hidden_layers = Manually specify the number of layers in your ONNX model.
44634488 Used for unit testing purposes.
44644489 filename = Filename for ONNX model (default is 'model.onnx').
0 commit comments