Skip to content

Commit 70d7603

Browse files
CopilotstephentoubCopilottarekgh
authored
Add heap-based BPE merge path for large inputs (>128 bytes) (#7580)
* Initial plan * Implement heap-based byte pair encoding for large inputs Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Add tests for large input BPE optimization Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Address code review feedback - improve test coverage and optimize heap capacity Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Remove timing-based performance test to prevent CI flakiness Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Add explanatory comments for threshold and heap capacity choices Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Apply suggestion from @stephentoub * Use default capacity for PriorityQueue instead of pre-allocating to max Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Apply suggestion from @stephentoub * Update src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix CompareTo ordering, revert BOM changes, and update test comments Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> * Add parameterless PriorityQueue constructor and use it in BytePairEncodeLarge Co-authored-by: tarekgh <10833894+tarekgh@users.noreply.github.com> * Add comment noting that CurRank assumes rank == token Id (Tiktoken-specific) Co-authored-by: tarekgh <10833894+tarekgh@users.noreply.github.com> * Remove stackalloc for State array; always use ArrayPool since method is only called for >128 bytes Co-authored-by: tarekgh <10833894+tarekgh@users.noreply.github.com> * Replace List+ToArray with ArrayPool for result buffer in BytePairEncodeLarge Co-authored-by: tarekgh <10833894+tarekgh@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: stephentoub <2642209+stephentoub@users.noreply.github.com> Co-authored-by: Stephen Toub <stoub@microsoft.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: tarekgh <10833894+tarekgh@users.noreply.github.com>
1 parent 989f5a0 commit 70d7603

File tree

3 files changed

+246
-1
lines changed

3 files changed

+246
-1
lines changed

src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnl
2020
return [(ranks[mergingBytes], 0, 1)];
2121
}
2222

23+
// For large inputs, use heap-based algorithm to avoid O(n²) behavior.
24+
// Threshold of 128 chosen empirically: linear scan is cache-friendly for small inputs,
25+
// while heap overhead (O(log n) per operation) becomes worthwhile for larger inputs.
26+
// Based on upstream tiktoken using 100, adjusted upward for C#'s efficient span operations.
27+
if (mergingBytes.Length > 128)
28+
{
29+
return BytePairEncodeLarge(mergingBytes, ranks, indexMappingSpan);
30+
}
31+
2332
(int Index, int Rank)[]? arrayPoolArray = null;
2433
int requiredLength = mergingBytes.Length + 1;
2534
Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ?
@@ -116,6 +125,168 @@ int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int
116125
return result;
117126
}
118127

128+
private struct State
129+
{
130+
public int Prev;
131+
public int End;
132+
public int NextEnd;
133+
public int NextRank;
134+
// Note: In the Tiktoken tokenizer, the rank is also the token Id.
135+
// This field is used to cache the rank/Id after a merge so we don't need to re-look it up.
136+
// Using this code with a different tokenizer where rank != token Id would produce wrong results.
137+
public int CurRank;
138+
}
139+
140+
private struct MergeEntry : IComparable<MergeEntry>
141+
{
142+
public int Rank;
143+
public int Start;
144+
145+
public int CompareTo(MergeEntry other)
146+
{
147+
int rankComparison = Rank.CompareTo(other.Rank);
148+
if (rankComparison != 0)
149+
{
150+
return rankComparison;
151+
}
152+
return Start.CompareTo(other.Start);
153+
}
154+
}
155+
156+
private static (int Id, int TokenIndex, int TokenLength)[] BytePairEncodeLarge(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
157+
{
158+
int stateLength = mergingBytes.Length;
159+
State[] statePoolArray = ArrayPool<State>.Shared.Rent(stateLength);
160+
Span<State> state = statePoolArray.AsSpan(0, stateLength);
161+
162+
state[0] = new State
163+
{
164+
Prev = int.MaxValue,
165+
End = 1,
166+
NextEnd = 2,
167+
NextRank = int.MaxValue,
168+
CurRank = int.MaxValue
169+
};
170+
171+
var heap = new PriorityQueue<MergeEntry>();
172+
173+
for (int i = 0; i < mergingBytes.Length - 1; i++)
174+
{
175+
var slice = mergingBytes.Slice(i, 2);
176+
if (ranks.TryGetValue(slice, out int rank))
177+
{
178+
heap.Enqueue(new MergeEntry { Start = i, Rank = rank });
179+
state[i].NextRank = rank;
180+
}
181+
182+
state[i + 1] = new State
183+
{
184+
Prev = i,
185+
End = i + 2,
186+
NextEnd = i + 3,
187+
NextRank = int.MaxValue,
188+
CurRank = int.MaxValue
189+
};
190+
}
191+
192+
// Local function to add a potential merge to the heap.
193+
void PotentialMerge(Span<State> stateSpan, PriorityQueue<MergeEntry> heapQueue, int start, int nextEndItem)
194+
{
195+
stateSpan[start].NextEnd = nextEndItem;
196+
stateSpan[start].NextRank = int.MaxValue;
197+
198+
if (nextEndItem <= mergingBytes.Length)
199+
{
200+
var slice = mergingBytes.Slice(start, nextEndItem - start);
201+
if (ranks.TryGetValue(slice, out int rank))
202+
{
203+
heapQueue.Enqueue(new MergeEntry { Start = start, Rank = rank });
204+
stateSpan[start].NextRank = rank;
205+
}
206+
}
207+
}
208+
209+
while (heap.Count > 0)
210+
{
211+
MergeEntry left = heap.Dequeue();
212+
213+
if (left.Rank == int.MaxValue)
214+
{
215+
break;
216+
}
217+
218+
if (left.Rank != state[left.Start].NextRank)
219+
{
220+
continue;
221+
}
222+
223+
int leftStart = left.Start;
224+
int rightStart = state[leftStart].End;
225+
int rightEnd = state[leftStart].NextEnd;
226+
int rightNextEnd = state[rightStart].NextEnd;
227+
228+
state[leftStart].CurRank = state[leftStart].NextRank;
229+
state[leftStart].End = rightEnd;
230+
PotentialMerge(state, heap, leftStart, rightNextEnd);
231+
232+
if (rightEnd < state.Length)
233+
{
234+
state[rightEnd].Prev = leftStart;
235+
}
236+
237+
if (leftStart > 0)
238+
{
239+
int prevStart = state[leftStart].Prev;
240+
PotentialMerge(state, heap, prevStart, rightEnd);
241+
}
242+
243+
state[rightStart].NextRank = int.MaxValue;
244+
}
245+
246+
// Use ArrayPool for the result buffer to avoid List<T> overhead.
247+
// The maximum number of tokens is mergingBytes.Length (no merges).
248+
var resultPoolArray = ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Rent(mergingBytes.Length);
249+
int resultCount = 0;
250+
int currentIndex = 0;
251+
252+
while (currentIndex < state.Length)
253+
{
254+
int startIndex = currentIndex;
255+
int endIndex = state[currentIndex].End;
256+
257+
int mappedStartIndex = indexMappingSpan[startIndex];
258+
int mappedEndIndex = indexMappingSpan[endIndex];
259+
260+
int finalEndIndex = endIndex;
261+
262+
// Handle partial characters/elements at token boundaries.
263+
// If the byte at endIndex-1 maps to the same character as endIndex,
264+
// extend the token to include the complete character.
265+
if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
266+
{
267+
finalEndIndex++;
268+
while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
269+
{
270+
finalEndIndex++;
271+
}
272+
}
273+
274+
int tokenId = state[currentIndex].CurRank != int.MaxValue
275+
? state[currentIndex].CurRank
276+
: ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)];
277+
278+
resultPoolArray[resultCount++] = (tokenId, mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex);
279+
280+
currentIndex = state[currentIndex].End;
281+
}
282+
283+
ArrayPool<State>.Shared.Return(statePoolArray);
284+
285+
var result = resultPoolArray.AsSpan(0, resultCount).ToArray();
286+
ArrayPool<(int Id, int TokenIndex, int TokenLength)>.Shared.Return(resultPoolArray);
287+
return result;
288+
}
289+
119290
private static ReadOnlyMemory<byte> SliceStartEnd(this ReadOnlyMemory<byte> memory, int start, int end) => memory.Slice(start, end - start);
120291
}
121292
}

src/Microsoft.ML.Tokenizers/Utils/PriorityQueue.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ internal class PriorityQueue<T> where T : IComparable<T>
1212
{
1313
private readonly List<T> _data;
1414

15+
public PriorityQueue() : this(0)
16+
{
17+
}
18+
1519
public PriorityQueue(int capacity)
1620
{
1721
_data = new List<T>(capacity);

test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Microsoft.DotNet.RemoteExecutor;
65
using System;
76
using System.Buffers;
87
using System.Collections.Generic;
@@ -13,6 +12,7 @@
1312
using System.Text;
1413
using System.Text.Json;
1514
using System.Threading.Tasks;
15+
using Microsoft.DotNet.RemoteExecutor;
1616
using Xunit;
1717

1818
namespace Microsoft.ML.Tokenizers.Tests
@@ -852,6 +852,76 @@ public void TestOss()
852852

853853
private static IReadOnlyDictionary<string, int>? GetVocabulary(TiktokenTokenizer tiktoken)
854854
=> typeof(TiktokenTokenizer).GetProperty("Vocabulary", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<string, int>;
855+
856+
[Fact]
857+
public void TestLargeInputOptimization()
858+
{
859+
// Verify that large inputs (>128 bytes) and boundary cases round-trip correctly via the public API.
860+
// This exercises the large-input optimization path but does not directly compare it to the small-input path.
861+
862+
// Test with repeated characters - this is the adversarial case that caused O(n^2) behavior
863+
string largeRepeatedInput = new string('a', 1000);
864+
IReadOnlyList<int> ids = GPT4.EncodeToIds(largeRepeatedInput);
865+
string decoded = GPT4.Decode(ids);
866+
Assert.Equal(largeRepeatedInput, decoded);
867+
868+
// Test with a more realistic large input
869+
string largeMixedInput = string.Join(" ", Enumerable.Repeat("Hello World! This is a test.", 50));
870+
IReadOnlyList<int> mixedIds = GPT4.EncodeToIds(largeMixedInput);
871+
string mixedDecoded = GPT4.Decode(mixedIds);
872+
Assert.Equal(largeMixedInput, mixedDecoded);
873+
874+
// Test boundary case - exactly at threshold (128)
875+
string boundaryInput = new string('x', 128);
876+
IReadOnlyList<int> boundaryIds = GPT4.EncodeToIds(boundaryInput);
877+
string boundaryDecoded = GPT4.Decode(boundaryIds);
878+
Assert.Equal(boundaryInput, boundaryDecoded);
879+
880+
// Test just below threshold (127)
881+
string belowThresholdInput = new string('x', 127);
882+
IReadOnlyList<int> belowIds = GPT4.EncodeToIds(belowThresholdInput);
883+
string belowDecoded = GPT4.Decode(belowIds);
884+
Assert.Equal(belowThresholdInput, belowDecoded);
885+
886+
// Test just above threshold (129)
887+
string aboveThresholdInput = new string('x', 129);
888+
IReadOnlyList<int> aboveIds = GPT4.EncodeToIds(aboveThresholdInput);
889+
string aboveDecoded = GPT4.Decode(aboveIds);
890+
Assert.Equal(aboveThresholdInput, aboveDecoded);
891+
}
892+
893+
[Theory]
894+
[InlineData(200)]
895+
[InlineData(500)]
896+
[InlineData(1000)]
897+
[InlineData(2000)]
898+
public void TestLargeInputConsistency(int length)
899+
{
900+
// Verify that large inputs are handled correctly by the public API and round-trip successfully.
901+
// These tests focus on observable behavior (round-tripping and reconstruction), not on comparing internal code paths.
902+
903+
// Test with repeated character
904+
string inputRepeated = new string('z', length);
905+
IReadOnlyList<int> idsRepeated = GPT4.EncodeToIds(inputRepeated);
906+
907+
// Verify round-trip
908+
string decodedRepeated = GPT4.Decode(idsRepeated);
909+
Assert.Equal(inputRepeated, decodedRepeated);
910+
911+
// Test with mixed content (more realistic scenario)
912+
string inputMixed = string.Join(" ", Enumerable.Repeat("Hello World! Test123", length / 20 + 1)).Substring(0, length);
913+
IReadOnlyList<int> idsMixed = GPT4.EncodeToIds(inputMixed);
914+
string decodedMixed = GPT4.Decode(idsMixed);
915+
Assert.Equal(inputMixed, decodedMixed);
916+
917+
// Verify with EncodingToTokens as well
918+
IReadOnlyList<EncodedToken> tokens = GPT4.EncodeToTokens(inputRepeated, out string? normalizedText);
919+
Assert.Null(normalizedText); // No normalization expected
920+
921+
// Reconstruct from tokens
922+
var reconstructed = string.Concat(tokens.Select(t => t.Value));
923+
Assert.Equal(inputRepeated, reconstructed);
924+
}
855925
}
856926
}
857927

0 commit comments

Comments
 (0)