|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license. |
3 | 3 | // See the LICENSE file in the project root for more information. |
4 | 4 |
|
5 | | -using Microsoft.DotNet.RemoteExecutor; |
6 | 5 | using System; |
7 | 6 | using System.Buffers; |
8 | 7 | using System.Collections.Generic; |
|
13 | 12 | using System.Text; |
14 | 13 | using System.Text.Json; |
15 | 14 | using System.Threading.Tasks; |
| 15 | +using Microsoft.DotNet.RemoteExecutor; |
16 | 16 | using Xunit; |
17 | 17 |
|
18 | 18 | namespace Microsoft.ML.Tokenizers.Tests |
@@ -848,6 +848,90 @@ public void TestOss() |
848 | 848 |
|
849 | 849 | private static IReadOnlyDictionary<string, int>? GetVocabulary(TiktokenTokenizer tiktoken) |
850 | 850 | => typeof(TiktokenTokenizer).GetProperty("Vocabulary", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<string, int>; |
| 851 | + |
| 852 | + [Fact] |
| 853 | + public void TestLargeInputOptimization() |
| 854 | + { |
| 855 | + // Test that large inputs (>128 bytes) are handled correctly and produce same results as small input path |
| 856 | + // This tests the heap-based algorithm added for performance |
| 857 | + |
| 858 | + // Test with repeated characters - this is the adversarial case that caused O(n^2) behavior |
| 859 | + string largeRepeatedInput = new string('a', 1000); |
| 860 | + IReadOnlyList<int> ids = GPT4.EncodeToIds(largeRepeatedInput); |
| 861 | + string decoded = GPT4.Decode(ids); |
| 862 | + Assert.Equal(largeRepeatedInput, decoded); |
| 863 | + |
| 864 | + // Test with a more realistic large input |
| 865 | + string largeMixedInput = string.Join(" ", Enumerable.Repeat("Hello World! This is a test.", 50)); |
| 866 | + IReadOnlyList<int> mixedIds = GPT4.EncodeToIds(largeMixedInput); |
| 867 | + string mixedDecoded = GPT4.Decode(mixedIds); |
| 868 | + Assert.Equal(largeMixedInput, mixedDecoded); |
| 869 | + |
| 870 | + // Test boundary case - exactly at threshold (128) |
| 871 | + string boundaryInput = new string('x', 128); |
| 872 | + IReadOnlyList<int> boundaryIds = GPT4.EncodeToIds(boundaryInput); |
| 873 | + string boundaryDecoded = GPT4.Decode(boundaryIds); |
| 874 | + Assert.Equal(boundaryInput, boundaryDecoded); |
| 875 | + |
| 876 | + // Test just below threshold (127) |
| 877 | + string belowThresholdInput = new string('x', 127); |
| 878 | + IReadOnlyList<int> belowIds = GPT4.EncodeToIds(belowThresholdInput); |
| 879 | + string belowDecoded = GPT4.Decode(belowIds); |
| 880 | + Assert.Equal(belowThresholdInput, belowDecoded); |
| 881 | + |
| 882 | + // Test just above threshold (129) |
| 883 | + string aboveThresholdInput = new string('x', 129); |
| 884 | + IReadOnlyList<int> aboveIds = GPT4.EncodeToIds(aboveThresholdInput); |
| 885 | + string aboveDecoded = GPT4.Decode(aboveIds); |
| 886 | + Assert.Equal(aboveThresholdInput, aboveDecoded); |
| 887 | + } |
| 888 | + |
| 889 | + [Theory] |
| 890 | + [InlineData(200)] |
| 891 | + [InlineData(500)] |
| 892 | + [InlineData(1000)] |
| 893 | + [InlineData(2000)] |
| 894 | + public void TestLargeInputConsistency(int length) |
| 895 | + { |
| 896 | + // Verify that large and small inputs with same content produce identical tokens |
| 897 | + // This ensures the heap-based algorithm produces the same results as the original |
| 898 | + |
| 899 | + string input = new string('z', length); |
| 900 | + IReadOnlyList<int> ids = GPT4.EncodeToIds(input); |
| 901 | + |
| 902 | + // Verify round-trip |
| 903 | + string decoded = GPT4.Decode(ids); |
| 904 | + Assert.Equal(input, decoded); |
| 905 | + |
| 906 | + // Verify with EncodingToTokens as well |
| 907 | + IReadOnlyList<EncodedToken> tokens = GPT4.EncodeToTokens(input, out string? normalizedText); |
| 908 | + Assert.Null(normalizedText); // No normalization expected |
| 909 | + |
| 910 | + // Reconstruct from tokens |
| 911 | + var reconstructed = string.Concat(tokens.Select(t => t.Value)); |
| 912 | + Assert.Equal(input, reconstructed); |
| 913 | + } |
| 914 | + |
| 915 | + [Fact] |
| 916 | + public void TestLargeInputPerformance() |
| 917 | + { |
| 918 | + // Test that very large inputs complete in reasonable time |
| 919 | + // This would timeout or take extremely long with O(n^2) algorithm |
| 920 | + |
| 921 | + string veryLargeInput = new string('a', 5000); |
| 922 | + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); |
| 923 | + IReadOnlyList<int> ids = GPT4.EncodeToIds(veryLargeInput); |
| 924 | + stopwatch.Stop(); |
| 925 | + |
| 926 | + // Should complete in well under a second even for 5000 chars |
| 927 | + // With O(n^2) this could take several seconds |
| 928 | + Assert.True(stopwatch.ElapsedMilliseconds < 5000, |
| 929 | + $"Large input encoding took {stopwatch.ElapsedMilliseconds}ms, expected < 5000ms"); |
| 930 | + |
| 931 | + // Verify correctness |
| 932 | + string decoded = GPT4.Decode(ids); |
| 933 | + Assert.Equal(veryLargeInput, decoded); |
| 934 | + } |
851 | 935 | } |
852 | 936 | } |
853 | 937 |
|
0 commit comments