Skip to content

Commit 5558e3e

Browse files
committed
feat: add o200k_harmony encoding
1 parent 55da1b1 commit 5558e3e

6 files changed

Lines changed: 173 additions & 15 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ the need to have similar capacities in the JVM ecosystem as the library
3030
## 🤖 Features
3131

3232
✅ Implements encoding and decoding via `r50k_base`, `p50k_base`, `p50k_edit`,
33-
`cl100k_base` and `o200k_base`
33+
`cl100k_base`, `o200k_base` and `o200k_harmony`
3434

3535
✅ Easy-to-use API
3636

docs/docs/getting-started/intro.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ JTokkit is a fast and efficient tokenizer designed for use in natural language p
88

99
## Features
1010

11-
✅ Implements encoding and decoding via `r50k_base`, `p50k_base`, `p50k_edit` and `cl100k_base`
11+
✅ Implements encoding and decoding via `r50k_base`, `p50k_base`, `p50k_edit`,
12+
`cl100k_base`, `o200k_base` and `o200k_harmony`
1213

1314
✅ Easy-to-use API
1415

lib/src/main/java/com/knuddels/jtokkit/AbstractEncodingRegistry.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ protected final void addEncoding(final EncodingType encodingType) {
9696
case O200K_BASE:
9797
encodings.computeIfAbsent(encodingType.getName(), k -> EncodingFactory.o200kBase());
9898
break;
99+
case O200K_HARMONY:
100+
encodings.computeIfAbsent(encodingType.getName(), k -> EncodingFactory.o200kHarmony());
101+
break;
99102
default:
100103
throw new IllegalStateException("Unknown encoding type " + encodingType.getName());
101104
}

lib/src/main/java/com/knuddels/jtokkit/EncodingFactory.java

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
import java.util.*;
1414
import java.util.regex.Pattern;
1515
import java.util.stream.Collectors;
16+
import java.util.stream.IntStream;
1617

1718
class EncodingFactory {
19+
private static final Map<String, Integer> SPECIAL_TOKENS_O200K_HARMONY;
1820
private static final Map<String, Integer> SPECIAL_TOKENS_O200K_BASE;
1921
private static final Map<String, Integer> SPECIAL_TOKENS_CL100K_BASE;
2022
private static final Map<String, Integer> SPECIAL_TOKENS_X50K_BASE;
@@ -58,6 +60,30 @@ class EncodingFactory {
5860
SPECIAL_TOKENS_O200K_BASE = Collections.unmodifiableMap(map);
5961
}
6062

63+
static {
64+
Map<String, Integer> map = new HashMap<>(SPECIAL_TOKENS_O200K_BASE);
65+
map.put("<|startoftext|>", 199998);
66+
map.put(ENDOFTEXT, 199999);
67+
map.put("<|reserved_200000|>", 200000);
68+
map.put("<|reserved_200001|>", 200001);
69+
map.put("<|return|>", 200002);
70+
map.put("<|constrain|>", 200003);
71+
map.put("<|reserved_200004|>", 200004);
72+
map.put("<|channel|>", 200005);
73+
map.put("<|start|>", 200006);
74+
map.put("<|end|>", 200007);
75+
map.put("<|message|>", 200008);
76+
map.put("<|reserved_200009|>", 200009);
77+
map.put("<|reserved_200010|>", 200010);
78+
map.put("<|reserved_200011|>", 200011);
79+
map.put("<|call|>", 200012);
80+
81+
IntStream.range(200013, 201088)
82+
.forEach( i -> map.put("<|reserved_" + i + "|>", i));
83+
84+
SPECIAL_TOKENS_O200K_HARMONY = Collections.unmodifiableMap(map);
85+
}
86+
6187
private EncodingFactory() {
6288
}
6389

@@ -119,18 +145,17 @@ static Encoding cl100kBase() {
119145
*/
120146

121147
static Encoding o200kBase() {
122-
Map<byte[], Integer> mergeableRanks = loadMergeableRanks("/com/knuddels/jtokkit/o200k_base.tiktoken");
123-
List<String> patStrList = new ArrayList<>();
124-
patStrList.add("[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?");
125-
patStrList.add("[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?");
126-
patStrList.add("\\p{N}{1,3}");
127-
patStrList.add(" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*");
128-
patStrList.add("\\s*[\\r\\n]+");
129-
patStrList.add("\\s+(?!\\S)");
130-
patStrList.add("\\s+");
131-
Pattern regex = compileRegex(patStrList.stream().map(String::valueOf).collect(Collectors.joining("|")), false);
132-
GptBytePairEncodingParams params = new GptBytePairEncodingParams("o200k_base", regex, mergeableRanks, SPECIAL_TOKENS_O200K_BASE);
133-
return fromParameters(params);
148+
return from200kParameters("o200k_base", SPECIAL_TOKENS_O200K_BASE);
149+
}
150+
151+
/**
152+
* Returns an {@link Encoding} instance for the o200k_harmony encoding.
153+
*
154+
* @return an {@link Encoding} instance for the o200k_harmony encoding
155+
*/
156+
157+
static Encoding o200kHarmony() {
158+
return from200kParameters("o200k_harmony", SPECIAL_TOKENS_O200K_HARMONY);
134159
}
135160

136161
/**
@@ -154,6 +179,24 @@ private static Encoding from50kParameters(
154179
return fromParameters(params);
155180
}
156181

182+
private static Encoding from200kParameters(
183+
String name,
184+
Map<String, Integer> specialTokens
185+
) {
186+
Map<byte[], Integer> mergeableRanks = loadMergeableRanks("/com/knuddels/jtokkit/o200k_base.tiktoken");
187+
List<String> patStrList = new ArrayList<>();
188+
patStrList.add("[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?");
189+
patStrList.add("[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?");
190+
patStrList.add("\\p{N}{1,3}");
191+
patStrList.add(" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*");
192+
patStrList.add("\\s*[\\r\\n]+");
193+
patStrList.add("\\s+(?!\\S)");
194+
patStrList.add("\\s+");
195+
Pattern regex = compileRegex(patStrList.stream().map(String::valueOf).collect(Collectors.joining("|")), false);
196+
GptBytePairEncodingParams params = new GptBytePairEncodingParams(name, regex, mergeableRanks, specialTokens);
197+
return fromParameters(params);
198+
}
199+
157200
static Pattern compileRegex(String patternString, boolean caseInsensitive) {
158201
try {
159202
int flags = Pattern.UNICODE_CHARACTER_CLASS;

lib/src/main/java/com/knuddels/jtokkit/api/EncodingType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ public enum EncodingType {
1111
P50K_BASE("p50k_base"),
1212
P50K_EDIT("p50k_edit"),
1313
CL100K_BASE("cl100k_base"),
14-
O200K_BASE("o200k_base");
14+
O200K_BASE("o200k_base"),
15+
O200K_HARMONY("o200k_harmony");
1516

1617
private static final Map<String, EncodingType> nameToEncodingType = Arrays.stream(values())
1718
.collect(Collectors.toMap(EncodingType::getName, Function.identity()));
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package com.knuddels.jtokkit.reference;
2+
3+
import com.knuddels.jtokkit.Encodings;
4+
import com.knuddels.jtokkit.api.Encoding;
5+
import com.knuddels.jtokkit.api.EncodingType;
6+
import org.junit.jupiter.api.Test;
7+
import org.junit.jupiter.params.ParameterizedTest;
8+
import org.junit.jupiter.params.provider.CsvFileSource;
9+
10+
import static org.junit.jupiter.api.Assertions.assertEquals;
11+
import static org.junit.jupiter.api.Assertions.assertTrue;
12+
13+
class O200kHarmonyTest {
14+
15+
private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.O200K_HARMONY);
16+
17+
@ParameterizedTest
18+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
19+
void o200kBaseEncodesCorrectly(
20+
String input,
21+
String output
22+
) {
23+
var expected = TestUtils.parseEncodingString(output);
24+
var actual = ENCODING.encode(input);
25+
26+
assertEquals(expected, actual);
27+
}
28+
29+
@ParameterizedTest
30+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
31+
void o200kHarmonyEncodesStable(String input) {
32+
var actual = ENCODING.decode(ENCODING.encode(input));
33+
34+
assertEquals(input, actual);
35+
}
36+
37+
@ParameterizedTest
38+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
39+
void o200kHarmonyEncodesCorrectlyWithMaxTokensSet(
40+
String input,
41+
String output,
42+
String outputMaxTokens10
43+
) {
44+
var expected = TestUtils.parseEncodingString(output);
45+
var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10);
46+
var encodingResult = ENCODING.encode(input, 10);
47+
48+
assertEquals(expectedWithMaxTokens, encodingResult.getTokens());
49+
assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated());
50+
}
51+
52+
@ParameterizedTest
53+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
54+
void o200kHarmonyEncodesStableWithMaxTokensSet(String input) {
55+
var actual = ENCODING.decode(ENCODING.encode(input, 10).getTokens());
56+
57+
assertTrue(input.startsWith(actual));
58+
}
59+
60+
@ParameterizedTest
61+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
62+
void o200kHarmonyEncodeOrdinaryEncodesCorrectly(
63+
String input,
64+
String output
65+
) {
66+
var expected = TestUtils.parseEncodingString(output);
67+
var actual = ENCODING.encodeOrdinary(input);
68+
69+
assertEquals(expected, actual);
70+
}
71+
72+
@ParameterizedTest
73+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
74+
void o200kHarmonyEncodeOrdinaryEncodesCorrectly(
75+
String input,
76+
String output,
77+
String outputMaxTokens10
78+
) {
79+
var expected = TestUtils.parseEncodingString(output);
80+
var expectedWithMaxTokens = TestUtils.parseEncodingString(outputMaxTokens10);
81+
var encodingResult = ENCODING.encodeOrdinary(input, 10);
82+
83+
assertEquals(expectedWithMaxTokens, encodingResult.getTokens());
84+
assertEquals(expected.size() > expectedWithMaxTokens.size(), encodingResult.isTruncated());
85+
}
86+
87+
@ParameterizedTest
88+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
89+
void o200kHarmonyEncodeOrdinaryEncodesStable(String input) {
90+
var actual = ENCODING.decode(ENCODING.encodeOrdinary(input));
91+
92+
assertEquals(input, actual);
93+
}
94+
95+
@ParameterizedTest
96+
@CsvFileSource(resources = "/o200k_base_encodings.csv", numLinesToSkip = 1, maxCharsPerColumn = 1_000_000)
97+
void o200kHarmonyEncodeOrdinaryEncodesStableWithMaxTokensSet(String input) {
98+
var actual = ENCODING.decode(ENCODING.encodeOrdinary(input, 10).getTokens());
99+
100+
assertTrue(input.startsWith(actual));
101+
}
102+
103+
@Test
104+
void o200kHarmonyEncodeOrdinaryEncodesSpecialTokensCorrectly() {
105+
var input = "<|startoftext|>Hello<|endoftext|>, <|start|> <|end|> world <|reserved_201088|> ! <|endofprompt|>";
106+
var actual = ENCODING.decode(ENCODING.encodeOrdinary(input));
107+
108+
assertEquals(input, actual);
109+
}
110+
}

0 commit comments

Comments
 (0)