Skip to content

Commit 2c9ba5d

Browse files
authored
Adding new APIs to avoid passing in allowed special tokens (#27)
* Adding new APIs to avoid passing in allowed special tokens * Update version to 1.3.3
1 parent 512d432 commit 2c9ba5d

File tree

4 files changed

+206
-56
lines changed

4 files changed

+206
-56
lines changed

Tokenizer_C#/TokenizerLib/ITokenizer.cs

+16
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ public interface ITokenizer
2222
/// </summary>
2323
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount);
2424

25+
/// <summary>
26+
/// Encode a string with or without special tokens set through constructor.
27+
/// </summary>
28+
public List<int> Encode(string text, bool applySpecialTokens = true);
29+
30+
/// <summary>
31+
/// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor.
32+
/// </summary>
33+
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, int maxTokenCount, bool applySpecialTokens = true);
34+
35+
36+
/// <summary>
37+
/// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor.
38+
/// </summary>
39+
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, int maxTokenCount, bool applySpecialTokens = true);
40+
2541

2642
/// <summary>
2743
/// Decode an array of integer token ids

Tokenizer_C#/TokenizerLib/TikTokenizer.cs

+169-54
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class TikTokenizer : ITokenizer
2121
{
2222

2323
private IReadOnlyDictionary<string, int> SpecialTokensEncoder = null!;
24+
private IReadOnlyCollection<string> SpecialTokens = null!;
2425
private Regex Regex = null!;
2526
private IReadOnlyDictionary<byte[], int> Encoder = null!;
2627
private IReadOnlyDictionary<int, byte[]> Decoder = null!;
@@ -76,6 +77,7 @@ private void Init(IReadOnlyDictionary<byte[], int> encoder, IReadOnlyDictionary<
7677
Regex = new Regex(pattern, RegexOptions.Compiled);
7778
SpecialTokensRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
7879
SpecialTokensEncoder = specialTokensEncoder;
80+
SpecialTokens = specialTokensEncoder.Keys.ToList();
7981

8082
Decoder = Encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
8183

@@ -136,13 +138,7 @@ private Dictionary<byte[], int> LoadTikTokenBpe(Stream tikTokenBpeFileStream)
136138
return bpeDict;
137139
}
138140

139-
/// <summary>
140-
/// Encode a string with a set of allowed special tokens that are not broken apart.
141-
/// </summary>
142-
/// <param name="text">String to be encoded</param>
143-
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
144-
/// <returns>List of token ids</returns>
145-
public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
141+
private List<int> EncodeInternal(string text, IReadOnlyCollection<string> allowedSpecial)
146142
{
147143
var tokenIds = new List<int>();
148144
int start = 0;
@@ -173,6 +169,43 @@ public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
173169
return tokenIds;
174170
}
175171

172+
/// <summary>
173+
/// Encode a string with a set of allowed special tokens that are not broken apart.
174+
/// </summary>
175+
/// <param name="text">String to be encoded</param>
176+
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
177+
/// <returns>List of token ids</returns>
178+
public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
179+
{
180+
if (allowedSpecial is null || allowedSpecial.Count == 0)
181+
{
182+
return Encode(text, false);
183+
}
184+
return EncodeInternal(text, allowedSpecial);
185+
}
186+
187+
/// <summary>
188+
/// Encode a string with or without special tokens set through constructor.
189+
/// </summary>
190+
/// <param name="text">String to be encoded</param>
191+
/// <param name="applySpecialTokens">Whether to apply special token processing</param>
192+
/// <returns></returns>
193+
public List<int> Encode(string text, bool applySpecialTokens = true)
194+
{
195+
196+
if (applySpecialTokens && SpecialTokens.Count > 0)
197+
{
198+
return EncodeInternal(text, SpecialTokens);
199+
}
200+
201+
var tokenIds = new List<int>();
202+
int start = 0;
203+
Encode(text, tokenIds, start, text.Length);
204+
205+
return tokenIds;
206+
207+
}
208+
176209
/// <summary>
177210
/// Encode a special token matched through regex.
178211
/// </summary>
@@ -194,7 +227,7 @@ private int EncodeSpecialToken(List<int> tokenIds, Match nextSpecial)
194227
/// <param name="start">Start search index in the string</param>
195228
/// <param name="nextSpecial">The regex match of a special token</param>
196229
/// <param name="end">The index of the special token matched or the end of the text</param>
197-
private void FindNextSpecialToken(string text, IReadOnlyCollection<string> allowedSpecial, int start, out Match nextSpecial, out int end)
230+
private void FindNextSpecialToken(string text, IReadOnlyCollection<string>? allowedSpecial, int start, out Match nextSpecial, out int end)
198231
{
199232
int startFind = start;
200233
while (true)
@@ -308,14 +341,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
308341
return (tokenCount, encodeLength);
309342
}
310343

311-
/// <summary>
312-
/// Encode a piece of text limited by max token count through trimming suffix
313-
/// </summary>
314-
/// <param name="text">Text to be encoded</param>
315-
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
316-
/// <param name="maxTokenCount">The max token count</param>
317-
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
318-
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
344+
private (List<int> TokenIds, string Text) EncodeTrimSuffixInternal(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
319345
{
320346
var tokenIds = new List<int>();
321347

@@ -367,21 +393,58 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
367393
}
368394

369395
/// <summary>
370-
/// Encode a piece of text limited by max token count through trimming prefix
396+
/// Encode a piece of text limited by max token count through trimming suffix
371397
/// </summary>
372398
/// <param name="text">Text to be encoded</param>
373399
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
374400
/// <param name="maxTokenCount">The max token count</param>
375-
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
376-
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
401+
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
402+
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
403+
{
404+
if (allowedSpecial is null || allowedSpecial.Count == 0)
405+
{
406+
return EncodeTrimSuffix(text, maxTokenCount, false);
407+
}
408+
409+
return EncodeTrimSuffixInternal(text, allowedSpecial, maxTokenCount);
410+
411+
}
412+
413+
/// <summary>
414+
/// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor.
415+
/// </summary>
416+
/// <param name="text">String to be encoded</param>
417+
/// <param name="maxTokenCount">The max token count</param>
418+
/// <param name="applySpecialTokens">Whether to apply special token processing</param>
419+
/// <returns></returns>
420+
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, int maxTokenCount, bool applySpecialTokens = true)
421+
{
422+
if (applySpecialTokens && SpecialTokens.Count > 0)
423+
{
424+
return EncodeTrimSuffixInternal(text, SpecialTokens, maxTokenCount);
425+
}
426+
427+
var tokenIds = new List<int>();
428+
int start = 0;
429+
int tokenCount = 0;
430+
var encodeLength = 0;
431+
(_, encodeLength) = EncodeTrimSuffix(text, tokenIds, start, text.Length, maxTokenCount, tokenCount, encodeLength);
432+
var encodedText = encodeLength == text.Length ? text : text[..encodeLength];
433+
434+
return (tokenIds, encodedText);
435+
}
436+
437+
private (List<int> TokenIds, string Text) EncodeTrimPrefixInternal(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
377438
{
378439
var tokenIds = new List<int>();
379440

380441
int start = 0;
381442
int tokenCount = 0;
382443
var encodeLength = 0;
383-
var tokenCountMap = new SortedDictionary<int, int>();
384-
tokenCountMap.Add(tokenCount, encodeLength);
444+
var tokenCountMap = new SortedDictionary<int, int>
445+
{
446+
{ tokenCount, encodeLength }
447+
};
385448
while (true)
386449
{
387450
Match nextSpecial;
@@ -390,39 +453,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
390453

391454
if (end > start)
392455
{
393-
foreach (Match match in Regex.Matches(text[start..end]))
394-
{
395-
var piece = match.Value;
396-
397-
if (this.Cache.Lookup(match.Value, out int[] tokens))
398-
{
399-
tokenCount += tokens.Length;
400-
encodeLength += piece.Length;
401-
tokenIds.AddRange(tokens);
402-
tokenCountMap[tokenCount] = encodeLength;
403-
}
404-
else
405-
{
406-
var bytes = Encoding.UTF8.GetBytes(piece);
407-
if (Encoder.TryGetValue(bytes, out int token))
408-
{
409-
tokenCount++;
410-
encodeLength += piece.Length;
411-
tokenIds.Add(token);
412-
tokenCountMap[tokenCount] = encodeLength;
413-
414-
}
415-
else
416-
{
417-
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
418-
this.Cache.Add(piece, encodedTokens.ToArray());
419-
tokenCount += encodedTokens.Count;
420-
encodeLength += piece.Length;
421-
tokenIds.AddRange(encodedTokens);
422-
tokenCountMap[tokenCount] = encodeLength;
423-
}
424-
}
425-
}
456+
Encode(text, tokenIds, start, ref tokenCount, ref encodeLength, tokenCountMap, end);
426457
}
427458

428459
if (nextSpecial.Success)
@@ -442,6 +473,11 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
442473
}
443474
}
444475

476+
return TrimPrefix(text, maxTokenCount, tokenIds, tokenCount, tokenCountMap);
477+
}
478+
479+
private static (List<int> TokenIds, string Text) TrimPrefix(string text, int maxTokenCount, List<int> tokenIds, int tokenCount, SortedDictionary<int, int> tokenCountMap)
480+
{
445481
if (tokenCount <= maxTokenCount)
446482
{
447483
return (tokenIds, text);
@@ -463,6 +499,85 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
463499
return (tokenIds.Skip(actualPrefixTokenCount).ToList(), text[actualPrefixStrLength..]);
464500
}
465501

502+
private void Encode(string text, List<int> tokenIds, int start, ref int tokenCount, ref int encodeLength, SortedDictionary<int, int> tokenCountMap, int end)
503+
{
504+
foreach (Match match in Regex.Matches(text[start..end]))
505+
{
506+
var piece = match.Value;
507+
508+
if (this.Cache.Lookup(match.Value, out int[] tokens))
509+
{
510+
tokenCount += tokens.Length;
511+
encodeLength += piece.Length;
512+
tokenIds.AddRange(tokens);
513+
tokenCountMap[tokenCount] = encodeLength;
514+
}
515+
else
516+
{
517+
var bytes = Encoding.UTF8.GetBytes(piece);
518+
if (Encoder.TryGetValue(bytes, out int token))
519+
{
520+
tokenCount++;
521+
encodeLength += piece.Length;
522+
tokenIds.Add(token);
523+
tokenCountMap[tokenCount] = encodeLength;
524+
525+
}
526+
else
527+
{
528+
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
529+
this.Cache.Add(piece, encodedTokens.ToArray());
530+
tokenCount += encodedTokens.Count;
531+
encodeLength += piece.Length;
532+
tokenIds.AddRange(encodedTokens);
533+
tokenCountMap[tokenCount] = encodeLength;
534+
}
535+
}
536+
}
537+
}
538+
539+
/// <summary>
540+
/// Encode a piece of text limited by max token count through trimming prefix
541+
/// </summary>
542+
/// <param name="text">Text to be encoded</param>
543+
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
544+
/// <param name="maxTokenCount">The max token count</param>
545+
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
546+
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
547+
{
548+
if (allowedSpecial is null || allowedSpecial.Count == 0)
549+
{
550+
return EncodeTrimPrefix(text, maxTokenCount, false);
551+
}
552+
return EncodeTrimPrefixInternal(text, allowedSpecial, maxTokenCount);
553+
}
554+
555+
/// <summary>
556+
/// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor.
557+
/// </summary>
558+
/// <param name="text">Text to be encoded</param>
559+
/// <param name="maxTokenCount">The max token count</param>
560+
/// <param name="applySpecialTokens">Whether to apply special token processing</param>
561+
/// <returns></returns>
562+
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, int maxTokenCount, bool applySpecialTokens = true)
563+
{
564+
if (applySpecialTokens && SpecialTokens.Count > 0)
565+
{
566+
return EncodeTrimPrefixInternal(text, SpecialTokens, maxTokenCount);
567+
}
568+
var tokenIds = new List<int>();
569+
570+
int start = 0;
571+
int tokenCount = 0;
572+
var encodeLength = 0;
573+
var tokenCountMap = new SortedDictionary<int, int>
574+
{
575+
{ tokenCount, encodeLength }
576+
};
577+
Encode(text, tokenIds, start, ref tokenCount, ref encodeLength, tokenCountMap, text.Length);
578+
return TrimPrefix(text, maxTokenCount, tokenIds, tokenCount, tokenCountMap);
579+
}
580+
466581
/// <summary>
467582
/// Decode an array of integer token ids
468583
/// </summary>

Tokenizer_C#/TokenizerLib/TokenizerLib.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
<Title>Tokenizer</Title>
99
<Description>Tokenizer for OpenAI large language models.</Description>
1010
<LangVersion>8.0</LangVersion>
11-
<AssemblyVersion>1.3.2</AssemblyVersion>
11+
<AssemblyVersion>1.3.3</AssemblyVersion>
1212
<FileVersion>$(AssemblyVersion)</FileVersion>
1313
<Version>$(AssemblyVersion)</Version>
1414
<Authors>Microsoft</Authors>

Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs

+20-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public void TestEncode0()
5353
public void TestEncode1()
5454
{
5555
var text = "<|im_start|>Hello World<|im_end|>";
56-
var encoded = Tokenizer.Encode(text, new HashSet<string>(SpecialTokens.Keys));
56+
var encoded = Tokenizer.Encode(text);
5757
Assert.AreEqual(4, encoded.Count);
5858
Assert.AreEqual(100264, encoded[0]);
5959
Assert.AreEqual(9906, encoded[1]);
@@ -70,6 +70,9 @@ public void TestEncode2()
7070
var encoded = Tokenizer.Encode(text, new HashSet<string>(SpecialTokens.Keys));
7171
Assert.AreEqual(5584, encoded.Count);
7272

73+
encoded = Tokenizer.Encode(text, false);
74+
Assert.AreEqual(5584, encoded.Count);
75+
7376
string json = File.ReadAllText("./testData/tokens.json");
7477
var expected = JsonConvert.DeserializeObject<int[]>(json);
7578

@@ -131,6 +134,14 @@ public void TestEncodeTrimSuffix()
131134
Assert.AreEqual(4, encoded.TokenIds.Count);
132135
Assert.AreEqual(text, encoded.Text);
133136

137+
encoded = Tokenizer.EncodeTrimSuffix(text, 4, false);
138+
Assert.AreEqual(4, encoded.TokenIds.Count);
139+
Assert.AreEqual("<|im_start", encoded.Text);
140+
141+
encoded = Tokenizer.EncodeTrimSuffix(text, 4);
142+
Assert.AreEqual(4, encoded.TokenIds.Count);
143+
Assert.AreEqual(text, encoded.Text);
144+
134145
encoded = Tokenizer.EncodeTrimSuffix(text, new HashSet<string>(SpecialTokens.Keys), 5);
135146
Assert.AreEqual(4, encoded.TokenIds.Count);
136147
Assert.AreEqual(text, encoded.Text);
@@ -173,6 +184,14 @@ public void TestEncodeTrimPrefix()
173184
Assert.AreEqual(4, encoded.TokenIds.Count);
174185
Assert.AreEqual(text, encoded.Text);
175186

187+
encoded = Tokenizer.EncodeTrimPrefix(text, 4, false);
188+
Assert.AreEqual(4, encoded.TokenIds.Count);
189+
Assert.AreEqual("im_end|>", encoded.Text);
190+
191+
encoded = Tokenizer.EncodeTrimPrefix(text, 4);
192+
Assert.AreEqual(4, encoded.TokenIds.Count);
193+
Assert.AreEqual(text, encoded.Text);
194+
176195
encoded = Tokenizer.EncodeTrimPrefix(text, new HashSet<string>(SpecialTokens.Keys), 5);
177196
Assert.AreEqual(4, encoded.TokenIds.Count);
178197
Assert.AreEqual(text, encoded.Text);

0 commit comments

Comments
 (0)