Skip to content

Commit 907997a

Browse files
committed
- Implemented Tokenizer.IDsToTokens()
- TempFixedAllocator's memory is now 128b-aligned - Moved output data structures to Output directory
1 parent d44b3f0 commit 907997a

File tree

11 files changed

+191
-10
lines changed

11 files changed

+191
-10
lines changed

Codegen/Program.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Threading;
66
using Tokenizers.NET;
77
using Tokenizers.NET.Collections;
8+
using Tokenizers.NET.Outputs;
89

910
namespace Codegen
1011
{

Native/src/lib.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
use std::string::String;
12
use std::marker::PhantomData;
23
use std::ptr::{ null, null_mut };
34
use std::slice;
45
use tokenizers::tokenizer::Tokenizer;
56
use tokenizers::Encoding;
6-
77
// #[inline(always)] is used aggressively - Realistically we only have a few callsites.
88

99
#[repr(C)]
@@ -445,6 +445,34 @@ pub unsafe extern "C" fn tokenizer_decode_core(
445445
return DecodeOutput::from_text(text);
446446
}
447447

448+
#[no_mangle]
449+
#[inline(always)]
450+
pub unsafe extern "C" fn ids_to_tokens(
451+
tokenizer_ptr: *mut Tokenizer,
452+
id_buffer: NativeBuffer<u32>,
453+
token_buffer: NativeBuffer<NativeBuffer<u8>>)
454+
-> *mut DropHandle<Vec<String>>
455+
{
456+
let tokenizer = &*tokenizer_ptr;
457+
458+
let mut token_buffers = Vec::with_capacity(id_buffer.length);
459+
460+
let mut current_token_ptr = token_buffer.ptr.mutable;
461+
462+
for id in id_buffer.as_slice()
463+
{
464+
let mut token = tokenizer.id_to_token(*id).unwrap();
465+
466+
*current_token_ptr = NativeBuffer::from_mutable_vec(token.as_mut_vec());
467+
468+
current_token_ptr = current_token_ptr.add(1);
469+
470+
token_buffers.push(token);
471+
}
472+
473+
return DropHandle::from_value_and_allocate_box(token_buffers);
474+
}
475+
448476
#[no_mangle]
449477
#[inline(always)]
450478
pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>)

Sample/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ private static void Main(string[] args)
4747

4848
foreach (var token in outputSpan)
4949
{
50-
const bool TEST_OVERFLOW = true;
50+
const bool TEST_OVERFLOW = false;
5151

5252
if (TEST_OVERFLOW)
5353
{

Tests/DecodeTests.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Text;
12
using Allure.NUnit;
23
using FluentAssertions;
34
using Tokenizers.NET;
@@ -107,5 +108,33 @@ public void DecodeMutatingStressTest()
107108
x.Should().Be(text);
108109
}
109110
}
111+
112+
[Test]
113+
public void IDsToTokens()
114+
{
115+
ref var tokenizer = ref FlorenceTokenizer;
116+
117+
const nuint MAX_VALUE = 500;
118+
119+
var stringBuilder = new StringBuilder();
120+
121+
for (nuint i = 1; i <= MAX_VALUE; i++)
122+
{
123+
var text = AllocateStringWithRandomChars((int) i);
124+
125+
using var tokenizeResult = tokenizer.Tokenize(text, addSpecialTokens: false);
126+
127+
var tokens = tokenizer.IDsToTokens(tokenizeResult.IDs);
128+
129+
foreach (var token in tokens)
130+
{
131+
stringBuilder.Append(token.Replace('Ġ', ' '));
132+
}
133+
134+
stringBuilder.ToString().Should().Be(text);
135+
136+
stringBuilder.Clear();
137+
}
138+
}
110139
}
111140
}

Tests/EncodeTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using FluentAssertions;
44
using Tokenizers.NET;
55
using Tokenizers.NET.Collections;
6+
using Tokenizers.NET.Outputs;
67

78
namespace Tests
89
{

Tokenizers.NET/Helpers/ThrowHelpers.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,11 @@ public static void UTF8EncodingPirated_GetMaxCharCount_OutOfRange()
3232
{
3333
throw new InvalidOperationException("Too many bytes. The resulting number of chars is larger than what can be returned as an int.");
3434
}
35+
36+
[DoesNotReturn]
37+
public static void IDsToTokens_LengthCheckFailed()
38+
{
39+
throw new ArgumentException("Output Span / Buffer length must be more than or equal to the input length.");
40+
}
3541
}
3642
}

Tokenizers.NET/DecodeOutput.cs renamed to Tokenizers.NET/Outputs/DecodeOutput.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using System.Text;
55
using Tokenizers.NET.Collections;
66

7-
namespace Tokenizers.NET
7+
namespace Tokenizers.NET.Outputs
88
{
99
[StructLayout(LayoutKind.Sequential)]
1010
public readonly struct DecodeOutput: IDisposable

Tokenizers.NET/Outputs/FreeHandle.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Runtime.CompilerServices;
3+
4+
namespace Tokenizers.NET.Outputs
5+
{
6+
public readonly struct FreeHandle(nint handle): IDisposable
7+
{
8+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
9+
public void Dispose()
10+
{
11+
TokenizerNativeMethods.FreeWithHandle(handle);
12+
}
13+
}
14+
}

Tokenizers.NET/TokenizeOutput.cs renamed to Tokenizers.NET/Outputs/TokenizeOutput.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22
using System.Runtime.CompilerServices;
33
using System.Runtime.InteropServices;
44
using Tokenizers.NET.Collections;
5-
#if DEBUG
6-
using System.Diagnostics;
7-
#endif
85

9-
namespace Tokenizers.NET
6+
namespace Tokenizers.NET.Outputs
107
{
118
public interface ITokenizeOutput
129
{

Tokenizers.NET/Tokenizer.cs

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Tokenizers.NET.Collections;
1010
using Tokenizers.NET.Enumerators;
1111
using Tokenizers.NET.Helpers;
12+
using Tokenizers.NET.Outputs;
1213

1314
namespace Tokenizers.NET
1415
{
@@ -171,15 +172,28 @@ private static readonly int
171172

172173
private readonly int Count;
173174

175+
// Modern cacheline size is either 64 or 128 bytes,
176+
// reducing cross-cacheline reads for SIMD instructions.
177+
// This should also satisfy the alignment for NativeBuffer<NativeBuffer<byte>>,
178+
// enabling us to reinterpret the memory in IDsToTokens() to avoid allocation.
179+
private const int ALIGNMENT = 128;
180+
181+
static TempFixedAllocator()
182+
{
183+
Debug.Assert(ALIGNMENT % sizeof(NativeBuffer<NativeBuffer<byte>>) == 0);
184+
}
185+
174186
public TempFixedAllocator()
175187
{
176188
var maxExpectedBatches = Config.ExpectedMaxBatches.ToSignedUnchecked();
177189

178-
var buffers = Buffers = AllocationHelpers.AllocatePinnedUninitialized<byte>(
179-
TOTAL_BUFFER_SIZE
190+
var buffers = Buffers = AllocationHelpers.AllocatePinnedUninitializedAligned<byte>(
191+
TOTAL_BUFFER_SIZE,
192+
ALIGNMENT,
193+
out var buffersPtr
180194
);
181195

182-
BuffersPtr = buffers.PinnedArrayToPointer();
196+
BuffersPtr = buffersPtr;
183197

184198
Count = maxExpectedBatches;
185199

@@ -525,6 +539,88 @@ public DecodeOutput DecodeMutating(NativeBuffer<ulong> ids, bool skipSpecialToke
525539
skipSpecialTokens
526540
);
527541
}
542+
543+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
544+
public FreeHandle IDsToTokens(NativeBuffer<uint> ids, Span<NativeBuffer<byte>> u8Strings)
545+
{
546+
fixed (NativeBuffer<byte>* ptr = &MemoryMarshal.GetReference(u8Strings))
547+
{
548+
var u8StringsBuffer = new NativeBuffer<NativeBuffer<byte>>(ptr, (nuint) u8Strings.Length);
549+
550+
return IDsToTokens(ids, u8StringsBuffer);
551+
}
552+
}
553+
554+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
555+
public FreeHandle IDsToTokens(
556+
NativeBuffer<uint> ids,
557+
NativeBuffer<NativeBuffer<byte>> tokens,
558+
bool performSizeCheck = true)
559+
{
560+
if (performSizeCheck && tokens.Length < ids.Length)
561+
{
562+
ThrowHelpers.IDsToTokens_LengthCheckFailed();
563+
}
564+
565+
var tokenizerHandle = TokenizerHandle;
566+
567+
return new(TokenizerNativeMethods.IDsToTokens(tokenizerHandle, ids, tokens));
568+
}
569+
570+
public string[] IDsToTokens(NativeBuffer<uint> ids)
571+
{
572+
var tokens = new string[ids.Length];
573+
574+
IDsToTokens(ids, tokens, performSizeCheck: false);
575+
576+
return tokens;
577+
}
578+
579+
public void IDsToTokens(NativeBuffer<uint> ids, Span<string> tokens, bool performSizeCheck = true)
580+
{
581+
var inputLength = ids.Length;
582+
583+
if (performSizeCheck && (nuint) tokens.Length < inputLength)
584+
{
585+
ThrowHelpers.IDsToTokens_LengthCheckFailed();
586+
}
587+
588+
var allocationSizeInBytes = (int) inputLength * sizeof(NativeBuffer<NativeBuffer<byte>>);
589+
590+
var allocateNative = allocationSizeInBytes > (Config.ExpectedMaxInputLength * Config.ExpectedMaxBatches);
591+
592+
NativeBuffer<NativeBuffer<byte>> allocation;
593+
594+
if (!allocateNative)
595+
{
596+
var ptr = Allocator.GetFullAllocationUnsafely().Ptr;
597+
598+
allocation = new((NativeBuffer<byte>*) ptr, inputLength);
599+
}
600+
601+
else
602+
{
603+
allocation = new NativeMemory<NativeBuffer<byte>>(inputLength).Buffer;
604+
}
605+
606+
using var freeHandle = IDsToTokens(ids, allocation, performSizeCheck: false);
607+
608+
ref var currentToken = ref MemoryMarshal.GetReference(tokens);
609+
610+
foreach (var buffer in allocation)
611+
{
612+
// In theory, we could intern the tokenizer's vocab and greatly reduce string allocs,
613+
// but it is what it is for now...
614+
currentToken = Encoding.UTF8.GetString(buffer.Ptr, (int) buffer.Length);
615+
616+
currentToken = ref Unsafe.Add(ref currentToken, 1);
617+
}
618+
619+
if (allocateNative)
620+
{
621+
NativeMemory<NativeBuffer<byte>>.FreeWithPtrUnsafely(allocation.Ptr);
622+
}
623+
}
528624

529625
public void Dispose()
530626
{

Tokenizers.NET/TokenizerNativeMethods.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Runtime.CompilerServices;
22
using System.Runtime.InteropServices;
33
using Tokenizers.NET.Collections;
4+
using Tokenizers.NET.Outputs;
45

56
namespace Tokenizers.NET
67
{
@@ -118,6 +119,14 @@ public static DecodeOutput TokenizerDecode(
118119
[LibraryImport(DLL_NAME, EntryPoint = "tokenizer_decode_skip_special_tokens")]
119120
private static partial DecodeOutput TokenizerDecodeSkipSpecialTokens(nint tokenizerPtr, NativeBuffer<uint> idBuffer);
120121

122+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
123+
[LibraryImport(DLL_NAME, EntryPoint = "ids_to_tokens")]
124+
public static partial nint IDsToTokens(
125+
nint tokenizerPtr,
126+
NativeBuffer<uint> idBuffer,
127+
NativeBuffer<NativeBuffer<byte>> tokenBuffer
128+
);
129+
121130
[MethodImpl(MethodImplOptions.AggressiveInlining)]
122131
[LibraryImport(DLL_NAME, EntryPoint = "free_with_handle")]
123132
public static partial void FreeWithHandle(nint handle);

0 commit comments

Comments
 (0)