-
-
Notifications
You must be signed in to change notification settings - Fork 153
Expand file tree
/
Copy pathinit.go
More file actions
321 lines (277 loc) · 11.1 KB
/
init.go
File metadata and controls
321 lines (277 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
package ai
import (
"context"
"fmt"
"sort"
"strings"
errUtils "github.com/cloudposse/atmos/errors"
"github.com/cloudposse/atmos/pkg/ai"
"github.com/cloudposse/atmos/pkg/ai/tools"
atmosTools "github.com/cloudposse/atmos/pkg/ai/tools/atmos"
"github.com/cloudposse/atmos/pkg/ai/tools/permission"
"github.com/cloudposse/atmos/pkg/auth"
cfg "github.com/cloudposse/atmos/pkg/config"
"github.com/cloudposse/atmos/pkg/dependencies"
log "github.com/cloudposse/atmos/pkg/logger"
mcpclient "github.com/cloudposse/atmos/pkg/mcp/client"
"github.com/cloudposse/atmos/pkg/mcp/router"
"github.com/cloudposse/atmos/pkg/schema"
"github.com/cloudposse/atmos/pkg/ui"
)
// aiToolsResult holds the result of AI tools initialization.
type aiToolsResult struct {
Registry *tools.Registry
Executor *tools.Executor
MCPMgr *mcpclient.Manager
}
// initializeAIToolsAndExecutor initializes the AI tool registry and executor.
// Passing mcpServerNames filters which MCP servers to start (empty or nil = auto-route or all).
// The question parameter is used for automatic routing when mcpServerNames is empty or nil.
func initializeAIToolsAndExecutor(atmosConfig *schema.AtmosConfiguration, mcpServerNames []string, question string) (*aiToolsResult, error) {
if !atmosConfig.AI.Tools.Enabled {
return nil, errUtils.ErrAIToolsDisabled
}
log.Debug("Initializing AI tools")
// Create tool registry.
registry := tools.NewRegistry()
// Register all Atmos tools (components, stacks, validation, etc.).
// Pass nil for LSP manager as it's not initialized in the command layer.
if err := atmosTools.RegisterTools(registry, atmosConfig, nil); err != nil {
log.Warnf("Failed to register Atmos tools: %v", err)
}
// Register external MCP server tools (filtered by routing).
// Skip for CLI providers — they handle MCP via provider-specific pass-through.
var mcpMgr *mcpclient.Manager
if !isCLIProvider(atmosConfig.AI.DefaultProvider) {
mcpMgr = registerMCPServerTools(registry, atmosConfig, mcpServerNames, question)
}
ui.Info(fmt.Sprintf("AI tools initialized: %d total", registry.Count()))
ui.Info(fmt.Sprintf("AI provider: %s", atmosConfig.AI.DefaultProvider))
// Initialize permission cache for persistent decisions.
permCache, err := permission.NewPermissionCache(atmosConfig.BasePath)
if err != nil {
log.Warnf("Failed to initialize permission cache: %v", err)
// Continue without cache - will prompt every time.
permCache = nil
}
// Create permission checker with cache-aware prompter.
permConfig := &permission.Config{
Mode: getPermissionMode(atmosConfig),
AllowedTools: atmosConfig.AI.Tools.AllowedTools,
RestrictedTools: atmosConfig.AI.Tools.RestrictedTools,
BlockedTools: atmosConfig.AI.Tools.BlockedTools,
YOLOMode: atmosConfig.AI.Tools.YOLOMode,
}
var prompter permission.Prompter
if permCache != nil {
prompter = permission.NewCLIPrompterWithCache(permCache)
} else {
prompter = permission.NewCLIPrompter()
}
permChecker := permission.NewChecker(permConfig, prompter)
// Create tool executor.
executor := tools.NewExecutor(registry, permChecker, tools.DefaultTimeout)
log.Debug("Tool executor initialized")
return &aiToolsResult{
Registry: registry,
Executor: executor,
MCPMgr: mcpMgr,
}, nil
}
// registerMCPServerTools registers external MCP server tools with toolchain resolution,
// auth credential injection, and optional server routing.
func registerMCPServerTools(registry *tools.Registry, atmosConfig *schema.AtmosConfiguration, mcpServerNames []string, question string) *mcpclient.Manager {
if len(atmosConfig.MCP.Servers) == 0 {
return nil
}
// Select which servers to start.
selectedServers := selectMCPServers(atmosConfig, mcpServerNames, question)
if len(selectedServers) == 0 {
return nil
}
// Create a filtered copy of the config for RegisterMCPTools.
filteredConfig := *atmosConfig
filteredConfig.MCP.Servers = selectedServers
toolchain := resolveToolchain(atmosConfig)
authProvider := resolveAuthProvider(&filteredConfig)
mgr, err := mcpclient.RegisterMCPTools(registry, &filteredConfig, authProvider, toolchain)
if err != nil {
ui.Error(fmt.Sprintf("Failed to initialize MCP servers: %v", err))
}
return mgr
}
// selectMCPServers determines which MCP servers to start based on:
// 1. Manual override via --mcp flag (mcpServerNames).
// 2. Two-pass AI routing using a fast model.
// 3. All servers (fallback).
func selectMCPServers(atmosConfig *schema.AtmosConfiguration, mcpServerNames []string, question string) map[string]schema.MCPServerConfig {
servers := atmosConfig.MCP.Servers
// Manual override via --mcp flag.
if len(mcpServerNames) > 0 {
return selectManualServers(servers, mcpServerNames)
}
// Single server — no routing needed.
if len(servers) <= 1 {
return servers
}
// Routing disabled in config.
if !atmosConfig.MCP.Routing.IsEnabled() {
return servers
}
// No question available (e.g., chat mode) — start all.
if question == "" {
return servers
}
// Two-pass routing with configured AI provider.
return selectRoutedServers(atmosConfig, servers, question)
}
// selectManualServers filters servers by the --mcp flag, warning about unknown names.
func selectManualServers(servers map[string]schema.MCPServerConfig, mcpServerNames []string) map[string]schema.MCPServerConfig {
filtered := filterServersByName(servers, mcpServerNames)
for _, name := range mcpServerNames {
if _, ok := servers[name]; !ok {
ui.Warning(fmt.Sprintf("MCP server `%s` not found in configuration (available: %s)",
name, strings.Join(sortedServerNames(servers), ", ")))
}
}
if len(filtered) > 0 {
ui.Info(fmt.Sprintf("MCP servers selected via --mcp flag: %s", strings.Join(sortedServerNames(filtered), ", ")))
}
return filtered
}
// selectRoutedServers uses the AI provider to select relevant servers, with validation.
func selectRoutedServers(atmosConfig *schema.AtmosConfiguration, servers map[string]schema.MCPServerConfig, question string) map[string]schema.MCPServerConfig {
selected := routeWithAI(atmosConfig, question)
if len(selected) == 0 {
return servers
}
filtered := filterServersByName(servers, selected)
if len(filtered) == 0 {
ui.Warning("MCP routing returned no valid server names, starting all servers")
return servers
}
if len(filtered) != len(selected) {
ui.Warning(fmt.Sprintf("MCP routing returned %d unknown server name(s), using %d valid",
len(selected)-len(filtered), len(filtered)))
}
ui.Info(fmt.Sprintf("MCP routing selected %d of %d servers: %s",
len(filtered), len(servers), strings.Join(sortedServerNames(filtered), ", ")))
return filtered
}
// routeWithAI uses a fast model to select relevant MCP servers for a question.
func routeWithAI(atmosConfig *schema.AtmosConfiguration, question string) []string {
client, err := createRoutingClient(atmosConfig)
if err != nil {
log.Debug("Failed to create routing client, starting all servers", "error", err)
return nil
}
// Build server info list in deterministic order for consistent routing prompts.
var serverInfos []router.ServerInfo
for _, name := range sortedServerNames(atmosConfig.MCP.Servers) {
cfg := atmosConfig.MCP.Servers[name]
serverInfos = append(serverInfos, router.ServerInfo{
Name: name,
Description: cfg.Description,
})
}
ctx, cancel := context.WithTimeout(context.Background(), router.DefaultTimeout)
defer cancel()
return router.Route(ctx, client, question, serverInfos)
}
// createRoutingClient creates an AI client for the routing step.
// Uses the same provider and model the user already configured — no extra model config needed.
// Only overrides max_tokens to keep routing responses small.
func createRoutingClient(atmosConfig *schema.AtmosConfiguration) (router.MessageSender, error) {
routingConfig := *atmosConfig
// Override max_tokens for routing (responses are just a JSON array of server names).
provider := atmosConfig.AI.DefaultProvider
if provider == "" {
provider = "anthropic"
}
// Deep-copy the provider map to avoid mutating the original config.
if atmosConfig.AI.Providers != nil {
routingConfig.AI.Providers = make(map[string]*schema.AIProviderConfig, len(atmosConfig.AI.Providers))
for k, v := range atmosConfig.AI.Providers {
if v != nil {
copied := *v
routingConfig.AI.Providers[k] = &copied
}
}
if existing, ok := routingConfig.AI.Providers[provider]; ok && existing != nil {
existing.MaxTokens = router.DefaultMaxTokens()
}
}
return ai.NewClient(&routingConfig)
}
// sortedServerNames returns server names sorted alphabetically.
func sortedServerNames(servers map[string]schema.MCPServerConfig) []string {
names := make([]string, 0, len(servers))
for name := range servers {
names = append(names, name)
}
sort.Strings(names)
return names
}
// filterServersByName returns only servers whose names are in the given list.
func filterServersByName(servers map[string]schema.MCPServerConfig, names []string) map[string]schema.MCPServerConfig {
filtered := make(map[string]schema.MCPServerConfig, len(names))
for _, name := range names {
if cfg, ok := servers[name]; ok {
filtered[name] = cfg
}
}
return filtered
}
// resolveToolchain attempts to create a toolchain resolver from .tool-versions or component deps.
func resolveToolchain(atmosConfig *schema.AtmosConfiguration) mcpclient.ToolchainResolver {
// Load tool dependencies from .tool-versions so uvx/npx are resolved from the toolchain.
deps, depsErr := dependencies.LoadToolVersionsDependencies(atmosConfig)
if depsErr == nil && len(deps) > 0 {
tenv, tenvErr := dependencies.NewEnvironmentFromDeps(atmosConfig, deps)
if tenvErr == nil && tenv != nil {
return tenv
}
log.Debug("Failed to create environment from .tool-versions deps", "error", tenvErr)
}
// Fall back to component-based resolution.
tenv, tenvErr := dependencies.ForComponent(atmosConfig, "terraform", nil, nil)
if tenvErr == nil && tenv != nil {
return tenv
}
log.Debug("Toolchain resolution failed, MCP servers will use system PATH", "error", tenvErr)
return nil
}
// resolveAuthProvider creates an auth provider if any MCP server needs credentials.
func resolveAuthProvider(atmosConfig *schema.AtmosConfiguration) mcpclient.AuthEnvProvider {
if !serversNeedAuth(atmosConfig.MCP.Servers) {
return nil
}
mgr, err := auth.CreateAndAuthenticateManagerWithAtmosConfig(
"", &atmosConfig.Auth, cfg.IdentityFlagSelectValue, atmosConfig,
)
if err != nil {
ui.Error(fmt.Sprintf("Failed to create auth manager for MCP servers: %v", err))
return nil
}
return mgr
}
// cliProviders lists providers that invoke a local CLI binary as a subprocess.
// These providers handle MCP via provider-specific pass-through, not via the Atmos tool registry.
var cliProviders = map[string]bool{
"claude-code": true,
"codex-cli": true,
"gemini-cli": true,
}
// isCLIProvider returns true if the provider invokes a local CLI binary.
func isCLIProvider(providerName string) bool {
return cliProviders[providerName]
}
// serversNeedAuth returns true if any configured MCP server has identity set.
func serversNeedAuth(servers map[string]schema.MCPServerConfig) bool {
for _, s := range servers {
if s.Identity != "" {
return true
}
}
return false
}