Skip to content

Commit b6e12b6

Browse files
authored
Update encode method and add overloads (#3)
* changed the encode method and added overloads * updated readme
1 parent e64e1b0 commit b6e12b6

File tree

8 files changed

+66
-63
lines changed

8 files changed

+66
-63
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ const tokenizerConfig = await fetch(`https://huggingface.co/${modelId}/resolve/m
4444
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
4545

4646
// Tokenize text
47-
const tokens = tokenizer.tokenize('Hello World'); // ['Hello', 'ĠWorld']
48-
const encoded = tokenizer.encode('Hello World'); // [9906, 4435]
49-
const decoded = tokenizer.decode(encoded); // 'Hello World'
47+
const tokens = tokenizer.tokenize('Hello World'); // ['Hello', 'ĠWorld']
48+
const encoded = tokenizer.encode('Hello World'); // { ids: [9906, 4435], tokens: ['Hello', 'ĠWorld'], attention_mask: [1, 1] }
49+
const decoded = tokenizer.decode(encoded.ids); // 'Hello World'
5050
```
5151

5252
## Requirements

src/core/Tokenizer.ts

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,48 @@ class Tokenizer {
128128
this.config.do_lowercase_and_remove_accent ?? false;
129129
}
130130

131+
/**
132+
* Encodes a single text or a pair of texts using the model's tokenizer.
133+
*
134+
* @param text The text to encode.
135+
* @param options An optional object containing the following properties:
136+
* @returns An object containing the encoded text.
137+
*/
138+
139+
// Overload: when return_token_type_ids is explicitly true
140+
public encode(
141+
text: string,
142+
options: EncodeOptions & { return_token_type_ids: true },
143+
): EncodingSingle & { token_type_ids: number[] };
144+
145+
// Overload: when return_token_type_ids is false/null or not provided
146+
public encode(text: string, options?: EncodeOptions): EncodingSingle;
147+
148+
// Implementation
131149
public encode(
132150
text: string,
133151
{
134-
text_pair,
135-
add_special_tokens,
136-
return_token_type_ids,
152+
text_pair = null,
153+
add_special_tokens = true,
154+
return_token_type_ids = null,
137155
}: EncodeOptions = {},
138-
): Array<number> {
139-
return this.encode_plus(text, {
156+
): EncodingSingle {
157+
const { tokens, token_type_ids } = this.tokenize_helper(text, {
140158
text_pair,
141159
add_special_tokens,
142-
return_token_type_ids,
143-
}).input_ids;
160+
});
161+
162+
const input_ids = this.model.convert_tokens_to_ids(tokens);
163+
const result: EncodingSingle = {
164+
ids: input_ids,
165+
tokens,
166+
attention_mask: new Array(input_ids.length).fill(1),
167+
};
168+
169+
if (return_token_type_ids && token_type_ids) {
170+
result.token_type_ids = token_type_ids;
171+
}
172+
return result;
144173
}
145174

146175
public decode(
@@ -198,40 +227,6 @@ class Tokenizer {
198227
return this.tokenize_helper(text, { text_pair, add_special_tokens }).tokens;
199228
}
200229

201-
/**
202-
* Encodes a single text or a pair of texts using the model's tokenizer.
203-
*
204-
* @param text The text to encode.
205-
* @param options An optional object containing the following properties:
206-
* @returns An object containing the encoded text.
207-
* @private
208-
*/
209-
210-
private encode_plus(
211-
text: string,
212-
{
213-
text_pair = null,
214-
add_special_tokens = true,
215-
return_token_type_ids = null,
216-
}: EncodeOptions,
217-
): EncodingSingle {
218-
const { tokens, token_type_ids } = this.tokenize_helper(text, {
219-
text_pair,
220-
add_special_tokens,
221-
});
222-
223-
const input_ids = this.model.convert_tokens_to_ids(tokens);
224-
const result: EncodingSingle = {
225-
input_ids,
226-
attention_mask: new Array(input_ids.length).fill(1),
227-
};
228-
229-
if (return_token_type_ids && token_type_ids) {
230-
result.token_type_ids = token_type_ids;
231-
}
232-
return result;
233-
}
234-
235230
private encode_text(text: string | null): string[] | null {
236231
if (text === null) {
237232
return null;

src/static/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ export type DataType =
3333
| "int4";
3434

3535
export interface EncodingSingle {
36-
input_ids: number[];
36+
ids: number[];
37+
tokens: string[];
3738
attention_mask: number[];
3839
token_type_ids?: number[];
3940
}

tests/bundle.test.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,21 @@ const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
1414
// Tokenize text
1515
const tokens = tokenizer.tokenize('Hello World');
1616
const encoded = tokenizer.encode('Hello World');
17-
const decoded = tokenizer.decode(encoded);
17+
const decoded = tokenizer.decode(encoded.ids);
1818
1919
console.log(tokens);
2020
console.log(encoded);
2121
console.log(decoded);
2222
`;
2323

24-
const TARGET_OUTPUT = "[ '▁Hello', '▁World' ]\n[ 1, 15043, 2787 ]\n<s> Hello World\n";
24+
const TARGET_OUTPUT = `[ '▁Hello', '▁World' ]
25+
{
26+
ids: [ 1, 15043, 2787 ],
27+
tokens: [ '<s>', '▁Hello', '▁World' ],
28+
attention_mask: [ 1, 1, 1 ]
29+
}
30+
<s> Hello World
31+
`;
2532

2633
const wrap_async_iife = (code: string) => `(async function() { ${code} })();`;
2734

tests/edgeCases.test.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@ describe("Edge cases", () => {
88
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
99

1010
let text = String.prototype.repeat.call("a", 50000);
11-
let token_ids = tokenizer.encode(text);
12-
expect(token_ids).toEqual([101, 100, 102]);
11+
let { ids } = tokenizer.encode(text);
12+
expect(ids).toEqual([101, 100, 102]);
1313
}, 5000); // NOTE: 5 seconds
1414

1515
it("Special/added tokens with earlier partial matches", async () => {
1616
const modelId = "Xenova/gemini-nano";
1717
const { tokenizerJson, tokenizerConfig } = await fetchConfigById(modelId);
1818
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
1919
{
20-
let token_ids = tokenizer.encode("\n", { add_special_tokens: false });
21-
expect(token_ids).toEqual([108]);
20+
let { ids } = tokenizer.encode("\n", { add_special_tokens: false });
21+
expect(ids).toEqual([108]);
2222
}
2323
{
24-
let token_ids = tokenizer.encode("\n\n", { add_special_tokens: false });
25-
expect(token_ids).toEqual([109]); // Should not be [108, 108]
24+
let { ids } = tokenizer.encode("\n\n", { add_special_tokens: false });
25+
expect(ids).toEqual([109]); // Should not be [108, 108]
2626
}
2727
}, 60_000);
2828

@@ -32,7 +32,7 @@ describe("Edge cases", () => {
3232
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
3333

3434
let text = "hello world!";
35-
let token_ids = tokenizer.encode(text);
36-
expect(token_ids).toEqual([128000, 15339, 1917, 0]);
35+
let { ids } = tokenizer.encode(text);
36+
expect(ids).toEqual([128000, 15339, 1917, 0]);
3737
}, 5000); // NOTE: 5 seconds
3838
});

tests/models/llama/llama.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ describe("hard-coded", () => {
4949
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
5050

5151
for (const [text, expected] of Object.entries(data)) {
52-
const token_ids = tokenizer.encode(text, {
52+
const encoded = tokenizer.encode(text, {
5353
add_special_tokens: false,
5454
});
55-
expect(token_ids).toEqual(expected);
55+
expect(encoded.ids).toEqual(expected);
5656

5757
// If reversible, test that decoding produces the original text
5858
if (reversible) {
59-
const decoded = tokenizer.decode(token_ids);
59+
const decoded = tokenizer.decode(encoded.ids);
6060
expect(decoded).toEqual(text);
6161
}
6262
}

tests/models/t5/t5.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ describe("hard-coded", () => {
3838
const tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
3939

4040
for (const [text, expected] of Object.entries(data)) {
41-
const token_ids = tokenizer.encode(text, {
41+
const encoded = tokenizer.encode(text, {
4242
add_special_tokens: false,
4343
});
44-
expect(token_ids).toEqual(expected);
44+
expect(encoded.ids).toEqual(expected);
4545

4646
// If reversible, test that decoding produces the original text
4747
if (reversible) {
48-
const decoded = tokenizer.decode(token_ids);
48+
const decoded = tokenizer.decode(encoded.ids);
4949
expect(decoded).toEqual(text);
5050
}
5151
}

tests/tokenizers.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ describe("Tokenizers (model-specific)", () => {
2020
for (const [testName, testCase] of Object.entries(config.default[modelId])) {
2121
test(testName, () => {
2222
if (testCase.ids) {
23-
const ids = tokenizer.encode(testCase.text, {
23+
const encoded = tokenizer.encode(testCase.text, {
2424
text_pair: testCase.text_pair,
2525
});
26-
expect(ids).toEqual(testCase.ids);
26+
expect(encoded.ids).toEqual(testCase.ids);
2727

2828
if (testCase.decoded) {
2929
const decoded = tokenizer.decode(testCase.ids);

0 commit comments

Comments
 (0)