Skip to content

Commit 0e2f31b

Browse files
committed
update test assertions
Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
1 parent 4450321 commit 0e2f31b

File tree

1 file changed

+46
-30
lines changed

1 file changed

+46
-30
lines changed

pkg/tokenization/uds_tokenizer_test.go

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -104,36 +104,59 @@ func (m *mockTokenizationServer) Tokenize(
104104
}, nil
105105
}
106106

107-
func (m *mockTokenizationServer) RenderChatTemplate(
108-
ctx context.Context,
109-
req *tokenizerpb.ChatTemplateRequest,
110-
) (*tokenizerpb.ChatTemplateResponse, error) {
107+
func (m *mockTokenizationServer) RenderChatCompletion(
108+
_ context.Context,
109+
req *tokenizerpb.RenderChatCompletionRequest,
110+
) (*tokenizerpb.RenderChatCompletionResponse, error) {
111111
if m.chatError {
112-
return &tokenizerpb.ChatTemplateResponse{
112+
return &tokenizerpb.RenderChatCompletionResponse{
113113
Success: false,
114-
ErrorMessage: "mock chat template error",
114+
ErrorMessage: "mock render chat completion error",
115115
}, nil
116116
}
117117

118-
// Check if model was initialized (matches real service behavior)
119-
if !m.initialized[req.ModelName] {
120-
return &tokenizerpb.ChatTemplateResponse{
118+
// Produce fake token IDs from native proto message content.
119+
var tokens []uint32
120+
for _, msg := range req.Messages {
121+
for _, r := range msg.GetText() {
122+
tokens = append(tokens, uint32(r))
123+
}
124+
}
125+
126+
return &tokenizerpb.RenderChatCompletionResponse{
127+
RequestId: "mock-request-id",
128+
TokenIds: tokens,
129+
Success: true,
130+
}, nil
131+
}
132+
133+
func (m *mockTokenizationServer) RenderCompletion(
134+
_ context.Context,
135+
req *tokenizerpb.RenderCompletionRequest,
136+
) (*tokenizerpb.RenderCompletionResponse, error) {
137+
if m.tokenizeError {
138+
return &tokenizerpb.RenderCompletionResponse{
121139
Success: false,
122-
ErrorMessage: fmt.Sprintf("model %s not initialized", req.ModelName),
140+
ErrorMessage: "mock render completion error",
123141
}, nil
124142
}
125143

126-
// Mock chat template rendering by concatenating messages
127-
rendered := ""
128-
for _, turn := range req.ConversationTurns {
129-
for _, msg := range turn.Messages {
130-
rendered += fmt.Sprintf("%s: %s\n", msg.Role, msg.GetText())
144+
items := make([]*tokenizerpb.RenderChatCompletionResponse, 0, len(req.Prompts))
145+
for _, prompt := range req.Prompts {
146+
tokens := make([]uint32, 0, len(prompt))
147+
for _, r := range prompt {
148+
tokens = append(tokens, uint32(r))
131149
}
150+
items = append(items, &tokenizerpb.RenderChatCompletionResponse{
151+
RequestId: "mock-request-id",
152+
TokenIds: tokens,
153+
Success: true,
154+
})
132155
}
133156

134-
return &tokenizerpb.ChatTemplateResponse{
135-
RenderedPrompt: rendered,
136-
Success: true,
157+
return &tokenizerpb.RenderCompletionResponse{
158+
Items: items,
159+
Success: true,
137160
}, nil
138161
}
139162

@@ -269,23 +292,16 @@ func (s *UdsTokenizerTestSuite) TestUdsTokenizer_ModelNotInMap() {
269292
}
270293

271294
func (s *UdsTokenizerTestSuite) TestUdsTokenizer_Render() {
272-
// Test Render - character-based tokenization
273295
input := "hello world"
274296
tokens, offsets, err := s.tokenizer.Render(input)
275297
s.Require().NoError(err)
276-
277-
// Each character becomes a token
278298
s.Assert().Equal(len([]rune(input)), len(tokens))
279-
s.Assert().Equal(len([]rune(input)), len(offsets))
280-
281-
// Verify specific characters
282-
s.Assert().Equal(uint32('h'), tokens[0]) // 'h' = 104
283-
s.Assert().Equal(uint32(' '), tokens[5]) // space at position 5 = 32
284-
s.Assert().Equal(uint32('d'), tokens[10]) // 'd' at end = 100
299+
s.Assert().Nil(offsets, "RenderCompletion does not return character offsets")
285300

286-
// Verify offsets
287-
s.Assert().Equal(types.Offset{0, 1}, offsets[0]) // 'h'
288-
s.Assert().Equal(types.Offset{5, 6}, offsets[5]) // space
301+
// Verify specific characters (mock converts runes to token IDs)
302+
s.Assert().Equal(uint32('h'), tokens[0])
303+
s.Assert().Equal(uint32(' '), tokens[5])
304+
s.Assert().Equal(uint32('d'), tokens[10])
289305
}
290306

291307
func (s *UdsTokenizerTestSuite) TestUdsTokenizer_RenderChat() {

0 commit comments

Comments
 (0)