1313import java .util .*;
1414import java .util .regex .Pattern ;
1515import java .util .stream .Collectors ;
16+ import java .util .stream .IntStream ;
1617
1718class EncodingFactory {
19+ private static final Map <String , Integer > SPECIAL_TOKENS_O200K_HARMONY ;
1820 private static final Map <String , Integer > SPECIAL_TOKENS_O200K_BASE ;
1921 private static final Map <String , Integer > SPECIAL_TOKENS_CL100K_BASE ;
2022 private static final Map <String , Integer > SPECIAL_TOKENS_X50K_BASE ;
@@ -58,6 +60,30 @@ class EncodingFactory {
5860 SPECIAL_TOKENS_O200K_BASE = Collections .unmodifiableMap (map );
5961 }
6062
63+ static {
64+ Map <String , Integer > map = new HashMap <>(SPECIAL_TOKENS_O200K_BASE );
65+ map .put ("<|startoftext|>" , 199998 );
66+ map .put (ENDOFTEXT , 199999 );
67+ map .put ("<|reserved_200000|>" , 200000 );
68+ map .put ("<|reserved_200001|>" , 200001 );
69+ map .put ("<|return|>" , 200002 );
70+ map .put ("<|constrain|>" , 200003 );
71+ map .put ("<|reserved_200004|>" , 200004 );
72+ map .put ("<|channel|>" , 200005 );
73+ map .put ("<|start|>" , 200006 );
74+ map .put ("<|end|>" , 200007 );
75+ map .put ("<|message|>" , 200008 );
76+ map .put ("<|reserved_200009|>" , 200009 );
77+ map .put ("<|reserved_200010|>" , 200010 );
78+ map .put ("<|reserved_200011|>" , 200011 );
79+ map .put ("<|call|>" , 200012 );
80+
81+ IntStream .range (200013 , 201088 )
82+ .forEach ( i -> map .put ("<|reserved_" + i + "|>" , i ));
83+
84+ SPECIAL_TOKENS_O200K_HARMONY = Collections .unmodifiableMap (map );
85+ }
86+
6187 private EncodingFactory () {
6288 }
6389
@@ -119,18 +145,17 @@ static Encoding cl100kBase() {
119145 */
120146
121147 static Encoding o200kBase () {
122- Map <byte [], Integer > mergeableRanks = loadMergeableRanks ("/com/knuddels/jtokkit/o200k_base.tiktoken" );
123- List <String > patStrList = new ArrayList <>();
124- patStrList .add ("[^\\ r\\ n\\ p{L}\\ p{N}]?[\\ p{Lu}\\ p{Lt}\\ p{Lm}\\ p{Lo}\\ p{M}]*[\\ p{Ll}\\ p{Lm}\\ p{Lo}\\ p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?" );
125- patStrList .add ("[^\\ r\\ n\\ p{L}\\ p{N}]?[\\ p{Lu}\\ p{Lt}\\ p{Lm}\\ p{Lo}\\ p{M}]+[\\ p{Ll}\\ p{Lm}\\ p{Lo}\\ p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?" );
126- patStrList .add ("\\ p{N}{1,3}" );
127- patStrList .add (" ?[^\\ s\\ p{L}\\ p{N}]+[\\ r\\ n/]*" );
128- patStrList .add ("\\ s*[\\ r\\ n]+" );
129- patStrList .add ("\\ s+(?!\\ S)" );
130- patStrList .add ("\\ s+" );
131- Pattern regex = compileRegex (patStrList .stream ().map (String ::valueOf ).collect (Collectors .joining ("|" )), false );
132- GptBytePairEncodingParams params = new GptBytePairEncodingParams ("o200k_base" , regex , mergeableRanks , SPECIAL_TOKENS_O200K_BASE );
133- return fromParameters (params );
148+ return from200kParameters ("o200k_base" , SPECIAL_TOKENS_O200K_BASE );
149+ }
150+
151+ /**
152+ * Returns an {@link Encoding} instance for the o200k_harmony encoding.
153+ *
154+ * @return an {@link Encoding} instance for the o200k_harmony encoding
155+ */
156+
157+ static Encoding o200kHarmony () {
158+ return from200kParameters ("o200k_harmony" , SPECIAL_TOKENS_O200K_HARMONY );
134159 }
135160
136161 /**
@@ -154,6 +179,24 @@ private static Encoding from50kParameters(
154179 return fromParameters (params );
155180 }
156181
182+ private static Encoding from200kParameters (
183+ String name ,
184+ Map <String , Integer > specialTokens
185+ ) {
186+ Map <byte [], Integer > mergeableRanks = loadMergeableRanks ("/com/knuddels/jtokkit/o200k_base.tiktoken" );
187+ List <String > patStrList = new ArrayList <>();
188+ patStrList .add ("[^\\ r\\ n\\ p{L}\\ p{N}]?[\\ p{Lu}\\ p{Lt}\\ p{Lm}\\ p{Lo}\\ p{M}]*[\\ p{Ll}\\ p{Lm}\\ p{Lo}\\ p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?" );
189+ patStrList .add ("[^\\ r\\ n\\ p{L}\\ p{N}]?[\\ p{Lu}\\ p{Lt}\\ p{Lm}\\ p{Lo}\\ p{M}]+[\\ p{Ll}\\ p{Lm}\\ p{Lo}\\ p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?" );
190+ patStrList .add ("\\ p{N}{1,3}" );
191+ patStrList .add (" ?[^\\ s\\ p{L}\\ p{N}]+[\\ r\\ n/]*" );
192+ patStrList .add ("\\ s*[\\ r\\ n]+" );
193+ patStrList .add ("\\ s+(?!\\ S)" );
194+ patStrList .add ("\\ s+" );
195+ Pattern regex = compileRegex (patStrList .stream ().map (String ::valueOf ).collect (Collectors .joining ("|" )), false );
196+ GptBytePairEncodingParams params = new GptBytePairEncodingParams (name , regex , mergeableRanks , specialTokens );
197+ return fromParameters (params );
198+ }
199+
157200 static Pattern compileRegex (String patternString , boolean caseInsensitive ) {
158201 try {
159202 int flags = Pattern .UNICODE_CHARACTER_CLASS ;
0 commit comments