@@ -21,6 +21,7 @@ public class TikTokenizer : ITokenizer
21
21
{
22
22
23
23
private IReadOnlyDictionary < string , int > SpecialTokensEncoder = null ! ;
24
+ private IReadOnlyCollection < string > SpecialTokens = null ! ;
24
25
private Regex Regex = null ! ;
25
26
private IReadOnlyDictionary < byte [ ] , int > Encoder = null ! ;
26
27
private IReadOnlyDictionary < int , byte [ ] > Decoder = null ! ;
@@ -76,6 +77,7 @@ private void Init(IReadOnlyDictionary<byte[], int> encoder, IReadOnlyDictionary<
76
77
Regex = new Regex ( pattern , RegexOptions . Compiled ) ;
77
78
SpecialTokensRegex = new Regex ( string . Join ( "|" , specialTokensEncoder . Keys . Select ( s => Regex . Escape ( s ) ) ) , RegexOptions . Compiled ) ;
78
79
SpecialTokensEncoder = specialTokensEncoder ;
80
+ SpecialTokens = specialTokensEncoder . Keys . ToList ( ) ;
79
81
80
82
Decoder = Encoder . ToDictionary ( kvp => kvp . Value , kvp => kvp . Key ) ;
81
83
@@ -136,13 +138,7 @@ private Dictionary<byte[], int> LoadTikTokenBpe(Stream tikTokenBpeFileStream)
136
138
return bpeDict ;
137
139
}
138
140
139
- /// <summary>
140
- /// Encode a string with a set of allowed special tokens that are not broken apart.
141
- /// </summary>
142
- /// <param name="text">String to be encoded</param>
143
- /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
144
- /// <returns>List of token ids</returns>
145
- public List < int > Encode ( string text , IReadOnlyCollection < string > allowedSpecial )
141
+ private List < int > EncodeInternal ( string text , IReadOnlyCollection < string > allowedSpecial )
146
142
{
147
143
var tokenIds = new List < int > ( ) ;
148
144
int start = 0 ;
@@ -173,6 +169,43 @@ public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
173
169
return tokenIds ;
174
170
}
175
171
172
+ /// <summary>
173
+ /// Encode a string with a set of allowed special tokens that are not broken apart.
174
+ /// </summary>
175
+ /// <param name="text">String to be encoded</param>
176
+ /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
177
+ /// <returns>List of token ids</returns>
178
+ public List < int > Encode ( string text , IReadOnlyCollection < string > allowedSpecial )
179
+ {
180
+ if ( allowedSpecial is null || allowedSpecial . Count == 0 )
181
+ {
182
+ return Encode ( text , false ) ;
183
+ }
184
+ return EncodeInternal ( text , allowedSpecial ) ;
185
+ }
186
+
187
+ /// <summary>
188
+ /// Encode a string with or without special tokens set through constructor.
189
+ /// </summary>
190
+ /// <param name="text">String to be encoded</param>
191
+ /// <param name="applySpecialTokens">Whether to apply special token processing</param>
192
+ /// <returns></returns>
193
+ public List < int > Encode ( string text , bool applySpecialTokens = true )
194
+ {
195
+
196
+ if ( applySpecialTokens && SpecialTokens . Count > 0 )
197
+ {
198
+ return EncodeInternal ( text , SpecialTokens ) ;
199
+ }
200
+
201
+ var tokenIds = new List < int > ( ) ;
202
+ int start = 0 ;
203
+ Encode ( text , tokenIds , start , text . Length ) ;
204
+
205
+ return tokenIds ;
206
+
207
+ }
208
+
176
209
/// <summary>
177
210
/// Encode a special token matched through regex.
178
211
/// </summary>
@@ -194,7 +227,7 @@ private int EncodeSpecialToken(List<int> tokenIds, Match nextSpecial)
194
227
/// <param name="start">Start search index in the string</param>
195
228
/// <param name="nextSpecial">The regex match of a special token</param>
196
229
/// <param name="end">The index of the special token matched or the end of the text</param>
197
- private void FindNextSpecialToken ( string text , IReadOnlyCollection < string > allowedSpecial , int start , out Match nextSpecial , out int end )
230
+ private void FindNextSpecialToken ( string text , IReadOnlyCollection < string > ? allowedSpecial , int start , out Match nextSpecial , out int end )
198
231
{
199
232
int startFind = start ;
200
233
while ( true )
@@ -308,14 +341,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
308
341
return ( tokenCount , encodeLength ) ;
309
342
}
310
343
311
- /// <summary>
312
- /// Encode a piece of text limited by max token count through trimming suffix
313
- /// </summary>
314
- /// <param name="text">Text to be encoded</param>
315
- /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
316
- /// <param name="maxTokenCount">The max token count</param>
317
- /// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
318
- public ( List < int > TokenIds , string Text ) EncodeTrimSuffix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
344
+ private ( List < int > TokenIds , string Text ) EncodeTrimSuffixInternal ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
319
345
{
320
346
var tokenIds = new List < int > ( ) ;
321
347
@@ -367,21 +393,58 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
367
393
}
368
394
369
395
/// <summary>
370
- /// Encode a piece of text limited by max token count through trimming prefix
396
+ /// Encode a piece of text limited by max token count through trimming suffix
371
397
/// </summary>
372
398
/// <param name="text">Text to be encoded</param>
373
399
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
374
400
/// <param name="maxTokenCount">The max token count</param>
375
- /// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
376
- public ( List < int > TokenIds , string Text ) EncodeTrimPrefix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
401
+ /// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
402
+ public ( List < int > TokenIds , string Text ) EncodeTrimSuffix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
403
+ {
404
+ if ( allowedSpecial is null || allowedSpecial . Count == 0 )
405
+ {
406
+ return EncodeTrimSuffix ( text , maxTokenCount , false ) ;
407
+ }
408
+
409
+ return EncodeTrimSuffixInternal ( text , allowedSpecial , maxTokenCount ) ;
410
+
411
+ }
412
+
413
+ /// <summary>
414
+ /// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor.
415
+ /// </summary>
416
+ /// <param name="text">String to be encoded</param>
417
+ /// <param name="maxTokenCount">The max token count</param>
418
+ /// <param name="applySpecialTokens">Whether to apply special token processing</param>
419
+ /// <returns></returns>
420
+ public ( List < int > TokenIds , string Text ) EncodeTrimSuffix ( string text , int maxTokenCount , bool applySpecialTokens = true )
421
+ {
422
+ if ( applySpecialTokens && SpecialTokens . Count > 0 )
423
+ {
424
+ return EncodeTrimSuffixInternal ( text , SpecialTokens , maxTokenCount ) ;
425
+ }
426
+
427
+ var tokenIds = new List < int > ( ) ;
428
+ int start = 0 ;
429
+ int tokenCount = 0 ;
430
+ var encodeLength = 0 ;
431
+ ( _ , encodeLength ) = EncodeTrimSuffix ( text , tokenIds , start , text . Length , maxTokenCount , tokenCount , encodeLength ) ;
432
+ var encodedText = encodeLength == text . Length ? text : text [ ..encodeLength ] ;
433
+
434
+ return ( tokenIds , encodedText ) ;
435
+ }
436
+
437
+ private ( List < int > TokenIds , string Text ) EncodeTrimPrefixInternal ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
377
438
{
378
439
var tokenIds = new List < int > ( ) ;
379
440
380
441
int start = 0 ;
381
442
int tokenCount = 0 ;
382
443
var encodeLength = 0 ;
383
- var tokenCountMap = new SortedDictionary < int , int > ( ) ;
384
- tokenCountMap . Add ( tokenCount , encodeLength ) ;
444
+ var tokenCountMap = new SortedDictionary < int , int >
445
+ {
446
+ { tokenCount , encodeLength }
447
+ } ;
385
448
while ( true )
386
449
{
387
450
Match nextSpecial ;
@@ -390,39 +453,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
390
453
391
454
if ( end > start )
392
455
{
393
- foreach ( Match match in Regex . Matches ( text [ start ..end ] ) )
394
- {
395
- var piece = match . Value ;
396
-
397
- if ( this . Cache . Lookup ( match . Value , out int [ ] tokens ) )
398
- {
399
- tokenCount += tokens . Length ;
400
- encodeLength += piece . Length ;
401
- tokenIds . AddRange ( tokens ) ;
402
- tokenCountMap [ tokenCount ] = encodeLength ;
403
- }
404
- else
405
- {
406
- var bytes = Encoding . UTF8 . GetBytes ( piece ) ;
407
- if ( Encoder . TryGetValue ( bytes , out int token ) )
408
- {
409
- tokenCount ++ ;
410
- encodeLength += piece . Length ;
411
- tokenIds . Add ( token ) ;
412
- tokenCountMap [ tokenCount ] = encodeLength ;
413
-
414
- }
415
- else
416
- {
417
- var encodedTokens = BytePairEncoder . BytePairEncode ( bytes , Encoder ) ;
418
- this . Cache . Add ( piece , encodedTokens . ToArray ( ) ) ;
419
- tokenCount += encodedTokens . Count ;
420
- encodeLength += piece . Length ;
421
- tokenIds . AddRange ( encodedTokens ) ;
422
- tokenCountMap [ tokenCount ] = encodeLength ;
423
- }
424
- }
425
- }
456
+ Encode ( text , tokenIds , start , ref tokenCount , ref encodeLength , tokenCountMap , end ) ;
426
457
}
427
458
428
459
if ( nextSpecial . Success )
@@ -442,6 +473,11 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
442
473
}
443
474
}
444
475
476
+ return TrimPrefix ( text , maxTokenCount , tokenIds , tokenCount , tokenCountMap ) ;
477
+ }
478
+
479
+ private static ( List < int > TokenIds , string Text ) TrimPrefix ( string text , int maxTokenCount , List < int > tokenIds , int tokenCount , SortedDictionary < int , int > tokenCountMap )
480
+ {
445
481
if ( tokenCount <= maxTokenCount )
446
482
{
447
483
return ( tokenIds , text ) ;
@@ -463,6 +499,85 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
463
499
return ( tokenIds . Skip ( actualPrefixTokenCount ) . ToList ( ) , text [ actualPrefixStrLength ..] ) ;
464
500
}
465
501
502
+ private void Encode ( string text , List < int > tokenIds , int start , ref int tokenCount , ref int encodeLength , SortedDictionary < int , int > tokenCountMap , int end )
503
+ {
504
+ foreach ( Match match in Regex . Matches ( text [ start ..end ] ) )
505
+ {
506
+ var piece = match . Value ;
507
+
508
+ if ( this . Cache . Lookup ( match . Value , out int [ ] tokens ) )
509
+ {
510
+ tokenCount += tokens . Length ;
511
+ encodeLength += piece . Length ;
512
+ tokenIds . AddRange ( tokens ) ;
513
+ tokenCountMap [ tokenCount ] = encodeLength ;
514
+ }
515
+ else
516
+ {
517
+ var bytes = Encoding . UTF8 . GetBytes ( piece ) ;
518
+ if ( Encoder . TryGetValue ( bytes , out int token ) )
519
+ {
520
+ tokenCount ++ ;
521
+ encodeLength += piece . Length ;
522
+ tokenIds . Add ( token ) ;
523
+ tokenCountMap [ tokenCount ] = encodeLength ;
524
+
525
+ }
526
+ else
527
+ {
528
+ var encodedTokens = BytePairEncoder . BytePairEncode ( bytes , Encoder ) ;
529
+ this . Cache . Add ( piece , encodedTokens . ToArray ( ) ) ;
530
+ tokenCount += encodedTokens . Count ;
531
+ encodeLength += piece . Length ;
532
+ tokenIds . AddRange ( encodedTokens ) ;
533
+ tokenCountMap [ tokenCount ] = encodeLength ;
534
+ }
535
+ }
536
+ }
537
+ }
538
+
539
+ /// <summary>
540
+ /// Encode a piece of text limited by max token count through trimming prefix
541
+ /// </summary>
542
+ /// <param name="text">Text to be encoded</param>
543
+ /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
544
+ /// <param name="maxTokenCount">The max token count</param>
545
+ /// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
546
+ public ( List < int > TokenIds , string Text ) EncodeTrimPrefix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
547
+ {
548
+ if ( allowedSpecial is null || allowedSpecial . Count == 0 )
549
+ {
550
+ return EncodeTrimPrefix ( text , maxTokenCount , false ) ;
551
+ }
552
+ return EncodeTrimPrefixInternal ( text , allowedSpecial , maxTokenCount ) ;
553
+ }
554
+
555
+ /// <summary>
556
+ /// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor.
557
+ /// </summary>
558
+ /// <param name="text">Text to be encoded</param>
559
+ /// <param name="maxTokenCount">The max token count</param>
560
+ /// <param name="applySpecialTokens">Whether to apply special token processing</param>
561
+ /// <returns></returns>
562
+ public ( List < int > TokenIds , string Text ) EncodeTrimPrefix ( string text , int maxTokenCount , bool applySpecialTokens = true )
563
+ {
564
+ if ( applySpecialTokens && SpecialTokens . Count > 0 )
565
+ {
566
+ return EncodeTrimPrefixInternal ( text , SpecialTokens , maxTokenCount ) ;
567
+ }
568
+ var tokenIds = new List < int > ( ) ;
569
+
570
+ int start = 0 ;
571
+ int tokenCount = 0 ;
572
+ var encodeLength = 0 ;
573
+ var tokenCountMap = new SortedDictionary < int , int >
574
+ {
575
+ { tokenCount , encodeLength }
576
+ } ;
577
+ Encode ( text , tokenIds , start , ref tokenCount , ref encodeLength , tokenCountMap , text . Length ) ;
578
+ return TrimPrefix ( text , maxTokenCount , tokenIds , tokenCount , tokenCountMap ) ;
579
+ }
580
+
466
581
/// <summary>
467
582
/// Decode an array of integer token ids
468
583
/// </summary>
0 commit comments