Skip to content

Commit 4cd4fc9

Browse files
committed
refactor(dataproc): fix tool Invoke method return value
Accounts for the new ToolboxError type in #2403
1 parent ddb3767 commit 4cd4fc9

File tree

5 files changed

+54
-30
lines changed

5 files changed

+54
-30
lines changed

internal/tools/dataproc/dataprocgetcluster/dataprocgetcluster.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ package dataprocgetcluster
1717
import (
1818
"context"
1919
"fmt"
20+
"net/http"
2021
"strings"
2122

2223
"github.com/goccy/go-yaml"
2324
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
2425
"github.com/googleapis/genai-toolbox/internal/sources"
2526
"github.com/googleapis/genai-toolbox/internal/sources/dataproc"
2627
"github.com/googleapis/genai-toolbox/internal/tools"
28+
"github.com/googleapis/genai-toolbox/internal/util"
2729
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2830
)
2931

@@ -112,22 +114,26 @@ type compatibleSource interface {
112114
}
113115

114116
// Invoke executes the tool's operation.
115-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
117+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
116118
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Config.Source, t.Name, kind)
117119
if err != nil {
118-
return nil, err
120+
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
119121
}
120122

121123
paramMap := params.AsMap()
122124
name, ok := paramMap["clusterName"].(string)
123125
if !ok {
124-
return nil, fmt.Errorf("missing required parameter: clusterName")
126+
return nil, util.NewAgentError("missing required parameter: clusterName", nil)
125127
}
126128
if strings.Contains(name, "/") {
127-
return nil, fmt.Errorf("clusterName must be a short name without '/': %s", name)
129+
return nil, util.NewAgentError(fmt.Sprintf("clusterName must be a short name without '/': %s", name), nil)
128130
}
129131

130-
return source.GetCluster(ctx, name)
132+
res, err := source.GetCluster(ctx, name)
133+
if err != nil {
134+
return nil, util.ProcessGcpError(err)
135+
}
136+
return res, nil
131137
}
132138

133139
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {

internal/tools/dataproc/dataprocgetjob/dataprocgetjob.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ package dataprocgetjob
1717
import (
1818
"context"
1919
"fmt"
20+
"net/http"
2021
"strings"
2122

2223
"github.com/goccy/go-yaml"
2324
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
2425
"github.com/googleapis/genai-toolbox/internal/sources"
2526
"github.com/googleapis/genai-toolbox/internal/sources/dataproc"
2627
"github.com/googleapis/genai-toolbox/internal/tools"
28+
"github.com/googleapis/genai-toolbox/internal/util"
2729
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2830
)
2931

@@ -112,22 +114,26 @@ type compatibleSource interface {
112114
}
113115

114116
// Invoke executes the tool's operation.
115-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
117+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
116118
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Config.Source, t.Name, kind)
117119
if err != nil {
118-
return nil, err
120+
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
119121
}
120122

121123
paramMap := params.AsMap()
122124
jobId, ok := paramMap["jobId"].(string)
123125
if !ok {
124-
return nil, fmt.Errorf("missing required parameter: jobId")
126+
return nil, util.NewAgentError("missing required parameter: jobId", nil)
125127
}
126128
if strings.Contains(jobId, "/") {
127-
return nil, fmt.Errorf("jobId must be a short name without '/': %s", jobId)
129+
return nil, util.NewAgentError(fmt.Sprintf("jobId must be a short name without '/': %s", jobId), nil)
128130
}
129131

130-
return source.GetJob(ctx, jobId)
132+
res, err := source.GetJob(ctx, jobId)
133+
if err != nil {
134+
return nil, util.ProcessGcpError(err)
135+
}
136+
return res, nil
131137
}
132138

133139
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {

internal/tools/dataproc/dataproclistclusters/dataproclistclusters.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package dataproclistclusters
1717
import (
1818
"context"
1919
"fmt"
20+
"net/http"
2021

2122
"github.com/goccy/go-yaml"
2223
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
2324
"github.com/googleapis/genai-toolbox/internal/sources"
2425
"github.com/googleapis/genai-toolbox/internal/sources/dataproc"
2526
"github.com/googleapis/genai-toolbox/internal/tools"
27+
"github.com/googleapis/genai-toolbox/internal/util"
2628
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2729
)
2830

@@ -113,25 +115,29 @@ type compatibleSource interface {
113115
}
114116

115117
// Invoke executes the tool's operation.
116-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
118+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
117119
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Config.Source, t.Name, kind)
118120
if err != nil {
119-
return nil, err
121+
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
120122
}
121123

122124
paramMap := params.AsMap()
123125
var pageSize *int
124126
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
125127
pageSizeV := ps.(int)
126128
if pageSizeV <= 0 {
127-
return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV)
129+
return nil, util.NewAgentError(fmt.Sprintf("pageSize must be positive: %d", pageSizeV), nil)
128130
}
129131
pageSize = &pageSizeV
130132
}
131133
pt, _ := paramMap["pageToken"].(string)
132134
filter, _ := paramMap["filter"].(string)
133135

134-
return source.ListClusters(ctx, pageSize, pt, filter)
136+
res, err := source.ListClusters(ctx, pageSize, pt, filter)
137+
if err != nil {
138+
return nil, util.ProcessGcpError(err)
139+
}
140+
return res, nil
135141
}
136142

137143
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {

internal/tools/dataproc/dataproclistjobs/dataproclistjobs.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package dataproclistjobs
1717
import (
1818
"context"
1919
"fmt"
20+
"net/http"
2021

2122
"github.com/goccy/go-yaml"
2223
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
2324
"github.com/googleapis/genai-toolbox/internal/sources"
2425
"github.com/googleapis/genai-toolbox/internal/sources/dataproc"
2526
"github.com/googleapis/genai-toolbox/internal/tools"
27+
"github.com/googleapis/genai-toolbox/internal/util"
2628
"github.com/googleapis/genai-toolbox/internal/util/parameters"
2729
)
2830

@@ -114,26 +116,30 @@ type compatibleSource interface {
114116
}
115117

116118
// Invoke executes the tool's operation.
117-
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
119+
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
118120
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Config.Source, t.Name, kind)
119121
if err != nil {
120-
return nil, err
122+
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
121123
}
122124

123125
paramMap := params.AsMap()
124126
var pageSize *int
125127
if ps, ok := paramMap["pageSize"]; ok && ps != nil {
126128
pageSizeV := ps.(int)
127129
if pageSizeV <= 0 {
128-
return nil, fmt.Errorf("pageSize must be positive: %d", pageSizeV)
130+
return nil, util.NewAgentError(fmt.Sprintf("pageSize must be positive: %d", pageSizeV), nil)
129131
}
130132
pageSize = &pageSizeV
131133
}
132134
pt, _ := paramMap["pageToken"].(string)
133135
filter, _ := paramMap["filter"].(string)
134136
matcher, _ := paramMap["jobStateMatcher"].(string)
135137

136-
return source.ListJobs(ctx, pageSize, pt, filter, matcher)
138+
res, err := source.ListJobs(ctx, pageSize, pt, filter, matcher)
139+
if err != nil {
140+
return nil, util.ProcessGcpError(err)
141+
}
142+
return res, nil
137143
}
138144

139145
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {

tests/dataproc/dataproc_integration_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,14 @@ func TestDataprocClustersToolEndpoints(t *testing.T) {
176176
name: "missing cluster",
177177
toolName: "get-cluster",
178178
request: map[string]any{"clusterName": "INVALID_CLUSTER"},
179-
wantCode: http.StatusBadRequest,
179+
wantCode: http.StatusOK,
180180
wantMsg: fmt.Sprintf("Not found: Cluster projects/%s/regions/%s/clusters/INVALID_CLUSTER", dataprocProject, dataprocRegion),
181181
},
182182
{
183183
name: "full cluster name",
184184
toolName: "get-cluster",
185185
request: map[string]any{"clusterName": missingClusterFullName},
186-
wantCode: http.StatusBadRequest,
186+
wantCode: http.StatusOK,
187187
wantMsg: fmt.Sprintf("clusterName must be a short name without '/': %s", missingClusterFullName),
188188
},
189189
}
@@ -220,14 +220,14 @@ func TestDataprocClustersToolEndpoints(t *testing.T) {
220220
name: "missing job",
221221
toolName: "get-job",
222222
request: map[string]any{"jobId": "INVALID_JOB"},
223-
wantCode: http.StatusBadRequest,
223+
wantCode: http.StatusOK,
224224
wantMsg: fmt.Sprintf("Not found: Job projects/%s/regions/%s/jobs/INVALID_JOB", dataprocProject, dataprocRegion),
225225
},
226226
{
227227
name: "full job name",
228228
toolName: "get-job",
229229
request: map[string]any{"jobId": missingJobFullName},
230-
wantCode: http.StatusBadRequest,
230+
wantCode: http.StatusOK,
231231
wantMsg: fmt.Sprintf("jobId must be a short name without '/': %s", missingJobFullName),
232232
},
233233
}
@@ -260,14 +260,14 @@ func TestDataprocClustersToolEndpoints(t *testing.T) {
260260
name: "zero page size",
261261
toolName: "list-clusters",
262262
request: map[string]any{"pageSize": 0},
263-
wantCode: http.StatusBadRequest,
263+
wantCode: http.StatusOK,
264264
wantMsg: "pageSize must be positive: 0",
265265
},
266266
{
267267
name: "negative page size",
268268
toolName: "list-clusters",
269269
request: map[string]any{"pageSize": -1},
270-
wantCode: http.StatusBadRequest,
270+
wantCode: http.StatusOK,
271271
wantMsg: "pageSize must be positive: -1",
272272
},
273273
}
@@ -301,14 +301,14 @@ func TestDataprocClustersToolEndpoints(t *testing.T) {
301301
name: "zero page size",
302302
toolName: "list-jobs",
303303
request: map[string]any{"pageSize": 0},
304-
wantCode: http.StatusBadRequest,
304+
wantCode: http.StatusOK,
305305
wantMsg: "pageSize must be positive: 0",
306306
},
307307
{
308308
name: "negative page size",
309309
toolName: "list-jobs",
310310
request: map[string]any{"pageSize": -1},
311-
wantCode: http.StatusBadRequest,
311+
wantCode: http.StatusOK,
312312
wantMsg: "pageSize must be positive: -1",
313313
},
314314
}
@@ -547,13 +547,13 @@ func runGetClusterTest(t *testing.T, client *dataproc.ClusterControllerClient, c
547547
{
548548
name: "missing clusterName",
549549
request: map[string]any{},
550-
wantCode: http.StatusBadRequest,
550+
wantCode: http.StatusOK,
551551
wantMsg: "missing required parameter: clusterName",
552552
},
553553
{
554554
name: "invalid name with slash",
555555
request: map[string]any{"clusterName": "projects/foo/regions/bar/clusters/baz"}, // Full name requires matching project/region
556-
wantCode: http.StatusBadRequest,
556+
wantCode: http.StatusOK,
557557
wantMsg: "clusterName must be a short name without '/'",
558558
},
559559
}
@@ -855,13 +855,13 @@ func runGetJobTest(t *testing.T, client *dataproc.JobControllerClient, ctx conte
855855
{
856856
name: "missing jobId",
857857
request: map[string]any{},
858-
wantCode: http.StatusBadRequest,
858+
wantCode: http.StatusOK,
859859
wantMsg: "missing required parameter: jobId",
860860
},
861861
{
862862
name: "invalid name with slash",
863863
request: map[string]any{"jobId": "projects/foo/regions/bar/jobs/baz"},
864-
wantCode: http.StatusBadRequest,
864+
wantCode: http.StatusOK,
865865
wantMsg: "jobId must be a short name without '/'",
866866
},
867867
}

0 commit comments

Comments
 (0)