@@ -83,6 +83,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v
8383void radix_topk_page_table_transform (TensorView input, TensorView output_page_table,
8484 TensorView src_page_table,
8585 Optional<TensorView> maybe_row_to_batch, TensorView lengths,
86+ Optional<TensorView> maybe_row_starts,
8687 Optional<TensorView> maybe_row_states_buffer, int64_t top_k,
8788 bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
8889 CHECK_INPUT (input);
@@ -93,6 +94,10 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
9394 CHECK_DIM (2 , output_page_table); // output_page_table: (num_rows, top_k)
9495 CHECK_DIM (2 , src_page_table); // src_page_table: (batch_size, max_len)
9596 CHECK_DIM (1 , lengths); // lengths: (num_rows,)
97+ if (maybe_row_starts.has_value ()) {
98+ CHECK_INPUT (maybe_row_starts.value ());
99+ CHECK_DIM (1 , maybe_row_starts.value ());
100+ }
96101
97102 unsigned int num_rows = input.size (0 );
98103 unsigned int max_len = input.size (1 );
@@ -118,14 +123,21 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
118123 if (maybe_row_to_batch.has_value ()) {
119124 row_to_batch_ptr = static_cast <int32_t *>(maybe_row_to_batch.value ().data_ptr ());
120125 }
126+ int32_t * row_starts_ptr = nullptr ;
127+ if (maybe_row_starts.has_value ()) {
128+ TVM_FFI_ICHECK (static_cast <unsigned int >(maybe_row_starts.value ().size (0 )) == num_rows)
129+ << " row_starts must have shape (num_rows,)" ;
130+ row_starts_ptr = static_cast <int32_t *>(maybe_row_starts.value ().data_ptr ());
131+ }
121132
122133 // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
123134 DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 (dtype, c_type, [&] {
124135 status = sampling::TopKPageTableTransformDispatch<c_type, int32_t >(
125136 static_cast <c_type*>(input.data_ptr ()), static_cast <int32_t *>(output_page_table.data_ptr ()),
126- static_cast <const int32_t *>(src_page_table.data_ptr ()), src_stride, row_to_batch_ptr,
127- static_cast <int32_t *>(lengths.data_ptr ()), num_rows, static_cast <uint32_t >(top_k), max_len,
128- row_states_ptr, deterministic, tie_break_mode, stream, dsa_graph_safe);
137+ static_cast <const int32_t *>(src_page_table.data_ptr ()), src_stride,
138+ static_cast <int32_t *>(lengths.data_ptr ()), row_starts_ptr, row_to_batch_ptr, num_rows,
139+ static_cast <uint32_t >(top_k), max_len, row_states_ptr, deterministic, tie_break_mode,
140+ stream, dsa_graph_safe);
129141 return true ;
130142 });
131143
@@ -134,9 +146,9 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta
134146}
135147
136148void radix_topk_ragged_transform (TensorView input, TensorView output_indices, TensorView offsets,
137- TensorView lengths, Optional<TensorView> maybe_row_states_buffer ,
138- int64_t top_k, bool deterministic, int64_t tie_break ,
139- bool dsa_graph_safe) {
149+ TensorView lengths, Optional<TensorView> maybe_row_starts ,
150+ Optional<TensorView> maybe_row_states_buffer, int64_t top_k ,
151+ bool deterministic, int64_t tie_break, bool dsa_graph_safe) {
140152 CHECK_INPUT (input);
141153 CHECK_INPUT (output_indices);
142154 CHECK_INPUT (offsets);
@@ -145,6 +157,10 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
145157 CHECK_DIM (2 , output_indices); // output_indices: (num_rows, top_k)
146158 CHECK_DIM (1 , offsets); // offsets: (num_rows,)
147159 CHECK_DIM (1 , lengths); // lengths: (num_rows,)
160+ if (maybe_row_starts.has_value ()) {
161+ CHECK_INPUT (maybe_row_starts.value ());
162+ CHECK_DIM (1 , maybe_row_starts.value ());
163+ }
148164
149165 unsigned int num_rows = input.size (0 );
150166 unsigned int max_len = input.size (1 );
@@ -164,14 +180,20 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te
164180 row_states_ptr =
165181 static_cast <sampling::RadixRowState*>(maybe_row_states_buffer.value ().data_ptr ());
166182 }
183+ int32_t * row_starts_ptr = nullptr ;
184+ if (maybe_row_starts.has_value ()) {
185+ TVM_FFI_ICHECK (static_cast <unsigned int >(maybe_row_starts.value ().size (0 )) == num_rows)
186+ << " row_starts must have shape (num_rows,)" ;
187+ row_starts_ptr = static_cast <int32_t *>(maybe_row_starts.value ().data_ptr ());
188+ }
167189
168190 // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
169191 DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 (dtype, c_type, [&] {
170192 status = sampling::TopKRaggedTransformDispatch<c_type, int32_t >(
171193 static_cast <c_type*>(input.data_ptr ()), static_cast <int32_t *>(output_indices.data_ptr ()),
172194 static_cast <const int32_t *>(offsets.data_ptr ()), static_cast <int32_t *>(lengths.data_ptr ()),
173- num_rows, static_cast <uint32_t >(top_k), max_len, row_states_ptr, deterministic ,
174- tie_break_mode, stream, dsa_graph_safe);
195+ row_starts_ptr, num_rows, static_cast <uint32_t >(top_k), max_len, row_states_ptr,
196+ deterministic, tie_break_mode, stream, dsa_graph_safe);
175197 return true ;
176198 });
177199
0 commit comments