Skip to content

Commit 5a8ace8

Browse files
committed
Cleanup agent based on initial PR feedback
1 parent b97e810 commit 5a8ace8

34 files changed

Lines changed: 422 additions & 1003 deletions

cli/azd/cmd/init.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,6 @@ func (i *initAction) initAppWithAgent(ctx context.Context) error {
384384
}
385385

386386
defer azdAgent.Stop()
387-
agentThoughts := azdAgent.Thoughts()
388387

389388
type initStep struct {
390389
Name string

cli/azd/internal/agent/agent.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@ type agentBase struct {
2525
callbacksHandler callbacks.Handler
2626
thoughtChan chan logging.Thought
2727
cleanupFunc AgentCleanup
28+
maxIterations int
2829
}
2930

31+
// AgentCleanup is a function that performs cleanup tasks for an agent.
3032
type AgentCleanup func() error
3133

34+
// Agent represents an AI agent that can execute tools and interact with language models.
3235
type Agent interface {
36+
// SendMessage sends a message to the agent and returns the response
3337
SendMessage(ctx context.Context, args ...string) (string, error)
38+
39+
// Stop terminates the agent and performs any necessary cleanup
3440
Stop() error
3541
}
3642

@@ -53,6 +59,13 @@ func WithDebug(debug bool) AgentCreateOption {
5359
}
5460
}
5561

62+
// WithMaxIterations returns an option that sets the maximum number of iterations for the agent
63+
func WithMaxIterations(maxIterations int) AgentCreateOption {
64+
return func(agent *agentBase) {
65+
agent.maxIterations = maxIterations
66+
}
67+
}
68+
5669
// WithDefaultModel returns an option that sets the default language model for the agent
5770
func WithDefaultModel(model llms.Model) AgentCreateOption {
5871
return func(agent *agentBase) {
@@ -74,12 +87,20 @@ func WithCallbacksHandler(handler callbacks.Handler) AgentCreateOption {
7487
}
7588
}
7689

90+
// WithThoughtChannel returns an option that sets the thought channel for the agent
7791
func WithThoughtChannel(thoughtChan chan logging.Thought) AgentCreateOption {
7892
return func(agent *agentBase) {
7993
agent.thoughtChan = thoughtChan
8094
}
8195
}
8296

97+
// WithCleanup returns an option that sets the cleanup function for the agent
98+
func WithCleanup(cleanupFunc AgentCleanup) AgentCreateOption {
99+
return func(agent *agentBase) {
100+
agent.cleanupFunc = cleanupFunc
101+
}
102+
}
103+
83104
// toolNames returns a comma-separated string of all tool names in the provided slice
84105
func toolNames(tools []common.AnnotatedTool) string {
85106
var tn strings.Builder

cli/azd/internal/agent/agent_factory.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ import (
1313
"github.com/azure/azure-dev/cli/azd/pkg/llm"
1414
)
1515

16+
// AgentFactory is responsible for creating agent instances
1617
type AgentFactory struct {
1718
consentManager consent.ConsentManager
1819
llmManager *llm.Manager
1920
console input.Console
2021
}
2122

23+
// NewAgentFactory creates a new instance of AgentFactory
2224
func NewAgentFactory(
2325
consentManager consent.ConsentManager,
2426
console input.Console,
@@ -31,13 +33,16 @@ func NewAgentFactory(
3133
}
3234
}
3335

36+
// CreateAgent creates a new agent instance
3437
func (f *AgentFactory) Create(opts ...AgentCreateOption) (Agent, error) {
38+
// Create a daily log file for all agent activity
3539
fileLogger, loggerCleanup, err := logging.NewFileLoggerDefault()
3640
if err != nil {
3741
defer loggerCleanup()
3842
return nil, err
3943
}
4044

45+
// Create a channel for logging thoughts & actions
4146
thoughtChan := make(chan logging.Thought)
4247
thoughtHandler := logging.NewThoughtLogger(thoughtChan)
4348
chainedHandler := logging.NewChainedHandler(fileLogger, thoughtHandler)
@@ -47,13 +52,16 @@ func (f *AgentFactory) Create(opts ...AgentCreateOption) (Agent, error) {
4752
return loggerCleanup()
4853
}
4954

55+
// Default model gets the chained handler to expose the UX experience for the agent
5056
defaultModelContainer, err := f.llmManager.GetDefaultModel(llm.WithLogger(chainedHandler))
5157
if err != nil {
5258
defer cleanup()
5359
return nil, err
5460
}
5561

56-
samplingModelContainer, err := f.llmManager.GetDefaultModel(llm.WithLogger(chainedHandler))
62+
// Sampling model only gets the file logger to output sampling actions
63+
// We don't need UX for sampling requests right now
64+
samplingModelContainer, err := f.llmManager.GetDefaultModel(llm.WithLogger(fileLogger))
5765
if err != nil {
5866
defer cleanup()
5967
return nil, err
@@ -66,7 +74,8 @@ func (f *AgentFactory) Create(opts ...AgentCreateOption) (Agent, error) {
6674
samplingModelContainer,
6775
)
6876

69-
toolLoaders := []localtools.ToolLoader{
77+
// Loads build-in tools & any referenced MCP servers
78+
toolLoaders := []common.ToolLoader{
7079
localtools.NewLocalToolsLoader(),
7180
mcptools.NewMcpToolsLoader(samplingHandler),
7281
}
@@ -95,22 +104,23 @@ func (f *AgentFactory) Create(opts ...AgentCreateOption) (Agent, error) {
95104
}
96105
}
97106

107+
// Wraps all tools in consent workflow
98108
protectedTools := f.consentManager.WrapTools(allTools)
99109

110+
// Finalize agent creation options
100111
allOptions := []AgentCreateOption{}
101112
allOptions = append(allOptions, opts...)
102113
allOptions = append(allOptions,
103114
WithCallbacksHandler(chainedHandler),
104115
WithThoughtChannel(thoughtChan),
105116
WithTools(protectedTools...),
117+
WithCleanup(cleanup),
106118
)
107119

108120
azdAgent, err := NewConversationalAzdAiAgent(defaultModelContainer.Model, allOptions...)
109121
if err != nil {
110122
return nil, err
111123
}
112124

113-
azdAgent.cleanupFunc = cleanup
114-
115125
return azdAgent, nil
116126
}

cli/azd/internal/agent/consent/checker.go

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ import (
1313
"github.com/mark3labs/mcp-go/mcp"
1414
)
1515

16+
var ErrToolExecutionDenied = fmt.Errorf("tool execution denied by user")
17+
var ErrSamplingDenied = fmt.Errorf("sampling denied by user")
18+
1619
// ConsentChecker provides shared consent checking logic for different tool types
1720
type ConsentChecker struct {
1821
consentMgr ConsentManager
@@ -66,6 +69,49 @@ func (cc *ConsentChecker) CheckSamplingConsent(
6669
return cc.consentMgr.CheckConsent(ctx, consentRequest)
6770
}
6871

72+
// PromptAndGrantConsent shows consent prompt and grants permission based on user choice
73+
func (cc *ConsentChecker) PromptAndGrantConsent(
74+
ctx context.Context,
75+
toolName, toolDesc string,
76+
annotations mcp.ToolAnnotation,
77+
) error {
78+
toolId := fmt.Sprintf("%s/%s", cc.serverName, toolName)
79+
80+
choice, err := cc.promptForToolConsent(ctx, toolName, toolDesc, annotations)
81+
if err != nil {
82+
return err
83+
}
84+
85+
if choice == "deny" {
86+
return ErrToolExecutionDenied
87+
}
88+
89+
// Grant consent based on user choice
90+
return cc.grantConsentFromChoice(ctx, toolId, choice, OperationTypeTool)
91+
}
92+
93+
// PromptAndGrantSamplingConsent shows sampling consent prompt and grants permission based on user choice
94+
func (cc *ConsentChecker) PromptAndGrantSamplingConsent(
95+
ctx context.Context,
96+
toolName, toolDesc string,
97+
) error {
98+
toolId := fmt.Sprintf("%s/%s", cc.serverName, toolName)
99+
100+
choice, err := cc.promptForSamplingConsent(ctx, toolName, toolDesc)
101+
if err != nil {
102+
return fmt.Errorf("sampling consent prompt failed: %w", err)
103+
}
104+
105+
if choice == "deny" {
106+
return ErrSamplingDenied
107+
}
108+
109+
// Grant sampling consent based on user choice
110+
return cc.grantConsentFromChoice(ctx, toolId, choice, OperationTypeSampling)
111+
}
112+
113+
// Private Struct Methods
114+
69115
// formatToolDescriptionWithAnnotations creates a formatted description with tool annotations as bullet points
70116
func (cc *ConsentChecker) formatToolDescriptionWithAnnotations(
71117
toolDesc string,
@@ -128,27 +174,6 @@ func (cc *ConsentChecker) formatToolDescriptionWithAnnotations(
128174
return description
129175
}
130176

131-
// PromptAndGrantConsent shows consent prompt and grants permission based on user choice
132-
func (cc *ConsentChecker) PromptAndGrantConsent(
133-
ctx context.Context,
134-
toolName, toolDesc string,
135-
annotations mcp.ToolAnnotation,
136-
) error {
137-
toolId := fmt.Sprintf("%s/%s", cc.serverName, toolName)
138-
139-
choice, err := cc.promptForToolConsent(ctx, toolName, toolDesc, annotations)
140-
if err != nil {
141-
return fmt.Errorf("consent prompt failed: %w", err)
142-
}
143-
144-
if choice == "deny" {
145-
return fmt.Errorf("tool execution denied by user")
146-
}
147-
148-
// Grant consent based on user choice
149-
return cc.grantConsentFromChoice(ctx, toolId, choice, OperationTypeTool)
150-
}
151-
152177
// promptForToolConsent shows an interactive consent prompt and returns the user's choice
153178
func (cc *ConsentChecker) promptForToolConsent(
154179
ctx context.Context,
@@ -361,26 +386,6 @@ func (cc *ConsentChecker) grantConsentFromChoice(
361386
return cc.consentMgr.GrantConsent(ctx, rule)
362387
}
363388

364-
// PromptAndGrantSamplingConsent shows sampling consent prompt and grants permission based on user choice
365-
func (cc *ConsentChecker) PromptAndGrantSamplingConsent(
366-
ctx context.Context,
367-
toolName, toolDesc string,
368-
) error {
369-
toolId := fmt.Sprintf("%s/%s", cc.serverName, toolName)
370-
371-
choice, err := cc.promptForSamplingConsent(ctx, toolName, toolDesc)
372-
if err != nil {
373-
return fmt.Errorf("sampling consent prompt failed: %w", err)
374-
}
375-
376-
if choice == "deny" {
377-
return fmt.Errorf("sampling denied by user")
378-
}
379-
380-
// Grant sampling consent based on user choice
381-
return cc.grantConsentFromChoice(ctx, toolId, choice, OperationTypeSampling)
382-
}
383-
384389
// promptForSamplingConsent shows an interactive sampling consent prompt and returns the user's choice
385390
func (cc *ConsentChecker) promptForSamplingConsent(
386391
ctx context.Context,

cli/azd/internal/agent/conversational_agent.go

Lines changed: 8 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
package agent
55

66
import (
7-
"bufio"
87
"context"
98
_ "embed"
109
"fmt"
11-
"os"
1210
"strings"
1311
"time"
1412

@@ -34,7 +32,7 @@ type ConversationalAzdAiAgent struct {
3432
// NewConversationalAzdAiAgent creates a new conversational agent with memory, tool loading,
3533
// and MCP sampling capabilities. It filters out excluded tools and configures the agent
3634
// for interactive conversations with a high iteration limit for complex tasks.
37-
func NewConversationalAzdAiAgent(llm llms.Model, opts ...AgentCreateOption) (*ConversationalAzdAiAgent, error) {
35+
func NewConversationalAzdAiAgent(llm llms.Model, opts ...AgentCreateOption) (Agent, error) {
3836
azdAgent := &ConversationalAzdAiAgent{
3937
agentBase: &agentBase{
4038
defaultModel: llm,
@@ -46,6 +44,11 @@ func NewConversationalAzdAiAgent(llm llms.Model, opts ...AgentCreateOption) (*Co
4644
opt(azdAgent.agentBase)
4745
}
4846

47+
// Default max iterations
48+
if azdAgent.maxIterations <= 0 {
49+
azdAgent.maxIterations = 100
50+
}
51+
4952
smartMemory := memory.NewConversationBuffer(
5053
memory.WithInputKey("input"),
5154
memory.WithOutputKey("output"),
@@ -74,7 +77,7 @@ func NewConversationalAzdAiAgent(llm llms.Model, opts ...AgentCreateOption) (*Co
7477

7578
// 5. Create executor without separate memory configuration since agent already has it
7679
executor := agents.NewExecutor(conversationAgent,
77-
agents.WithMaxIterations(100),
80+
agents.WithMaxIterations(azdAgent.maxIterations),
7881
agents.WithMemory(smartMemory),
7982
agents.WithCallbacksHandler(azdAgent.callbacksHandler),
8083
agents.WithReturnIntermediateSteps(),
@@ -86,62 +89,6 @@ func NewConversationalAzdAiAgent(llm llms.Model, opts ...AgentCreateOption) (*Co
8689

8790
// SendMessage processes a single message through the agent and returns the response
8891
func (aai *ConversationalAzdAiAgent) SendMessage(ctx context.Context, args ...string) (string, error) {
89-
return aai.runChain(ctx, strings.Join(args, "\n"))
90-
}
91-
92-
// StartConversation runs an interactive conversation loop with the agent.
93-
// It accepts an optional initial query and handles user input/output with proper formatting.
94-
// The conversation continues until the user types "exit" or "quit".
95-
func (aai *ConversationalAzdAiAgent) StartConversation(ctx context.Context, args ...string) (string, error) {
96-
// Handle initial query if provided
97-
var initialQuery string
98-
if len(args) > 0 {
99-
initialQuery = strings.Join(args, " ")
100-
}
101-
102-
scanner := bufio.NewScanner(os.Stdin)
103-
104-
for {
105-
var userInput string
106-
107-
if initialQuery != "" {
108-
userInput = initialQuery
109-
initialQuery = "" // Clear after first use
110-
color.Cyan("💬 You: %s\n", userInput)
111-
} else {
112-
fmt.Print(color.CyanString("\n💬 You: "))
113-
color.Set(color.FgCyan) // Set blue color for user input
114-
if !scanner.Scan() {
115-
color.Unset() // Reset color
116-
break // EOF or error
117-
}
118-
userInput = strings.TrimSpace(scanner.Text())
119-
color.Unset() // Reset color after input
120-
}
121-
122-
// Check for exit commands
123-
if userInput == "" {
124-
continue
125-
}
126-
127-
if strings.ToLower(userInput) == "exit" || strings.ToLower(userInput) == "quit" {
128-
fmt.Println("👋 Goodbye! Thanks for using azd Agent!")
129-
break
130-
}
131-
132-
// Process the query with the enhanced agent
133-
return aai.runChain(ctx, userInput)
134-
}
135-
136-
if err := scanner.Err(); err != nil {
137-
return "", fmt.Errorf("error reading input: %w", err)
138-
}
139-
140-
return "", nil
141-
}
142-
143-
// runChain executes a user query through the agent's chain with memory and returns the response
144-
func (aai *ConversationalAzdAiAgent) runChain(ctx context.Context, userInput string) (string, error) {
14592
thoughtsCtx, cancelCtx := context.WithCancel(ctx)
14693
cleanup, err := aai.renderThoughts(thoughtsCtx)
14794
if err != nil {
@@ -154,8 +101,7 @@ func (aai *ConversationalAzdAiAgent) runChain(ctx context.Context, userInput str
154101
cancelCtx()
155102
}()
156103

157-
// Execute with enhanced input - agent should automatically handle memory
158-
output, err := chains.Run(ctx, aai.executor, userInput)
104+
output, err := chains.Run(ctx, aai.executor, strings.Join(args, "\n"))
159105
if err != nil {
160106
return "", err
161107
}

0 commit comments

Comments
 (0)