Skip to content

Commit cdd6cdc

Browse files
nickmisasiclaude
andauthored
Restrict native tools (web_search) to Direct Message contexts only (#432)
* Add WithToolsDisabled to LanguageModelConfig, add checks in relevant providers * Fix TestThreadsAnalyze to expect WithToolsDisabled option The test was failing because ChatCompletion now receives an additional variadic options parameter (WithToolsDisabled). Updated the mock expectation to accept both arguments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 3f97fc4 commit cdd6cdc

File tree

10 files changed

+46
-23
lines changed

10 files changed

+46
-23
lines changed

anthropic/anthropic.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ func (a *Anthropic) streamChatWithTools(state messageState) {
202202
Model: anthropicSDK.Model(state.config.Model),
203203
MaxTokens: int64(state.config.MaxGeneratedTokens),
204204
Messages: state.messages,
205-
Tools: convertTools(state.tools),
205+
}
206+
207+
// Only add tools if not explicitly disabled
208+
if !state.config.ToolsDisabled {
209+
params.Tools = convertTools(state.tools)
206210
}
207211

208212
// Only include system message if it's non-empty
@@ -213,8 +217,8 @@ func (a *Anthropic) streamChatWithTools(state messageState) {
213217
}}
214218
}
215219

216-
// Add native tools if enabled
217-
if a.isNativeToolEnabled("web_search") {
220+
// Add native tools if not explicitly disabled
221+
if !state.config.ToolsDisabled && a.isNativeToolEnabled("web_search") {
218222
// Add web search as a native tool
219223
webSearchTool := anthropicSDK.WebSearchTool20250305Param{
220224
Name: "web_search",

api/api_llm_bridge.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ func (a *API) convertRequestToLLMOptions(req bridgeclient.CompletionRequest) ([]
104104
})
105105
}
106106

107+
// Plugin bridge requests do not allow tools to be enabled
108+
options = append(options, llm.WithToolsDisabled())
107109
return options, nil
108110
}
109111

channels/channels.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func (c *Channels) Interval(
9191
Context: context,
9292
}
9393

94-
resultStream, err := c.llm.ChatCompletion(completionRequest)
94+
resultStream, err := c.llm.ChatCompletion(completionRequest, llm.WithToolsDisabled())
9595
if err != nil {
9696
return nil, err
9797
}

conversations/conversations.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ func (c *Conversations) ProcessUserRequestWithContext(bot *bots.Bot, postingUser
121121
Posts: posts,
122122
Context: context,
123123
}
124-
result, err := bot.LLM().ChatCompletion(completionRequest)
124+
isDM := mmapi.IsDMWith(bot.GetMMBot().UserId, channel)
125+
var opts []llm.LanguageModelOption
126+
if !isDM {
127+
opts = append(opts, llm.WithToolsDisabled())
128+
}
129+
result, err := bot.LLM().ChatCompletion(completionRequest, opts...)
125130
if err != nil {
126131
return nil, err
127132
}

llm/language_model.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type LanguageModelConfig struct {
3737
MaxGeneratedTokens int
3838
EnableVision bool
3939
JSONOutputFormat *jsonschema.Schema
40+
ToolsDisabled bool
4041
}
4142

4243
type LanguageModelOption func(*LanguageModelConfig)
@@ -57,4 +58,10 @@ func WithJSONOutput[T any]() LanguageModelOption {
5758
}
5859
}
5960

61+
func WithToolsDisabled() LanguageModelOption {
62+
return func(cfg *LanguageModelConfig) {
63+
cfg.ToolsDisabled = true
64+
}
65+
}
66+
6067
type LanguageModelWrapper func(LanguageModel) LanguageModel

meetings/meeting_summarization.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func (s *Service) SummarizeTranscription(bot *bots.Bot, transcription *subtitles
347347
Context: context,
348348
}
349349

350-
summaryStream, err := bot.LLM().ChatCompletion(completionRequest)
350+
summaryStream, err := bot.LLM().ChatCompletion(completionRequest, llm.WithToolsDisabled())
351351
if err != nil {
352352
return nil, fmt.Errorf("unable to get meeting summary: %w", err)
353353
}

openai/openai.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ func NewCompatibleEmbeddings(config Config, httpClient *http.Client) *OpenAI {
145145
}
146146
}
147147

148-
func modifyCompletionRequestWithRequest(params openai.ChatCompletionNewParams, internalRequest llm.CompletionRequest) openai.ChatCompletionNewParams {
148+
func modifyCompletionRequestWithRequest(params openai.ChatCompletionNewParams, internalRequest llm.CompletionRequest, cfg llm.LanguageModelConfig) openai.ChatCompletionNewParams {
149149
params.Messages = postsToChatCompletionMessages(internalRequest.Posts)
150-
if internalRequest.Context.Tools != nil {
150+
// Only add tools if not explicitly disabled
151+
if !cfg.ToolsDisabled && internalRequest.Context.Tools != nil {
151152
params.Tools = toolsToOpenAITools(internalRequest.Context.Tools.GetTools())
152153
}
153154
return params
@@ -318,10 +319,10 @@ type ToolBufferElement struct {
318319
args strings.Builder
319320
}
320321

321-
func (s *OpenAI) streamResultToChannels(params openai.ChatCompletionNewParams, llmContext *llm.Context, output chan<- llm.TextStreamEvent) {
322+
func (s *OpenAI) streamResultToChannels(params openai.ChatCompletionNewParams, llmContext *llm.Context, cfg llm.LanguageModelConfig, output chan<- llm.TextStreamEvent) {
322323
// Route to Responses API or Completions API based on configuration
323324
if s.config.UseResponsesAPI {
324-
s.streamResponsesAPIToChannels(params, llmContext, output)
325+
s.streamResponsesAPIToChannels(params, llmContext, cfg, output)
325326
} else {
326327
s.streamCompletionsAPIToChannels(params, llmContext, output)
327328
}
@@ -484,7 +485,7 @@ func (s *OpenAI) streamCompletionsAPIToChannels(params openai.ChatCompletionNewP
484485
}
485486

486487
// streamResponsesAPIToChannels uses the new Responses API for streaming
487-
func (s *OpenAI) streamResponsesAPIToChannels(params openai.ChatCompletionNewParams, llmContext *llm.Context, output chan<- llm.TextStreamEvent) {
488+
func (s *OpenAI) streamResponsesAPIToChannels(params openai.ChatCompletionNewParams, llmContext *llm.Context, cfg llm.LanguageModelConfig, output chan<- llm.TextStreamEvent) {
488489
ctx, cancel := context.WithCancelCause(context.Background())
489490
defer cancel(nil)
490491

@@ -510,7 +511,7 @@ func (s *OpenAI) streamResponsesAPIToChannels(params openai.ChatCompletionNewPar
510511
}()
511512

512513
// Convert ChatCompletionNewParams to ResponseNewParams
513-
responseParams := s.convertToResponseParams(params, llmContext)
514+
responseParams := s.convertToResponseParams(params, llmContext, cfg)
514515

515516
// Create a streaming request
516517
stream := s.client.Responses.NewStreaming(ctx, responseParams)
@@ -843,7 +844,7 @@ func (s *OpenAI) streamResponsesAPIToChannels(params openai.ChatCompletionNewPar
843844

844845
// convertToResponseParams converts ChatCompletionNewParams to ResponseNewParams
845846
// This is a simplified conversion that handles the basic use cases
846-
func (s *OpenAI) convertToResponseParams(params openai.ChatCompletionNewParams, llmContext *llm.Context) responses.ResponseNewParams {
847+
func (s *OpenAI) convertToResponseParams(params openai.ChatCompletionNewParams, llmContext *llm.Context, cfg llm.LanguageModelConfig) responses.ResponseNewParams {
847848
result := responses.ResponseNewParams{}
848849

849850
// Convert model - directly assign as it's the same type
@@ -988,8 +989,8 @@ func (s *OpenAI) convertToResponseParams(params openai.ChatCompletionNewParams,
988989
}
989990
}
990991

991-
// Add native tools if enabled
992-
if len(s.config.EnabledNativeTools) > 0 {
992+
// Add native tools if not explicitly disabled
993+
if !cfg.ToolsDisabled && len(s.config.EnabledNativeTools) > 0 {
993994
for _, nativeTool := range s.config.EnabledNativeTools {
994995
if nativeTool == "web_search" {
995996
// Add web search as a built-in tool
@@ -1026,11 +1027,11 @@ func (s *OpenAI) convertToResponseParams(params openai.ChatCompletionNewParams,
10261027
return result
10271028
}
10281029

1029-
func (s *OpenAI) streamResult(params openai.ChatCompletionNewParams, llmContext *llm.Context) (*llm.TextStreamResult, error) {
1030+
func (s *OpenAI) streamResult(params openai.ChatCompletionNewParams, llmContext *llm.Context, cfg llm.LanguageModelConfig) (*llm.TextStreamResult, error) {
10301031
eventStream := make(chan llm.TextStreamEvent)
10311032
go func() {
10321033
defer close(eventStream)
1033-
s.streamResultToChannels(params, llmContext, eventStream)
1034+
s.streamResultToChannels(params, llmContext, cfg, eventStream)
10341035
}()
10351036

10361037
return &llm.TextStreamResult{Stream: eventStream}, nil
@@ -1101,16 +1102,17 @@ func getModelConstant(model string) shared.ChatModel {
11011102
}
11021103

11031104
func (s *OpenAI) ChatCompletion(request llm.CompletionRequest, opts ...llm.LanguageModelOption) (*llm.TextStreamResult, error) {
1104-
params := s.completionRequestFromConfig(s.createConfig(opts))
1105-
params = modifyCompletionRequestWithRequest(params, request)
1105+
cfg := s.createConfig(opts)
1106+
params := s.completionRequestFromConfig(cfg)
1107+
params = modifyCompletionRequestWithRequest(params, request, cfg)
11061108
params.StreamOptions.IncludeUsage = openai.Bool(true)
11071109

11081110
if s.config.SendUserID {
11091111
if request.Context.RequestingUser != nil {
11101112
params.User = openai.String(request.Context.RequestingUser.Id)
11111113
}
11121114
}
1113-
return s.streamResult(params, request.Context)
1115+
return s.streamResult(params, request.Context, cfg)
11141116
}
11151117

11161118
func (s *OpenAI) ChatCompletionNoStream(request llm.CompletionRequest, opts ...llm.LanguageModelOption) (string, error) {

openai/openai_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,10 @@ func TestReasoningEffortConfiguration(t *testing.T) {
811811
}
812812

813813
// Call the actual function that handles reasoning configuration
814-
result := oai.convertToResponseParams(chatParams, &llm.Context{})
814+
result := oai.convertToResponseParams(chatParams, &llm.Context{}, llm.LanguageModelConfig{
815+
Model: "gpt-4o",
816+
MaxGeneratedTokens: 8192,
817+
})
815818

816819
if !tt.shouldSetReasoning {
817820
// When reasoning is disabled, Reasoning should be empty

threads/threads.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (t *Threads) Analyze(postIDToAnalyze string, context *llm.Context, promptNa
5252
Posts: posts,
5353
Context: context,
5454
}
55-
analysisStream, err := t.llm.ChatCompletion(completionReqest)
55+
analysisStream, err := t.llm.ChatCompletion(completionReqest, llm.WithToolsDisabled())
5656
if err != nil {
5757
return nil, err
5858
}

threads/threads_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func TestThreadsAnalyze(t *testing.T) {
102102
}
103103

104104
if tc.expectedLLMCalls > 0 {
105-
mockLLM.EXPECT().ChatCompletion(mock.Anything).Return(&llm.TextStreamResult{}, tc.llmError)
105+
mockLLM.EXPECT().ChatCompletion(mock.Anything, mock.Anything).Return(&llm.TextStreamResult{}, tc.llmError)
106106
}
107107

108108
threadService := threads.New(mockLLM, prompts, mockClient)

0 commit comments

Comments
 (0)