@@ -68,6 +68,7 @@ def quantize(
6868
6969 return dq_output
7070
71+
7172def dynamic_block_quantize (
7273 ctx : ConversionContext ,
7374 target : Target ,
@@ -99,23 +100,29 @@ def dynamic_block_quantize(
99100 raise ValueError (
100101 f"dynamic_block_quantize converter received an input of { input_tensor .shape } shape. Supported shapes: 2D or 3D"
101102 )
102- print (f"input_tensor.shape: { input_tensor .shape } { block_size = } { amax = } { num_bits = } { exponent_bits = } { scale_num_bits = } { scale_exponent_bits = } " )
103103 max_bound = 6
104104 amax = to_torch (amax , None )
105105 scale = torch .divide (amax , max_bound )
106106 scale = get_trt_tensor (ctx , scale , name + "_scale" )
107107
108- output_type = trt .DataType .FP4
109108 # Add Q node
110- dynamic_quantize_layer = ctx .net .add_dynamic_quantize (input_tensor , axis = - 1 , block_size = 16 , output_type = output_type )
111- quantize_layer .set_output_type (0 , output_type )
109+ dynamic_quantize_layer = ctx .net .add_dynamic_quantize (
110+ input_tensor ,
111+ axis = - 1 ,
112+ block_size = 16 ,
113+ output_type = trt .DataType .FP4 ,
114+ scale_type = trt .DataType .FP8 ,
115+ )
116+ dynamic_quantize_layer .set_output_type (0 , trt .DataType .FP4 )
112117
113- set_layer_name (quantize_layer , target , name + "_quantize" , source_ir )
114- q_output = quantize_layer .get_output (0 )
118+ set_layer_name (
119+ dynamic_quantize_layer , target , name + "_dynamic_quantize" , source_ir
120+ )
121+ q_output = dynamic_quantize_layer .get_output (0 )
115122 # Add DQ node
116123 dequantize_layer = ctx .net .add_dequantize (q_output , scale )
117124 set_layer_name (dequantize_layer , target , name + "_dequantize" , source_ir )
118- dequantize_layer .precision = output_type
125+ dequantize_layer .precision = trt . DataType . FP4
119126 dq_output = dequantize_layer .get_output (0 )
120127
121128 return dq_output
0 commit comments