Skip to content

Commit f976424

Browse files
authored
First round of perf improvements for tiktoken (#7012)
1 parent 4635a86 commit f976424

13 files changed

+376
-374
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
#if NET5_0_OR_GREATER
6+
[module: System.Runtime.CompilerServices.SkipLocalsInit]
7+
#endif

src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
<TargetFrameworks>netstandard2.0;net8.0</TargetFrameworks>
66
<Nullable>enable</Nullable>
77
<PackageDescription>Microsoft.ML.Tokenizers contains the implmentation of the tokenization used in the NLP transforms.</PackageDescription>
8+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
89
</PropertyGroup>
910

1011
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">

src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs

+8-11
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
using System.Diagnostics;
99
using System.IO;
1010
using System.Linq;
11-
using System.Runtime.CompilerServices;
12-
using System.Text;
1311
using System.Text.Json;
14-
using System.Text.Json.Serialization;
1512

1613
namespace Microsoft.ML.Tokenizers
1714
{
@@ -27,7 +24,7 @@ public sealed class EnglishRoberta : Model
2724
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
2825
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
2926
private readonly string[] _charToString;
30-
private readonly Cache<string, IReadOnlyList<Token>> _cache;
27+
private readonly Cache<string, List<Token>> _cache;
3128

3229
/// <summary>
3330
/// Construct tokenizer object to use with the English Robert model.
@@ -72,7 +69,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
7269
}
7370

7471
_unicodeToByte = _byteToUnicode.Reverse();
75-
_cache = new Cache<string, IReadOnlyList<Token>>();
72+
_cache = new Cache<string, List<Token>>();
7673
}
7774

7875
/// <summary>
@@ -110,7 +107,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
110107
}
111108

112109
_unicodeToByte = _byteToUnicode.Reverse();
113-
_cache = new Cache<string, IReadOnlyList<Token>>();
110+
_cache = new Cache<string, List<Token>>();
114111
}
115112

116113
//
@@ -226,17 +223,17 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
226223
{
227224
ArrayPool<char>.Shared.Return(token);
228225
ArrayPool<int>.Shared.Return(indexMapping);
229-
return Bpe.EmptyTokensList;
226+
return Array.Empty<Token>();
230227
}
231228

232-
if (_cache.TryGet(sequence, out IReadOnlyList<Token>? hit))
229+
if (_cache.TryGet(sequence, out List<Token>? hit))
233230
{
234231
ArrayPool<char>.Shared.Return(token);
235232
ArrayPool<int>.Shared.Return(indexMapping);
236233
return ModifyTokenListOffsets(hit, indexMapping);
237234
}
238235

239-
IReadOnlyList<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
236+
List<Token> result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
240237
_cache.Set(sequence, result);
241238
ArrayPool<char>.Shared.Return(token);
242239
ArrayPool<int>.Shared.Return(indexMapping);
@@ -261,7 +258,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
261258

262259
private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
263260
{
264-
if (_cache.TryGet(sequence, out IReadOnlyList<Token>? hit))
261+
if (_cache.TryGet(sequence, out List<Token>? hit))
265262
{
266263
if (accumulatedIds is not null)
267264
{
@@ -299,7 +296,7 @@ private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
299296
return 0;
300297
}
301298

302-
IReadOnlyList<Token> result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping);
299+
List<Token> result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping);
303300
_cache.Set(sequence, result);
304301
return result.Count;
305302
}

0 commit comments

Comments
 (0)