Skip to content

Commit 44ad0ae

Browse files
committed
Support addSpecialTokens parameter
1 parent 840c377 commit 44ad0ae

File tree

3 files changed

+82
-29
lines changed

3 files changed

+82
-29
lines changed

Native/src/lib.rs

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -253,31 +253,46 @@ pub unsafe extern "C" fn free_tokenizer(tokenizer_handle: *mut Tokenizer)
253253
#[no_mangle]
254254
pub unsafe extern "C" fn tokenizer_encode(
255255
tokenizer_ptr: *mut Tokenizer,
256-
text_buffer: ReadOnlyBuffer<u8>) -> TokenizeOutput
256+
text_buffer: ReadOnlyBuffer<u8>,
257+
add_special_tokens: bool)
258+
-> TokenizeOutput
257259
{
258-
return tokenizer_encode_core(tokenizer_ptr, text_buffer, true);
260+
return tokenizer_encode_core(
261+
tokenizer_ptr,
262+
text_buffer,
263+
true,
264+
add_special_tokens
265+
);
259266
}
260267

261268
#[no_mangle]
262269
pub unsafe extern "C" fn tokenizer_encode_non_truncating(
263270
tokenizer_ptr: *mut Tokenizer,
264-
text_buffer: ReadOnlyBuffer<u8>) -> TokenizeOutput
271+
text_buffer: ReadOnlyBuffer<u8>,
272+
add_special_tokens: bool)
273+
-> TokenizeOutput
265274
{
266-
return tokenizer_encode_core(tokenizer_ptr, text_buffer, false);
275+
return tokenizer_encode_core(
276+
tokenizer_ptr,
277+
text_buffer,
278+
false,
279+
add_special_tokens
280+
);
267281
}
268282

269283
#[inline(always)]
270284
pub unsafe extern "C" fn tokenizer_encode_core(
271285
tokenizer_ptr: *mut Tokenizer,
272286
text_buffer: ReadOnlyBuffer<u8>,
273-
truncate: bool)
287+
truncate: bool,
288+
add_special_tokens: bool)
274289
-> TokenizeOutput
275290
{
276291
let tokenizer = &*tokenizer_ptr;
277292

278293
let text = std::str::from_utf8_unchecked(text_buffer.as_slice());
279294

280-
let encoded_result = tokenizer.encode_fast(text, true);
295+
let encoded_result = tokenizer.encode_fast(text, add_special_tokens);
281296

282297
let encoded_tokens = match encoded_result
283298
{
@@ -292,26 +307,41 @@ pub unsafe extern "C" fn tokenizer_encode_core(
292307
pub unsafe extern "C" fn tokenizer_encode_batch(
293308
tokenizer_ptr: *mut Tokenizer,
294309
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
295-
output_buffer: Buffer<TokenizeOutput>)
310+
output_buffer: Buffer<TokenizeOutput>,
311+
add_special_tokens: bool)
296312
{
297-
tokenizer_encode_batch_core(tokenizer_ptr, text_buffers, output_buffer, true);
313+
tokenizer_encode_batch_core(
314+
tokenizer_ptr,
315+
text_buffers,
316+
output_buffer,
317+
true,
318+
add_special_tokens
319+
);
298320
}
299321

300322
#[no_mangle]
301323
pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating(
302324
tokenizer_ptr: *mut Tokenizer,
303325
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
304-
output_buffer: Buffer<TokenizeOutput>)
326+
output_buffer: Buffer<TokenizeOutput>,
327+
add_special_tokens: bool)
305328
{
306-
tokenizer_encode_batch_core(tokenizer_ptr, text_buffers, output_buffer, false);
329+
tokenizer_encode_batch_core(
330+
tokenizer_ptr,
331+
text_buffers,
332+
output_buffer,
333+
false,
334+
add_special_tokens
335+
);
307336
}
308337

309338
#[inline(always)]
310339
pub unsafe extern "C" fn tokenizer_encode_batch_core(
311340
tokenizer_ptr: *mut Tokenizer,
312341
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
313342
output_buffer: Buffer<TokenizeOutput>,
314-
truncate: bool)
343+
truncate: bool,
344+
add_special_tokens: bool)
315345
{
316346
let tokenizer = &*tokenizer_ptr;
317347

@@ -321,7 +351,7 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
321351
.map(|text_buffer| std::str::from_utf8_unchecked(text_buffer.as_slice()))
322352
.collect::<Vec<&str>>();
323353

324-
let encoded_result = tokenizer.encode_batch_fast(texts, true);
354+
let encoded_result = tokenizer.encode_batch_fast(texts, add_special_tokens);
325355

326356
let encoded_tokens = match encoded_result
327357
{

Tokenizers.NET/Tokenizer.cs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ public Tokenizer()
196196
}
197197

198198
[SkipLocalsInit]
199-
public TokenizeOutput Tokenize(string input)
199+
public TokenizeOutput Tokenize(string input, bool addSpecialTokens = true)
200200
{
201201
Span<byte> allocation;
202202

@@ -236,7 +236,12 @@ public TokenizeOutput Tokenize(string input)
236236

237237
var u8String = new ReadOnlyNativeBuffer<byte>(ref MemoryMarshal.GetReference(allocation), (nuint) bytesWritten);
238238

239-
var result = TokenizerNativeMethods.TokenizerEncode(TokenizerHandle, u8String, TRUNCATE);
239+
var result = TokenizerNativeMethods.TokenizerEncode(
240+
TokenizerHandle,
241+
u8String,
242+
addSpecialTokens,
243+
TRUNCATE
244+
);
240245

241246
if (useNativeMemory)
242247
{
@@ -246,35 +251,38 @@ public TokenizeOutput Tokenize(string input)
246251
return result;
247252
}
248253

249-
public void TokenizeBatch(ReadOnlySpan<string> inputs, Span<TokenizeOutput> outputs)
254+
public void TokenizeBatch(ReadOnlySpan<string> inputs, Span<TokenizeOutput> outputs, bool addSpecialTokens = true)
250255
{
251256
TokenizeBatchInternal(
252257
inputs,
253258
outputs,
254259
outputsPrePinned: false,
255-
skipLengthCheck: false
260+
skipLengthCheck: false,
261+
addSpecialTokens: addSpecialTokens
256262
);
257263
}
258264

259-
public void TokenizeBatch(ReadOnlySpan<string> inputs, NativeMemory<TokenizeOutput> outputs)
265+
public void TokenizeBatch(ReadOnlySpan<string> inputs, NativeMemory<TokenizeOutput> outputs, bool addSpecialTokens = true)
260266
{
261267
TokenizeBatchInternal(
262268
inputs,
263269
outputs.Buffer.AsSpan(),
264270
outputsPrePinned: true,
265-
skipLengthCheck: false
271+
skipLengthCheck: false,
272+
addSpecialTokens: addSpecialTokens
266273
);
267274
}
268275

269-
public NativeMemory<TokenizeOutput> TokenizeBatch(ReadOnlySpan<string> inputs)
276+
public NativeMemory<TokenizeOutput> TokenizeBatch(ReadOnlySpan<string> inputs, bool addSpecialTokens = true)
270277
{
271278
var outputs = new NativeMemory<TokenizeOutput>((nuint) inputs.Length);
272279

273280
TokenizeBatchInternal(
274281
inputs,
275282
outputs.Buffer.AsSpan(),
276283
outputsPrePinned: true,
277-
skipLengthCheck: true
284+
skipLengthCheck: true,
285+
addSpecialTokens: addSpecialTokens
278286
);
279287

280288
return outputs;
@@ -286,7 +294,8 @@ private void TokenizeBatchInternal(
286294
ReadOnlySpan<string> inputs,
287295
Span<TokenizeOutput> outputs,
288296
bool outputsPrePinned,
289-
bool skipLengthCheck)
297+
bool skipLengthCheck,
298+
bool addSpecialTokens)
290299
{
291300
var numInputs = inputs.Length;
292301

@@ -358,6 +367,7 @@ stackalloc ReadOnlyNativeBuffer<byte>[numInputs]
358367
tokenizerPtr: tokenizerHandle,
359368
textNativeBuffers: readonlyU8Strings,
360369
outputNativeBuffer: new(ref outputStart, outputLengthNative),
370+
addSpecialTokens,
361371
truncate
362372
);
363373
}
@@ -370,6 +380,7 @@ stackalloc ReadOnlyNativeBuffer<byte>[numInputs]
370380
tokenizerPtr: tokenizerHandle,
371381
textNativeBuffers: readonlyU8Strings,
372382
outputNativeBuffer: new(outputsPtr, outputLengthNative),
383+
addSpecialTokens,
373384
truncate
374385
);
375386
}

Tokenizers.NET/TokenizerNativeMethods.cs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,55 +18,67 @@ internal static unsafe class TokenizerNativeMethods
1818
public static TokenizeOutput TokenizerEncode(
1919
nint tokenizerPtr,
2020
ReadOnlyNativeBuffer<byte> textNativeBuffer,
21+
bool addSpecialTokens,
2122
bool truncate)
2223
{
2324
if (truncate)
2425
{
25-
return TokenizerEncode(tokenizerPtr, textNativeBuffer);
26+
return TokenizerEncode(tokenizerPtr, textNativeBuffer, addSpecialTokens);
2627
}
2728

2829
else
2930
{
30-
return TokenizerEncodeNonTruncating(tokenizerPtr, textNativeBuffer);
31+
return TokenizerEncodeNonTruncating(tokenizerPtr, textNativeBuffer, addSpecialTokens);
3132
}
3233
}
3334

3435
[DllImport(DLL_NAME, EntryPoint = "tokenizer_encode", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
35-
private static extern TokenizeOutput TokenizerEncode(nint tokenizerPtr, ReadOnlyNativeBuffer<byte> textNativeBuffer);
36+
private static extern TokenizeOutput TokenizerEncode(
37+
nint tokenizerPtr,
38+
ReadOnlyNativeBuffer<byte> textNativeBuffer,
39+
[MarshalAs(UnmanagedType.U1)] bool addSpecialTokens
40+
);
3641

3742
[DllImport(DLL_NAME, EntryPoint = "tokenizer_encode_non_truncating", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
38-
private static extern TokenizeOutput TokenizerEncodeNonTruncating(nint tokenizerPtr, ReadOnlyNativeBuffer<byte> textNativeBuffer);
43+
private static extern TokenizeOutput TokenizerEncodeNonTruncating(
44+
nint tokenizerPtr,
45+
ReadOnlyNativeBuffer<byte> textNativeBuffer,
46+
[MarshalAs(UnmanagedType.U1)] bool addSpecialTokens
47+
);
3948

4049
[MethodImpl(MethodImplOptions.AggressiveInlining)]
4150
public static void TokenizerEncodeBatch(
4251
nint tokenizerPtr,
4352
ReadOnlyNativeBuffer<ReadOnlyNativeBuffer<byte>> textNativeBuffers,
4453
NativeBuffer<TokenizeOutput> outputNativeBuffer,
54+
bool addSpecialTokens,
4555
bool truncate)
4656
{
4757
if (truncate)
4858
{
49-
TokenizerEncodeBatch(tokenizerPtr, textNativeBuffers, outputNativeBuffer);
59+
TokenizerEncodeBatch(tokenizerPtr, textNativeBuffers, outputNativeBuffer, addSpecialTokens);
5060
}
5161

5262
else
5363
{
54-
TokenizerEncodeBatchNonTruncating(tokenizerPtr, textNativeBuffers, outputNativeBuffer);
64+
TokenizerEncodeBatchNonTruncating(tokenizerPtr, textNativeBuffers, outputNativeBuffer, addSpecialTokens);
5565
}
5666
}
5767

5868
[DllImport(DLL_NAME, EntryPoint = "tokenizer_encode_batch", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
5969
private static extern void TokenizerEncodeBatch(
6070
nint tokenizerPtr,
6171
ReadOnlyNativeBuffer<ReadOnlyNativeBuffer<byte>> textNativeBuffers,
62-
NativeBuffer<TokenizeOutput> outputNativeBuffer
72+
NativeBuffer<TokenizeOutput> outputNativeBuffer,
73+
[MarshalAs(UnmanagedType.U1)] bool addSpecialTokens
6374
);
6475

6576
[DllImport(DLL_NAME, EntryPoint = "tokenizer_encode_batch_non_truncating", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
6677
private static extern void TokenizerEncodeBatchNonTruncating(
6778
nint tokenizerPtr,
6879
ReadOnlyNativeBuffer<ReadOnlyNativeBuffer<byte>> textNativeBuffers,
69-
NativeBuffer<TokenizeOutput> outputNativeBuffer
80+
NativeBuffer<TokenizeOutput> outputNativeBuffer,
81+
[MarshalAs(UnmanagedType.U1)] bool addSpecialTokens
7082
);
7183

7284
[MethodImpl(MethodImplOptions.AggressiveInlining)]

0 commit comments

Comments
 (0)