3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
using System ;
6
+ using System . Buffers ;
6
7
using System . Collections . Generic ;
7
8
using System . IO ;
9
+ using System . Linq ;
8
10
using System . Runtime . CompilerServices ;
9
11
using System . Text . Json ;
10
12
using System . Text . Json . Serialization ;
@@ -34,20 +36,21 @@ private set
34
36
{
35
37
_unknownToken = value ;
36
38
37
- if ( value is null )
39
+ if ( VocabReverse . TryGetValue ( 0 , out string ? v ) )
38
40
{
39
- if ( VocabReverse . TryGetValue ( 0 , out string ? v ) )
41
+ if ( v == value )
40
42
{
41
- VocabReverse . Remove ( 0 ) ;
42
- if ( _vocab . TryGetValue ( v , out int id ) )
43
- {
44
- _vocab . Remove ( v ) ;
45
- }
43
+ return ;
46
44
}
45
+
46
+ VocabReverse . Remove ( 0 ) ;
47
+ _vocab . Remove ( new StringSpanOrdinalKey ( v ) ) ;
47
48
}
48
- else
49
+
50
+
51
+ if ( value is not null )
49
52
{
50
- _vocab [ value ] = 0 ;
53
+ _vocab [ new StringSpanOrdinalKey ( value ) ] = 0 ;
51
54
VocabReverse [ 0 ] = value ;
52
55
}
53
56
}
@@ -68,7 +71,6 @@ private set
68
71
/// </summary>
69
72
public bool FuseUnknownTokens { get ; }
70
73
71
-
72
74
/// <summary>
73
75
/// Construct a new Bpe model object to use for text encoding.
74
76
/// </summary>
@@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
111
113
ContinuingSubwordPrefix = continuingSubwordPrefix ;
112
114
EndOfWordSuffix = endOfWordSuffix ;
113
115
114
- ( Dictionary < string , int > ? vocab1 , Vec < ( string , string ) > merges ) = ReadModelData ( vocabStream , mergesStream ) ;
115
- _vocab = vocab1 ?? new Dictionary < string , int > ( ) ;
116
- Cache = new Cache < string , Word > ( ) ;
116
+ ( Dictionary < StringSpanOrdinalKey , int > ? vocab1 , Vec < ( string , string ) > merges ) = ReadModelData ( vocabStream , mergesStream ) ;
117
+ _vocab = vocab1 ?? new Dictionary < StringSpanOrdinalKey , int > ( ) ;
118
+ Cache = new StringSpanOrdinalKeyCache < Word > ( ) ;
117
119
118
120
VocabReverse = new ( ) ;
119
121
120
- foreach ( KeyValuePair < string , int > kvp in Vocab )
122
+ foreach ( KeyValuePair < StringSpanOrdinalKey , int > kvp in _vocab )
121
123
{
122
- VocabReverse . Add ( kvp . Value , kvp . Key ) ;
124
+ VocabReverse . Add ( kvp . Value , kvp . Key . Data ! ) ;
123
125
}
124
126
125
- if ( unknownToken is null && VocabReverse . TryGetValue ( 0 , out string ? unkToken ) )
126
- {
127
- unknownToken = unkToken ;
128
- }
129
127
130
- UnknownToken = unknownToken ;
128
+ UnknownToken = unknownToken ?? ( VocabReverse . TryGetValue ( 0 , out string ? unkToken ) ? unkToken : null ) ;
131
129
132
130
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix . Length ;
133
131
@@ -197,31 +195,23 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
197
195
/// <param name="text">The text to split.</param>
198
196
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
199
197
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
200
- public override void EncodeToIds ( string text , bool isSpecialToken , IList < int > accumulatedIds ) => EncodeToIdsWithCache ( text , accumulatedIds ) ;
198
+ public override void EncodeToIds ( ReadOnlySpan < char > text , bool isSpecialToken , IList < int > accumulatedIds ) => EncodeToIdsWithCache ( text , accumulatedIds ) ;
201
199
202
200
/// <summary>
203
201
/// Get the number of tokens that the input text will be encoded to.
204
202
/// </summary>
205
203
/// <param name="text">The text to encode.</param>
206
204
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
207
205
/// <returns>The number of tokens that the input text will be encoded to.</returns>
208
- public override int CountTokens ( string text , bool isSpecialToken ) => EncodeToIdsWithCache ( text , null ) ;
206
+ public override int CountTokens ( ReadOnlySpan < char > text , bool isSpecialToken ) => EncodeToIdsWithCache ( text , null ) ;
209
207
210
208
/// <summary>
211
209
/// Map the token to encoded Id.
212
210
/// </summary>
213
211
/// <param name="token">The token to map to the Id.</param>
214
212
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
215
213
/// <returns>The mapped Id of the token.</returns>
216
- public override int ? MapTokenToId ( string token , bool considerSpecialTokens = true )
217
- {
218
- if ( _vocab . TryGetValue ( token , out int value ) )
219
- {
220
- return value ;
221
- }
222
-
223
- return null ;
224
- }
214
+ public override int ? MapTokenToId ( ReadOnlySpan < char > token , bool considerSpecialTokens = true ) => _vocab . TryGetValue ( token , out int value ) ? value : null ;
225
215
226
216
/// <summary>
227
217
/// Map the encoded Id to the token.
@@ -242,24 +232,27 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
242
232
/// <summary>
243
233
/// Gets the dictionary mapping tokens to Ids.
244
234
/// </summary>
245
- public IReadOnlyDictionary < string , int > Vocab => _vocab ;
235
+ public IReadOnlyDictionary < string , int > Vocab => _vocabOriginal ??= _vocab . ToDictionary ( kvp => kvp . Key . Data ! , kvp => kvp . Value ) ;
246
236
247
237
/// Read the given files to extract the vocab and merges
248
- internal static ( Dictionary < string , int > ? , Vec < ( string , string ) > ) ReadModelData ( Stream vocab , Stream ? merges )
238
+ internal static ( Dictionary < StringSpanOrdinalKey , int > ? , Vec < ( string , string ) > ) ReadModelData ( Stream vocab , Stream ? merges )
249
239
{
250
- Dictionary < string , int > ? dic = JsonSerializer . Deserialize < Dictionary < string , int > > ( vocab ) as Dictionary < string , int > ;
240
+ JsonSerializerOptions options = new ( ) { Converters = { StringSpanOrdinalKeyConverter . Instance } } ;
241
+ Dictionary < StringSpanOrdinalKey , int > ? dic = JsonSerializer . Deserialize < Dictionary < StringSpanOrdinalKey , int > > ( vocab , options ) as Dictionary < StringSpanOrdinalKey , int > ;
251
242
252
243
return ( dic , ConvertMergesToHashmap ( merges ) ) ;
253
244
}
254
245
255
246
/// The vocabulary assigns a number to each token.
256
- private readonly Dictionary < string , int > _vocab ;
247
+ private readonly Dictionary < StringSpanOrdinalKey , int > _vocab ;
248
+
249
+ private Dictionary < string , int > ? _vocabOriginal ;
257
250
258
251
/// Contains the mapping between Pairs and their (rank, newId).
259
252
internal Dictionary < Pair < int > , ( int , int ) > Merges { get ; }
260
253
261
254
/// Contains the cache for optimizing the encoding step.
262
- internal Cache < string , Word > ? Cache { get ; }
255
+ internal StringSpanOrdinalKeyCache < Word > ? Cache { get ; }
263
256
264
257
internal static readonly int DefaultCacheCapacity = 10_000 ;
265
258
@@ -309,9 +302,6 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(
309
302
return merges ;
310
303
}
311
304
312
- /// Reset the cache.
313
- internal void ClearCache ( ) => Cache ? . Clear ( ) ;
314
-
315
305
private readonly Dictionary < char , string > _charToString = new Dictionary < char , string > ( ) ;
316
306
317
307
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
@@ -327,38 +317,68 @@ internal string CharToString(char c)
327
317
return s ;
328
318
}
329
319
330
- internal Word MergeWord ( string w )
320
+ internal Word MergeWord ( ReadOnlySpan < char > w )
331
321
{
332
322
Word word = Word . WithCapacity ( w . Length ) ;
333
323
( int Id , int Len ) ? unk = null ;
334
324
int i = 0 ;
335
325
326
+ Span < char > buffer = stackalloc char [ 256 ] ;
327
+ scoped ReadOnlySpan < char > s ;
328
+
336
329
while ( i < w . Length )
337
330
{
338
331
int length ;
339
- string s ;
340
332
341
333
if ( Char . IsHighSurrogate ( w [ i ] ) && i < w . Length - 1 && Char . IsLowSurrogate ( w [ i + 1 ] ) )
342
334
{
343
335
length = 2 ;
344
- s = w . Substring ( i , length ) ;
336
+ s = w . Slice ( i , 2 ) ;
345
337
}
346
338
else
347
339
{
348
340
length = 1 ;
349
- s = CharToString ( w [ i ] ) ;
341
+ s = w . Slice ( i , 1 ) ;
350
342
}
351
343
352
344
// Add the `continuing_subword_prefix` if relevant
353
345
if ( i > 0 && ContinuingSubwordPrefix is not null )
354
346
{
355
- s = $ "{ ContinuingSubwordPrefix } { s } ";
347
+ if ( ContinuingSubwordPrefix . Length + s . Length <= buffer . Length )
348
+ {
349
+ ContinuingSubwordPrefix . AsSpan ( ) . CopyTo ( buffer ) ;
350
+ s . CopyTo ( buffer . Slice ( ContinuingSubwordPrefix . Length ) ) ;
351
+ s = buffer . Slice ( 0 , ContinuingSubwordPrefix . Length + s . Length ) ;
352
+ }
353
+ else
354
+ {
355
+ #if NETCOREAPP
356
+ s = $ "{ ContinuingSubwordPrefix } { s } ". AsSpan ( ) ;
357
+ #else
358
+ string s1 = s . Length == 1 ? CharToString ( s [ 0 ] ) : s . ToString ( ) ;
359
+ s = $ "{ ContinuingSubwordPrefix } { s1 } ". AsSpan ( ) ;
360
+ #endif
361
+ }
356
362
}
357
363
358
364
// Add the `end_of_word_suffix` if relevant
359
365
if ( i + length >= w . Length && EndOfWordSuffix is not null )
360
366
{
361
- s = $ "{ s } { EndOfWordSuffix } ";
367
+ if ( s . Length + EndOfWordSuffix . Length <= buffer . Length )
368
+ {
369
+ s . CopyTo ( buffer ) ;
370
+ EndOfWordSuffix . AsSpan ( ) . CopyTo ( buffer . Slice ( s . Length ) ) ;
371
+ s = buffer . Slice ( 0 , s . Length + EndOfWordSuffix . Length ) ;
372
+ }
373
+ else
374
+ {
375
+ #if NETCOREAPP
376
+ s = $ "{ s } { EndOfWordSuffix } ". AsSpan ( ) ;
377
+ #else
378
+ string s1 = s . Length == 1 ? CharToString ( s [ 0 ] ) : s . ToString ( ) ;
379
+ s = $ "{ s1 } { EndOfWordSuffix } ". AsSpan ( ) ;
380
+ #endif
381
+ }
362
382
}
363
383
364
384
if ( _vocab . TryGetValue ( s , out int id ) )
@@ -419,17 +439,17 @@ internal List<Token> EncodeWithCache(string text)
419
439
Word word ;
420
440
if ( Cache is not null )
421
441
{
422
- if ( Cache . TryGet ( text , out word ) )
442
+ if ( Cache . TryGetValue ( text , out word ) )
423
443
{
424
444
return WordToTokens ( ref word ) ;
425
445
}
426
446
427
- word = MergeWord ( text ) ;
447
+ word = MergeWord ( text . AsSpan ( ) ) ;
428
448
Cache . Set ( text , word ) ;
429
449
}
430
450
else
431
451
{
432
- word = MergeWord ( text ) ;
452
+ word = MergeWord ( text . AsSpan ( ) ) ;
433
453
}
434
454
435
455
return WordToTokens ( ref word ) ;
@@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
445
465
return word . SymbolsCount ;
446
466
}
447
467
448
- internal int EncodeToIdsWithCache ( string text , IList < int > ? accumulatedIds )
468
+ internal int EncodeToIdsWithCache ( ReadOnlySpan < char > text , IList < int > ? accumulatedIds )
449
469
{
450
470
Word word ;
451
471
452
472
if ( Cache is not null )
453
473
{
454
- if ( Cache . TryGet ( text , out Word hit ) )
474
+ if ( Cache . TryGetValue ( text , out Word hit ) )
455
475
{
456
476
return WordToIds ( ref hit , accumulatedIds ) ;
457
477
}
458
478
459
479
word = MergeWord ( text ) ;
460
- Cache . Set ( text , word ) ;
480
+ Cache . Set ( text . ToString ( ) , word ) ;
461
481
}
462
482
else
463
483
{
0 commit comments