Skip to content

Commit 99c620a

Browse files
authored
Add Span support in tokenizer's Model abstraction (#7035)
* Add Span support in tokenizer's Model abstraction * Address the feedback * Use stackalloc instead of the ArrayPool
1 parent c6f5397 commit 99c620a

File tree

12 files changed

+427
-311
lines changed

12 files changed

+427
-311
lines changed

src/Microsoft.ML.Tokenizers/Model/BPE.cs

+71-51
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Buffers;
67
using System.Collections.Generic;
78
using System.IO;
9+
using System.Linq;
810
using System.Runtime.CompilerServices;
911
using System.Text.Json;
1012
using System.Text.Json.Serialization;
@@ -34,20 +36,21 @@ private set
3436
{
3537
_unknownToken = value;
3638

37-
if (value is null)
39+
if (VocabReverse.TryGetValue(0, out string? v))
3840
{
39-
if (VocabReverse.TryGetValue(0, out string? v))
41+
if (v == value)
4042
{
41-
VocabReverse.Remove(0);
42-
if (_vocab.TryGetValue(v, out int id))
43-
{
44-
_vocab.Remove(v);
45-
}
43+
return;
4644
}
45+
46+
VocabReverse.Remove(0);
47+
_vocab.Remove(new StringSpanOrdinalKey(v));
4748
}
48-
else
49+
50+
51+
if (value is not null)
4952
{
50-
_vocab[value] = 0;
53+
_vocab[new StringSpanOrdinalKey(value)] = 0;
5154
VocabReverse[0] = value;
5255
}
5356
}
@@ -68,7 +71,6 @@ private set
6871
/// </summary>
6972
public bool FuseUnknownTokens { get; }
7073

71-
7274
/// <summary>
7375
/// Construct a new Bpe model object to use for text encoding.
7476
/// </summary>
@@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
111113
ContinuingSubwordPrefix = continuingSubwordPrefix;
112114
EndOfWordSuffix = endOfWordSuffix;
113115

114-
(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
115-
_vocab = vocab1 ?? new Dictionary<string, int>();
116-
Cache = new Cache<string, Word>();
116+
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
117+
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
118+
Cache = new StringSpanOrdinalKeyCache<Word>();
117119

118120
VocabReverse = new();
119121

120-
foreach (KeyValuePair<string, int> kvp in Vocab)
122+
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
121123
{
122-
VocabReverse.Add(kvp.Value, kvp.Key);
124+
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
123125
}
124126

125-
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
126-
{
127-
unknownToken = unkToken;
128-
}
129127

130-
UnknownToken = unknownToken;
128+
UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);
131129

132130
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
133131

@@ -197,31 +195,23 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
197195
/// <param name="text">The text to split.</param>
198196
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
199197
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
200-
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
198+
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
201199

202200
/// <summary>
203201
/// Get the number of tokens that the input text will be encoded to.
204202
/// </summary>
205203
/// <param name="text">The text to encode.</param>
206204
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
207205
/// <returns>The number of tokens that the input text will be encoded to.</returns>
208-
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
206+
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
209207

210208
/// <summary>
211209
/// Map the token to encoded Id.
212210
/// </summary>
213211
/// <param name="token">The token to map to the Id.</param>
214212
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
215213
/// <returns>The mapped Id of the token.</returns>
216-
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
217-
{
218-
if (_vocab.TryGetValue(token, out int value))
219-
{
220-
return value;
221-
}
222-
223-
return null;
224-
}
214+
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
225215

226216
/// <summary>
227217
/// Map the encoded Id to the token.
@@ -242,24 +232,27 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
242232
/// <summary>
243233
/// Gets the dictionary mapping tokens to Ids.
244234
/// </summary>
245-
public IReadOnlyDictionary<string, int> Vocab => _vocab;
235+
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
246236

247237
/// Read the given files to extract the vocab and merges
248-
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
238+
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
249239
{
250-
Dictionary<string, int>? dic = JsonSerializer.Deserialize<Dictionary<string, int>>(vocab) as Dictionary<string, int>;
240+
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
241+
Dictionary<StringSpanOrdinalKey, int>? dic = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;
251242

252243
return (dic, ConvertMergesToHashmap(merges));
253244
}
254245

255246
/// The vocabulary assigns a number to each token.
256-
private readonly Dictionary<string, int> _vocab;
247+
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;
248+
249+
private Dictionary<string, int>? _vocabOriginal;
257250

258251
/// Contains the mapping between Pairs and their (rank, newId).
259252
internal Dictionary<Pair<int>, (int, int)> Merges { get; }
260253

261254
/// Contains the cache for optimizing the encoding step.
262-
internal Cache<string, Word>? Cache { get; }
255+
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }
263256

264257
internal static readonly int DefaultCacheCapacity = 10_000;
265258

@@ -309,9 +302,6 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(
309302
return merges;
310303
}
311304

312-
/// Reset the cache.
313-
internal void ClearCache() => Cache?.Clear();
314-
315305
private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();
316306

317307
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -327,38 +317,68 @@ internal string CharToString(char c)
327317
return s;
328318
}
329319

330-
internal Word MergeWord(string w)
320+
internal Word MergeWord(ReadOnlySpan<char> w)
331321
{
332322
Word word = Word.WithCapacity(w.Length);
333323
(int Id, int Len)? unk = null;
334324
int i = 0;
335325

326+
Span<char> buffer = stackalloc char[256];
327+
scoped ReadOnlySpan<char> s;
328+
336329
while (i < w.Length)
337330
{
338331
int length;
339-
string s;
340332

341333
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
342334
{
343335
length = 2;
344-
s = w.Substring(i, length);
336+
s = w.Slice(i, 2);
345337
}
346338
else
347339
{
348340
length = 1;
349-
s = CharToString(w[i]);
341+
s = w.Slice(i, 1);
350342
}
351343

352344
// Add the `continuing_subword_prefix` if relevant
353345
if (i > 0 && ContinuingSubwordPrefix is not null)
354346
{
355-
s = $"{ContinuingSubwordPrefix}{s}";
347+
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
348+
{
349+
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
350+
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
351+
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
352+
}
353+
else
354+
{
355+
#if NETCOREAPP
356+
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
357+
#else
358+
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
359+
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
360+
#endif
361+
}
356362
}
357363

358364
// Add the `end_of_word_suffix` if relevant
359365
if (i + length >= w.Length && EndOfWordSuffix is not null)
360366
{
361-
s = $"{s}{EndOfWordSuffix}";
367+
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
368+
{
369+
s.CopyTo(buffer);
370+
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
371+
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
372+
}
373+
else
374+
{
375+
#if NETCOREAPP
376+
s = $"{s}{EndOfWordSuffix}".AsSpan();
377+
#else
378+
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
379+
s = $"{s1}{EndOfWordSuffix}".AsSpan();
380+
#endif
381+
}
362382
}
363383

364384
if (_vocab.TryGetValue(s, out int id))
@@ -419,17 +439,17 @@ internal List<Token> EncodeWithCache(string text)
419439
Word word;
420440
if (Cache is not null)
421441
{
422-
if (Cache.TryGet(text, out word))
442+
if (Cache.TryGetValue(text, out word))
423443
{
424444
return WordToTokens(ref word);
425445
}
426446

427-
word = MergeWord(text);
447+
word = MergeWord(text.AsSpan());
428448
Cache.Set(text, word);
429449
}
430450
else
431451
{
432-
word = MergeWord(text);
452+
word = MergeWord(text.AsSpan());
433453
}
434454

435455
return WordToTokens(ref word);
@@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
445465
return word.SymbolsCount;
446466
}
447467

448-
internal int EncodeToIdsWithCache(string text, IList<int>? accumulatedIds)
468+
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
449469
{
450470
Word word;
451471

452472
if (Cache is not null)
453473
{
454-
if (Cache.TryGet(text, out Word hit))
474+
if (Cache.TryGetValue(text, out Word hit))
455475
{
456476
return WordToIds(ref hit, accumulatedIds);
457477
}
458478

459479
word = MergeWord(text);
460-
Cache.Set(text, word);
480+
Cache.Set(text.ToString(), word);
461481
}
462482
else
463483
{

0 commit comments

Comments
 (0)