1515_TRIED_LOADING_BLOCKWISE_3RD_KERNEL = False
1616_BLOCKWISE_3RD_TUNED_VALUE_CLS = None
1717_BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = None
18+ _BLOCKWISE_3RD_TUNED_BLOCK_SIZES = None
1819_TRIED_LOADING_BLOCKWISE_3RD_TUNING = False
1920
2021
@@ -39,22 +40,33 @@ def _get_blockwise_3rd_tuning_api():
3940 """Lazily load third-party tuned-size helpers for blockwise kernel."""
4041 global _BLOCKWISE_3RD_TUNED_VALUE_CLS
4142 global _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES
43+ global _BLOCKWISE_3RD_TUNED_BLOCK_SIZES
4244 global _TRIED_LOADING_BLOCKWISE_3RD_TUNING
4345
4446 if _TRIED_LOADING_BLOCKWISE_3RD_TUNING :
45- return _BLOCKWISE_3RD_TUNED_VALUE_CLS , _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES
47+ return (
48+ _BLOCKWISE_3RD_TUNED_VALUE_CLS ,
49+ _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES ,
50+ _BLOCKWISE_3RD_TUNED_BLOCK_SIZES ,
51+ )
4652 _TRIED_LOADING_BLOCKWISE_3RD_TUNING = True
4753
4854 try :
4955 package = __package__ or "sgl_jax.srt.kernels.quantized_matmul"
5056 module = importlib .import_module (f"{ package } .3rd_quantized_matmul.tuned_block_sizes" )
5157 _BLOCKWISE_3RD_TUNED_VALUE_CLS = getattr (module , "TunedValue" , None )
5258 _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = getattr (module , "get_tuned_block_sizes" , None )
59+ _BLOCKWISE_3RD_TUNED_BLOCK_SIZES = getattr (module , "TUNED_BLOCK_SIZES" , None )
5360 except Exception :
5461 _BLOCKWISE_3RD_TUNED_VALUE_CLS = None
5562 _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES = None
63+ _BLOCKWISE_3RD_TUNED_BLOCK_SIZES = None
5664
57- return _BLOCKWISE_3RD_TUNED_VALUE_CLS , _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES
65+ return (
66+ _BLOCKWISE_3RD_TUNED_VALUE_CLS ,
67+ _BLOCKWISE_3RD_GET_TUNED_BLOCK_SIZES ,
68+ _BLOCKWISE_3RD_TUNED_BLOCK_SIZES ,
69+ )
5870
5971
6072def _next_multiple (x : int , m : int ) -> int :
@@ -63,6 +75,72 @@ def _next_multiple(x: int, m: int) -> int:
6375 return ((x + m - 1 ) // m ) * m
6476
6577
78+ def _floor_multiple (x : int , m : int ) -> int :
79+ if m <= 0 :
80+ return x
81+ return max (m , (x // m ) * m )
82+
83+
84+ def _nearest_power_of_two_multiple (x : int , base : int , upper_bound : int ) -> int :
85+ if base <= 0 :
86+ return x
87+
88+ x = max (base , x )
89+ units = max (1 , x // base )
90+ lower_units = 1 << (units .bit_length () - 1 )
91+ upper_units = lower_units if lower_units == units else lower_units << 1
92+
93+ def _candidate (units_value : int ) -> int :
94+ return units_value * base
95+
96+ lower = _candidate (lower_units )
97+ upper = _candidate (upper_units )
98+ candidates = [value for value in (lower , upper ) if value <= upper_bound ]
99+ if not candidates :
100+ candidates = [lower ]
101+
102+ return min (candidates , key = lambda value : (abs (value - x ), - value ))
103+
104+
105+ def _iter_blockwise_tuned_candidates (
106+ tuned_block_sizes : dict | None ,
107+ n_batch : int ,
108+ n_out : int ,
109+ n_in : int ,
110+ x_q_dtype : jnp .dtype ,
111+ w_q_dtype : jnp .dtype ,
112+ ):
113+ if not tuned_block_sizes :
114+ return []
115+
116+ x_q_dtype_name = jnp .dtype (x_q_dtype ).name
117+ w_q_dtype_name = jnp .dtype (w_q_dtype ).name
118+ compatible_x_dtype_names = [x_q_dtype_name ]
119+ if jnp .issubdtype (w_q_dtype , jnp .integer ) and x_q_dtype_name != "int8" :
120+ compatible_x_dtype_names .append ("int8" )
121+
122+ candidates = []
123+ for key , value in tuned_block_sizes .items ():
124+ if key .w_q_dtype != w_q_dtype_name :
125+ continue
126+ if key .x_q_dtype not in compatible_x_dtype_names :
127+ continue
128+
129+ score = (
130+ compatible_x_dtype_names .index (key .x_q_dtype ),
131+ key .n_in != n_in ,
132+ abs (key .n_in - n_in ),
133+ key .n_batch != n_batch ,
134+ abs (key .n_batch - n_batch ),
135+ key .n_out != n_out ,
136+ abs (key .n_out - n_out ),
137+ )
138+ candidates .append ((score , value ))
139+
140+ candidates .sort (key = lambda item : item [0 ])
141+ return [value for _ , value in candidates ]
142+
143+
66144def _get_safe_blockwise_tuned_value (
67145 n_batch : int ,
68146 n_out : int ,
@@ -72,12 +150,22 @@ def _get_safe_blockwise_tuned_value(
72150 block_size_in : int ,
73151):
74152 """Build a safe tuned value for third-party blockwise kernel on TPU."""
75- tuned_value_cls , get_tuned_block_sizes = _get_blockwise_3rd_tuning_api ()
153+ tuned_value_cls , get_tuned_block_sizes , tuned_block_sizes = _get_blockwise_3rd_tuning_api ()
76154 if tuned_value_cls is None :
77155 return None
78156
79157 tuned = None
80- if get_tuned_block_sizes is not None :
158+ compatible_candidates = _iter_blockwise_tuned_candidates (
159+ tuned_block_sizes = tuned_block_sizes ,
160+ n_batch = n_batch ,
161+ n_out = n_out ,
162+ n_in = n_in ,
163+ x_q_dtype = x_q_dtype ,
164+ w_q_dtype = w_q_dtype ,
165+ )
166+ if compatible_candidates :
167+ tuned = compatible_candidates [0 ]
168+ elif get_tuned_block_sizes is not None :
81169 try :
82170 tuned = get_tuned_block_sizes (
83171 n_batch = n_batch ,
@@ -94,10 +182,17 @@ def _get_safe_blockwise_tuned_value(
94182 n_lane_multiplier = max (1 , int (tuned .n_lane_multiplier ))
95183 compute_tile_n = 256 * n_lane_multiplier
96184
97- batch_block_size = max (1 , int (tuned .batch_block_size ))
185+ batch_block_size = max (1 , min ( int (tuned .batch_block_size ), int ( n_batch ) ))
98186 out_block_size = _next_multiple (max (int (tuned .out_block_size ), compute_tile_n ), compute_tile_n )
187+ out_block_size = min (out_block_size , _floor_multiple (int (n_out ), compute_tile_n ))
188+ out_block_size = _nearest_power_of_two_multiple (
189+ out_block_size ,
190+ compute_tile_n ,
191+ _floor_multiple (int (n_out ), compute_tile_n ),
192+ )
99193 in_block_size = max (int (tuned .in_block_size ), int (block_size_in ))
100194 in_block_size = _next_multiple (in_block_size , int (block_size_in ))
195+ in_block_size = min (in_block_size , _floor_multiple (int (n_in ), int (block_size_in )))
101196
102197 return tuned_value_cls (batch_block_size , out_block_size , in_block_size , n_lane_multiplier )
103198
0 commit comments