Skip to content

Commit d47d417

Browse files
authored
Add min_p sampling support (#14)
* feat: added min_p support from request to CUDA logit processor fix: masksData offset for min_p handling fix: min_p request constraint handling chore: adjusted variable casing and comments wip: move min_p as parameter to add_request refactor: moved min_p log space conversion into async_exec fix: added lnMinPs from entries to PostProcessorFn * fix: dummy grammar, min_p in RequestParams, code cleanups refactor: added min_p to RequestParams chore: code cleanup (unnecessary import & serde default default)
1 parent bb0dc22 commit d47d417

File tree

8 files changed

+49
-9
lines changed

8 files changed

+49
-9
lines changed

llgtrt/src/async_exec.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct ReqData {
6262
// it seems to create them one by one
6363
// this array keeps track of assignment of req_id to llg state
6464
llg_infos: Vec<ConstraintInfo>,
65+
min_p: f32,
6566
prompt_len: usize,
6667
is_run: bool,
6768
}
@@ -90,6 +91,7 @@ struct PendingSeq {
9091
prompt_len: usize,
9192
is_run: bool,
9293
entry: TlcLogitsEntry,
94+
min_p: f32,
9395
stop: bool,
9496
// setting this will stop the sequence with given error
9597
error: Option<String>,
@@ -145,6 +147,11 @@ impl PendingSeq {
145147
let mask = step_res.sample_mask.as_ref().expect("No mask");
146148
self.entry.out_mask_pointer = copy_mask(mask);
147149
self.entry.temperature = llg.temperature;
150+
self.entry.ln_min_p = if self.min_p > 0.0 {
151+
self.min_p.ln()
152+
} else {
153+
-f32::MAX
154+
};
148155

149156
Ok(())
150157
}
@@ -160,6 +167,7 @@ impl PendingSeq {
160167
prompt_len: rd.prompt_len,
161168
entry: entry.clone(),
162169
stop: false,
170+
min_p: rd.min_p,
163171
error: None,
164172
is_run: rd.is_run,
165173
}
@@ -276,6 +284,7 @@ extern "C" fn logits_processor(logits: *mut TlcLogitsEntry, num_logits: u32) {
276284
let entry = &mut entries[ps.entry_idx];
277285
entry.out_mask_pointer = ps.entry.out_mask_pointer;
278286
entry.temperature = ps.entry.temperature;
287+
entry.ln_min_p = ps.entry.ln_min_p;
279288
let mut llg = ps.llg;
280289
if let Some(rd) = exec.req_data.get_mut(&entry.client_req_id()) {
281290
if rd.logs.is_empty() {
@@ -448,6 +457,7 @@ impl AsyncExecutor {
448457
llgs: llgs.into_iter().map(Some).collect(),
449458
llg_infos: vec![],
450459
prompt_len,
460+
min_p: init.params.min_p,
451461
logs: String::new(),
452462
is_run,
453463
},

llgtrt/src/routes/completions.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ fn req_params_from_openai(params: &CommonCreateParams) -> Result<RequestParams>
103103
let mut r = RequestParams {
104104
temperature: params.temperature,
105105
top_p: params.top_p,
106+
min_p: params.min_p,
106107
max_new_tokens: params
107108
.max_completion_tokens
108109
.unwrap_or_else(|| params.max_tokens.unwrap_or(16)) as u32,
@@ -282,7 +283,16 @@ fn llg_grammar(params: &CommonCreateParams) -> Result<Option<TopLevelGrammar>> {
282283
log::debug!("using Lark grammar");
283284
lark_to_llguidance(lark_grammar)?
284285
}
285-
_ => return Ok(None),
286+
_ => {
287+
if params.min_p > 0.0 {
288+
// Returning a Dummy-grammar to enforce logit processing when min_p is set
289+
let grm = TopLevelGrammar::from_regex(llguidance::api::RegexNode::Regex(
290+
r"(\n|.)*".to_string(),
291+
));
292+
return Ok(Some(grm));
293+
}
294+
return Ok(None);
295+
}
286296
};
287297
Ok(Some(grm))
288298
}

llgtrt/src/routes/openai.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ pub struct CommonCreateParams {
165165
/// tokens comprising the top 10% probability mass are considered.
166166
#[serde(default = "default_top_p")]
167167
pub top_p: f32,
168+
/// Filters out tokens with probability less than min_p multiplied by the probability of the most likely token
169+
#[serde(default)]
170+
pub min_p: f32,
168171
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect
169172
/// abuse.
170173
#[allow(dead_code)]

trtllm-c/logits.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <cfloat>
12
#include <stdexcept>
23
#include <string>
34
#include <cmath>
@@ -34,7 +35,7 @@ void* tlc_alloc_logit_data(int32_t mask_stride_, int32_t max_batch_size_)
3435
assert(max_batch_size > 0);
3536
assert(mask_stride % 4 == 0);
3637

37-
size_t hd_size = max_batch_size * sizeof(int64_t) * 4;
38+
size_t hd_size = max_batch_size * sizeof(int64_t) * 5;
3839
size_t sz2 = hd_size + max_batch_size * mask_stride;
3940
masks_size = sz2;
4041
if (cudaHostAlloc(&masksData, sz2, cudaHostAllocDefault))
@@ -46,7 +47,7 @@ void* tlc_alloc_logit_data(int32_t mask_stride_, int32_t max_batch_size_)
4647

4748
float* tlc_mask_fraction_ptr()
4849
{
49-
return (float*) ((uint8_t*) masksData + max_batch_size * sizeof(int64_t) * 3);
50+
return (float*) ((uint8_t*) masksData + max_batch_size * sizeof(int64_t) * 4);
5051
}
5152

5253
#define MAX_BATCH_SIZE 128
@@ -109,6 +110,7 @@ static void logitsPostProcessorFn(std::vector<tle::IdType> const& reqIds, std::v
109110
entry._num_tokens = tokens[i].get()[0].size();
110111
entry.out_mask_pointer = nullptr;
111112
entry.temperature = 1.0f;
113+
entry.ln_min_p = -FLT_MAX;
112114
entries.push_back(entry);
113115

114116
// auto shape = logits[i].getShape();
@@ -130,8 +132,10 @@ static void logitsPostProcessorFn(std::vector<tle::IdType> const& reqIds, std::v
130132
int64_t* logitPtrs = (int64_t*) masksData;
131133
int64_t* masksOffsets = logitPtrs + batchSize;
132134
float* temperatures = (float*) (logitPtrs + 2 * batchSize);
135+
float* lnMinPs = (float*) (logitPtrs + 3 * batchSize);
133136

134137
int64_t temperatures_offset = (uint8_t*) temperatures - (uint8_t*) masksData;
138+
int64_t ln_min_p_offset = (uint8_t*) lnMinPs - (uint8_t*) masksData;
135139
int64_t mask_fractions_offset = (uint8_t*) tlc_mask_fraction_ptr() - (uint8_t*) masksData;
136140

137141
int64_t* cudaLogitPtrs = (int64_t*) cudaMasksData;
@@ -187,6 +191,8 @@ static void logitsPostProcessorFn(std::vector<tle::IdType> const& reqIds, std::v
187191

188192
masksOffsets[dp] = mask_offset;
189193
temperatures[dp] = entries[i].temperature;
194+
lnMinPs[dp] = entries[i].ln_min_p;
195+
190196

191197
if (mask_offset > max_offset)
192198
max_offset = mask_offset;
@@ -201,8 +207,8 @@ static void logitsPostProcessorFn(std::vector<tle::IdType> const& reqIds, std::v
201207
if (dp > 0)
202208
{
203209
cudaMemcpyAsync(cudaMasksData, masksData, max_offset + mask_stride, cudaMemcpyHostToDevice, stream);
204-
mask_logits_ext(cudaLogitPtrs, cudaMasksOffsets, mask_fractions_offset, temperatures_offset, dp, nVocab,
205-
mask_stride / 4, tp, stream);
210+
mask_logits_ext(cudaLogitPtrs, cudaMasksOffsets, mask_fractions_offset, temperatures_offset, ln_min_p_offset,
211+
dp, nVocab, mask_stride / 4, tp, stream);
206212
cudaMemcpyAsync((uint8_t*) masksData + mask_fractions_offset, (uint8_t*) cudaMasksData + mask_fractions_offset,
207213
dp * sizeof(float), cudaMemcpyDeviceToHost, stream);
208214

trtllm-c/mask_logits.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ __inline__ __device__ void blockReduceMax2(T& val, int& idx, T flt_max)
6262

6363
template <typename T>
6464
__global__ void mask_logits_kernel(T** logit_ptrs, int64_t* mask_offsets, size_t batch_size, size_t n_vocab,
65-
size_t mask_stride, float* temperatures, T flt_max, float* mask_fractions)
65+
size_t mask_stride, float* temperatures, float* ln_min_p, T flt_max, float* mask_fractions)
6666
{
6767
auto const batch_idx = blockIdx.x;
6868
auto logits_ptr = logit_ptrs[batch_idx];
@@ -135,6 +135,10 @@ __global__ void mask_logits_kernel(T** logit_ptrs, int64_t* mask_offsets, size_t
135135
else
136136
{
137137
logit_adjusted = (logit - s_max_val_allowed) * beta;
138+
if ((float) logit_adjusted < ln_min_p[batch_idx])
139+
{
140+
logit_adjusted = -flt_max;
141+
}
138142
}
139143
}
140144

@@ -154,6 +158,7 @@ void mask_logits_ext(int64_t* d_logit_ptrs, // in,out [batch_size]
154158
int64_t* d_mask_offsets, // in [int32_t,mask_stride], [batch_size]
155159
int64_t mask_fractions_offset, // out, float, [batch_size]
156160
int64_t temperature_offset, // in, float, [batch_size]; can be 0.0f for argmax
161+
int64_t ln_min_p_offset, // in, float, [batch_size]; log_e(min_p) for min_p > 0.0f, -FLT_MAX otherwise
157162
size_t batch_size, // current batch size
158163
size_t n_vocab, // vocab size
159164
size_t mask_stride, // n_vocab / 32 or thereabouts
@@ -167,10 +172,11 @@ void mask_logits_ext(int64_t* d_logit_ptrs, // in,out [batch_size]
167172

168173
float* mask_fractions = reinterpret_cast<float*>((uint8_t*) d_logit_ptrs + mask_fractions_offset);
169174
float* temperatures = reinterpret_cast<float*>((uint8_t*) d_logit_ptrs + temperature_offset);
175+
float* ln_min_ps = reinterpret_cast<float*>((uint8_t*) d_logit_ptrs + ln_min_p_offset);
170176

171177
#define LAUNCH_KERNEL(T, m) \
172-
mask_logits_kernel<T><<<grid, block, 0, stream>>>( \
173-
(T**) d_logit_ptrs, d_mask_offsets, batch_size, n_vocab, mask_stride, temperatures, m, mask_fractions)
178+
mask_logits_kernel<T><<<grid, block, 0, stream>>>((T**) d_logit_ptrs, d_mask_offsets, batch_size, n_vocab, \
179+
mask_stride, temperatures, ln_min_ps, m, mask_fractions)
174180

175181
switch (tp)
176182
{

trtllm-c/mask_logits.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ void mask_logits_ext(int64_t* d_logit_ptrs, // in,out [batch_size]
1010
int64_t* d_mask_offsets, // in [int32_t,mask_stride], [batch_size]
1111
int64_t mask_fractions_offset, // out, float, [batch_size]
1212
int64_t temperature_offset, // in, float, [batch_size]; can be 0.0f for argmax
13+
int64_t ln_min_p_offset, // in, float, [batch_size]; log_e(min_p) for min_p > 0.0f, -FLT_MAX otherwise
1314
size_t batch_size, // current batch size
1415
size_t n_vocab, // vocab size
1516
size_t mask_stride, // n_vocab / 32 or thereabouts

trtllm-c/tlc.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ extern "C"
2323
uint32_t _num_tokens;
2424
// set by the callback (initially 1.0)
2525
float temperature;
26+
// set by the callback (initially -FLT_MAX)
27+
float ln_min_p;
2628
// set by the callback (initially NULL)
2729
uint32_t* out_mask_pointer;
2830
} TlcLogitsEntry;
@@ -116,6 +118,7 @@ extern "C"
116118
uint32_t eos_token_id;
117119
float temperature;
118120
float top_p;
121+
float min_p;
119122
float frequency_penalty;
120123
float presence_penalty;
121124
float priority;
@@ -183,4 +186,4 @@ extern "C"
183186
}
184187
#endif
185188

186-
#endif // TLC_H
189+
#endif // TLC_H

trtllm_rs/src/tlc.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ impl Default for RequestParams {
5555
num_return_sequences: 1,
5656
temperature: f32::NAN,
5757
top_p: 1.0,
58+
min_p: 0.0,
5859
presence_penalty: 0.0,
5960
frequency_penalty: 0.0,
6061
top_k: 0,

0 commit comments

Comments
 (0)