Skip to content

Commit 40c2211

Browse files
DePasqualeOrgMatttsmdesai
authored
Enable lstripBlocks and trimBlocks options for chat templates (#12)
Co-authored-by: Mattt <mattt@me.com> Co-authored-by: Sachin Desai <smdesai@gmail.com>
1 parent e786f06 commit 40c2211

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ public class PreTrainedTokenizer: @unchecked Sendable, Tokenizer {
559559
_templateCacheLock.unlock()
560560

561561
// Compile template outside of lock to avoid holding lock during expensive operation
562-
let compiled = try Template(templateString)
562+
let compiled = try Template(templateString, with: .init(lstripBlocks: true, trimBlocks: true))
563563

564564
// Insert into cache (double-checked in case another thread compiled the same template)
565565
_templateCacheLock.lock()

Tests/TokenizersTests/ChatTemplateTests.swift

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,33 @@ struct ChatTemplateTests {
182182
#expect(decoded == decodedTarget)
183183
}
184184

185+
/// https://github.com/huggingface/swift-transformers/issues/322
186+
@Test("Jinja block whitespace control")
187+
func jinjaBlockWhitespaceControl() async throws {
188+
let tokenizer = try await Self.sharedPhiTokenizer()
189+
let whitespaceSensitiveTemplate = """
190+
{% for message in messages %}
191+
{% if message['role'] == 'user' %}
192+
{{ message['content'] }}
193+
{% endif %}
194+
{% endfor %}
195+
{% if add_generation_prompt %}
196+
assistant
197+
{% endif %}
198+
"""
199+
let encoded = try tokenizer.applyChatTemplate(
200+
messages: messages, chatTemplate: whitespaceSensitiveTemplate
201+
)
202+
let decoded = tokenizer.decode(tokens: encoded)
203+
let expected = """
204+
Describe the Swift programming language.
205+
assistant
206+
"""
207+
#expect(decoded == expected)
208+
#expect(!decoded.hasPrefix("\n"))
209+
#expect(!decoded.contains("\n\n"))
210+
}
211+
185212
@Test("Qwen 2.5 with tools functionality")
186213
func qwen2_5WithTools() async throws {
187214
let tokenizer = try await makeTokenizer(model: "mlx-community/Qwen2.5-7B-Instruct-4bit")
@@ -284,7 +311,12 @@ struct ChatTemplateTests {
284311
decoded.hasPrefix(expectedPromptStart),
285312
"Prompt should start with expected system message"
286313
)
287-
#expect(decoded.hasSuffix(expectedPromptEnd), "Prompt should end with expected format")
314+
#expect(
315+
decoded.trimmingCharacters(in: .newlines).hasSuffix(
316+
expectedPromptEnd.trimmingCharacters(in: .newlines)
317+
),
318+
"Prompt should end with expected format"
319+
)
288320
}
289321

290322
/// Test for vision models with a vision chat template in chat_template.json

0 commit comments

Comments
 (0)