Skip to content

Commit 082b05d

Browse files
committed
fix
1 parent 86f7f79 commit 082b05d

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

tests/tool.go

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ func RunToolInvokeParametersTest(t *testing.T, name string, params []byte, simpl
230230
// RunToolInvoke runs the tool invoke endpoint
231231
func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOption) {
232232
// Resolve options
233-
// Default values for InvokeTestConfig
234233
configs := &InvokeTestConfig{
235234
myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
236235
myToolById4Want: "[{\"id\":4,\"name\":null}]",
@@ -419,19 +418,17 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
419418
if !tc.enabled {
420419
return
421420
}
421+
reqBytes, _ := io.ReadAll(tc.requestBody)
422422
var got string
423423
var actualStatusCode int
424424

425425
if configs.IsMCP {
426426
parts := strings.Split(tc.api, "/")
427427
toolName := parts[len(parts)-2]
428428

429-
reqBytes, _ := io.ReadAll(tc.requestBody)
430429
var args map[string]any
431430
if len(reqBytes) > 0 {
432-
if err := json.Unmarshal(reqBytes, &args); err != nil {
433-
t.Fatalf("failed to unmarshal request body for MCP args: %v", err)
434-
}
431+
_ = json.Unmarshal(reqBytes, &args)
435432
}
436433
if args == nil {
437434
args = make(map[string]any)
@@ -461,8 +458,8 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
461458
}
462459
}
463460
} else {
464-
// Native API execution
465-
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
461+
// Reconstruct a fresh buffer for the legacy API request using the saved bytes
462+
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqBytes), tc.requestHeader)
466463
actualStatusCode = resp.StatusCode
467464

468465
if tc.wantBody != "" && actualStatusCode == tc.wantStatusCode {
@@ -495,6 +492,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
495492
}
496493
}
497494

495+
// RunToolInvokeWithTemplateParameters runs tool invoke test cases with template parameters.
498496
func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options ...TemplateParamOption) {
499497
configs := &TemplateParameterTestConfig{
500498
ddlWant: "null",
@@ -508,8 +506,9 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
508506
nameColFilter: "name",
509507
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
510508

511-
supportDdl: true,
512-
supportInsert: true,
509+
supportDdl: true,
510+
supportInsert: true,
511+
supportSelectFields: true,
513512
}
514513

515514
for _, option := range options {
@@ -609,20 +608,20 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
609608
}
610609
for _, tc := range invokeTcs {
611610
t.Run(tc.name, func(t *testing.T) {
612-
if !tc.enabled && tc.name == "invoke select-fields-templateParams-tool" {
611+
if !tc.enabled {
613612
return
614613
}
615614
ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl)
616615
insertAllow := !tc.insert || (tc.insert && configs.supportInsert)
617616

618617
if ddlAllow && insertAllow {
618+
reqBytes, _ := io.ReadAll(tc.requestBody)
619619
var got string
620620

621621
if configs.IsMCP {
622622
parts := strings.Split(tc.api, "/")
623623
toolName := parts[len(parts)-2]
624624

625-
reqBytes, _ := io.ReadAll(tc.requestBody)
626625
var args map[string]any
627626
if len(reqBytes) > 0 {
628627
_ = json.Unmarshal(reqBytes, &args)
@@ -640,18 +639,23 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
640639
}
641640

642641
var blocks []string
643-
for _, content := range mcpResp.Result.Content {
644-
if content.Type == "text" {
645-
blocks = append(blocks, strings.TrimSpace(content.Text))
642+
if mcpResp != nil && !mcpResp.Result.IsError {
643+
for _, content := range mcpResp.Result.Content {
644+
if content.Type == "text" {
645+
blocks = append(blocks, strings.TrimSpace(content.Text))
646+
}
646647
}
647648
}
648-
if len(blocks) == 0 {
649+
650+
if mcpResp != nil && mcpResp.Error != nil {
651+
got = fmt.Sprintf(`{"error":"%s"}`, mcpResp.Error.Message)
652+
} else if len(blocks) == 0 {
649653
got = "null"
650654
} else {
651655
got = strings.Join(blocks, "")
652656
}
653657
} else {
654-
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
658+
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqBytes), tc.requestHeader)
655659
if resp.StatusCode != http.StatusOK {
656660
if tc.isErr {
657661
return
@@ -673,7 +677,17 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
673677
}
674678

675679
if got != tc.want {
676-
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
680+
var gotJSON, wantJSON any
681+
errGot := json.Unmarshal([]byte(got), &gotJSON)
682+
errWant := json.Unmarshal([]byte(tc.want), &wantJSON)
683+
684+
if errGot == nil && errWant == nil {
685+
if diff := cmp.Diff(wantJSON, gotJSON); diff != "" {
686+
t.Fatalf("unexpected JSON value mismatch (-want +got):\n%s\nRaw got: %s\nRaw want: %s", diff, got, tc.want)
687+
}
688+
} else {
689+
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
690+
}
677691
}
678692
}
679693
})
@@ -784,13 +798,15 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want
784798
}
785799
for _, tc := range invokeTcs {
786800
t.Run(tc.name, func(t *testing.T) {
801+
802+
// Read buffer ONCE to prevent draining
803+
reqBytes, _ := io.ReadAll(tc.requestBody)
787804
var got string
788805

789806
if configs.IsMCP {
790807
parts := strings.Split(tc.api, "/")
791808
toolName := parts[len(parts)-2]
792809

793-
reqBytes, _ := io.ReadAll(tc.requestBody)
794810
var args map[string]any
795811
if len(reqBytes) > 0 {
796812
_ = json.Unmarshal(reqBytes, &args)
@@ -824,7 +840,8 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want
824840
got = strings.Join(blocks, "")
825841
}
826842
} else {
827-
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
843+
// Reconstruct a fresh buffer for the legacy API request using the saved bytes
844+
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqBytes), tc.requestHeader)
828845
if resp.StatusCode != http.StatusOK {
829846
if tc.isErr {
830847
return

0 commit comments

Comments
 (0)