Skip to content

Commit 9a3ebe7

Browse files
committed
fix: Fix first item and nil issues
1 parent f61b8ad commit 9a3ebe7

File tree

1 file changed

+53
-53
lines changed

1 file changed

+53
-53
lines changed

tests/mcp_tool.go

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,24 @@ func InvokeMCPTool(t *testing.T, toolName string, arguments map[string]any, requ
143143
return resp.StatusCode, &mcpResp, nil
144144
}
145145

146-
// getMCPResultText safely extracts the text from the first content block.
146+
// getMCPResultText safely extracts the text from content blocks, reconstructing an array if there are multiple items.
147147
func getMCPResultText(resp *MCPCallToolResponse) string {
148148
if len(resp.Result.Content) == 0 {
149149
return "[]"
150150
}
151-
return resp.Result.Content[0].Text
151+
if len(resp.Result.Content) == 1 {
152+
return resp.Result.Content[0].Text
153+
}
154+
var builder strings.Builder
155+
builder.WriteString("[")
156+
for i, content := range resp.Result.Content {
157+
if i > 0 {
158+
builder.WriteString(",")
159+
}
160+
builder.WriteString(content.Text)
161+
}
162+
builder.WriteString("]")
163+
return builder.String()
152164
}
153165

154166
// unmarshalMCPResult unmarshals a JSON string into a slice of Ts.
@@ -275,10 +287,7 @@ func RunMCPCustomToolCallMethod(t *testing.T, toolName string, arguments map[str
275287
if mcpResp.Result.IsError {
276288
t.Fatalf("%s returned error result: %v", toolName, mcpResp.Result)
277289
}
278-
if len(mcpResp.Result.Content) == 0 {
279-
t.Fatalf("%s returned empty content field", toolName)
280-
}
281-
got := mcpResp.Result.Content[0].Text
290+
got := getMCPResultText(mcpResp)
282291
if !strings.Contains(got, want) {
283292
t.Fatalf(`expected %q to contain %q`, got, want)
284293
}
@@ -376,10 +385,7 @@ func RunMCPToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTes
376385
if mcpResp.Result.IsError {
377386
t.Fatalf("%s returned error result: %v", tc.toolName, mcpResp.Result)
378387
}
379-
if len(mcpResp.Result.Content) == 0 {
380-
t.Fatalf("%s returned empty content field", tc.toolName)
381-
}
382-
got := mcpResp.Result.Content[0].Text
388+
got := getMCPResultText(mcpResp)
383389
if !strings.Contains(got, tc.wantResult) {
384390
t.Fatalf(`expected %q to contain %q`, got, tc.wantResult)
385391
}
@@ -531,7 +537,7 @@ func RunMCPPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxp
531537
name: "invoke list_schemas with non-existent schema",
532538
args: map[string]any{"schema_name": "non_existent_schema"},
533539
wantStatusCode: http.StatusOK,
534-
want: nil,
540+
want: []map[string]any{},
535541
},
536542
}
537543
for _, tc := range invokeTcs {
@@ -549,14 +555,12 @@ func RunMCPPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxp
549555
if mcpResp.Result.IsError {
550556
t.Fatalf("list_schemas returned error result: %v", mcpResp.Result)
551557
}
552-
var gotObj []map[string]any
553-
if len(mcpResp.Result.Content) > 0 {
554-
got := mcpResp.Result.Content[0].Text
555-
if got != "null" {
556-
gotObj, err = unmarshalMCPResult[map[string]any](got)
557-
if err != nil {
558-
t.Fatalf("failed to unmarshal nested result string: %v", err)
559-
}
558+
gotObj := []map[string]any{}
559+
got := getMCPResultText(mcpResp)
560+
if got != "null" {
561+
gotObj, err = unmarshalMCPResult[map[string]any](got)
562+
if err != nil {
563+
t.Fatalf("failed to unmarshal nested result string: %v", err)
560564
}
561565
}
562566

@@ -666,13 +670,11 @@ func RunMCPPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool
666670
t.Fatalf("list_active_queries returned error result: %v", mcpResp.Result)
667671
}
668672
var details []queryListDetails
669-
if len(mcpResp.Result.Content) > 0 {
670-
got := mcpResp.Result.Content[0].Text
671-
if got != "null" {
672-
details, err = unmarshalMCPResult[queryListDetails](got)
673-
if err != nil {
674-
t.Fatalf("failed to unmarshal nested result string: %v", err)
675-
}
673+
got := getMCPResultText(mcpResp)
674+
if got != "null" {
675+
details, err = unmarshalMCPResult[queryListDetails](got)
676+
if err != nil {
677+
t.Fatalf("failed to unmarshal nested result string: %v", err)
676678
}
677679
}
678680

@@ -842,7 +844,7 @@ func RunMCPPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgx
842844
name: "filter by non-existent trigger_name",
843845
args: map[string]any{"trigger_name": "non_existent_trigger"},
844846
wantStatusCode: http.StatusOK,
845-
want: nil,
847+
want: []map[string]any{},
846848
},
847849
}
848850
for _, tc := range invokeTcs {
@@ -861,14 +863,16 @@ func RunMCPPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgx
861863
t.Fatalf("list_triggers returned error result: %v", mcpResp.Result)
862864
}
863865
var gotObj []map[string]any
864-
if len(mcpResp.Result.Content) > 0 {
865-
got := mcpResp.Result.Content[0].Text
866-
if got != "null" {
867-
gotObj, err = unmarshalMCPResult[map[string]any](got)
868-
if err != nil {
869-
t.Fatalf("failed to unmarshal nested result string: %v", err)
870-
}
866+
for _, content := range mcpResp.Result.Content {
867+
got := content.Text
868+
if got == "null" {
869+
continue
870+
}
871+
items, err := unmarshalMCPResult[map[string]any](got)
872+
if err != nil {
873+
t.Fatalf("failed to unmarshal nested result string: %v", err)
871874
}
875+
gotObj = append(gotObj, items...)
872876
}
873877

874878
if tc.compareSubset {
@@ -885,7 +889,7 @@ func RunMCPPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgx
885889
}
886890
}
887891
if !found {
888-
t.Errorf("Expected trigger '%+v' not found in the list.", wantTrigger)
892+
t.Errorf("Expected trigger '%+v' not found in the list. Got: %+v", wantTrigger, gotObj)
889893
}
890894
} else {
891895
if !reflect.DeepEqual(tc.want, gotObj) {
@@ -946,7 +950,7 @@ func RunMCPPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pg
946950
name: "invoke list_sequences with non-existent sequence",
947951
args: map[string]any{"sequence_name": "non_existent_sequence"},
948952
wantStatusCode: http.StatusOK,
949-
want: nil,
953+
want: []map[string]any{},
950954
},
951955
}
952956
for _, tc := range invokeTcs {
@@ -964,14 +968,12 @@ func RunMCPPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pg
964968
if mcpResp.Result.IsError {
965969
t.Fatalf("list_sequences returned error result: %v", mcpResp.Result)
966970
}
967-
var gotObj []map[string]any
968-
if len(mcpResp.Result.Content) > 0 {
969-
got := mcpResp.Result.Content[0].Text
970-
if got != "null" {
971-
gotObj, err = unmarshalMCPResult[map[string]any](got)
972-
if err != nil {
973-
t.Fatalf("failed to unmarshal nested result string: %v", err)
974-
}
971+
gotObj := []map[string]any{}
972+
got := getMCPResultText(mcpResp)
973+
if got != "null" {
974+
gotObj, err = unmarshalMCPResult[map[string]any](got)
975+
if err != nil {
976+
t.Fatalf("failed to unmarshal nested result string: %v", err)
975977
}
976978
}
977979

@@ -1084,7 +1086,7 @@ func RunMCPPostgresListIndexesTest(t *testing.T, ctx context.Context, pool *pgxp
10841086
name: "invoke list_indexes with non-existent table",
10851087
args: map[string]any{"table_name": "non_existent_table"},
10861088
wantStatusCode: http.StatusOK,
1087-
want: nil,
1089+
want: []map[string]any{},
10881090
},
10891091
}
10901092
for _, tc := range invokeTcs {
@@ -1103,7 +1105,7 @@ func RunMCPPostgresListIndexesTest(t *testing.T, ctx context.Context, pool *pgxp
11031105
if mcpResp.Result.IsError {
11041106
t.Fatalf("list_indexes returned error result: %v", mcpResp.Result)
11051107
}
1106-
var gotObj []map[string]any
1108+
gotObj := []map[string]any{}
11071109
for _, content := range mcpResp.Result.Content {
11081110
got := content.Text
11091111
t.Logf("list_indexes got: %s", got)
@@ -1295,13 +1297,11 @@ func RunMCPPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, po
12951297
t.Fatalf("list_stored_procedure returned error result: %v", mcpResp.Result)
12961298
}
12971299
var gotObj []storedProcedureDetails
1298-
if len(mcpResp.Result.Content) > 0 {
1299-
got := mcpResp.Result.Content[0].Text
1300-
if got != "null" {
1301-
gotObj, err = unmarshalMCPResult[storedProcedureDetails](got)
1302-
if err != nil {
1303-
t.Fatalf("failed to unmarshal nested result string: %v", err)
1304-
}
1300+
got := getMCPResultText(mcpResp)
1301+
if got != "null" {
1302+
gotObj, err = unmarshalMCPResult[storedProcedureDetails](got)
1303+
if err != nil {
1304+
t.Fatalf("failed to unmarshal nested result string: %v", err)
13051305
}
13061306
}
13071307

0 commit comments

Comments
 (0)