forked from Azure/azure-sdk-tools
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSessionExecutor.cs
More file actions
183 lines (165 loc) · 6.9 KB
/
SessionExecutor.cs
File metadata and controls
183 lines (165 loc) · 6.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System.Diagnostics;
using GitHub.Copilot.SDK;
using Azure.Sdk.Tools.Cli.Benchmarks.Models;
namespace Azure.Sdk.Tools.Cli.Benchmarks.Infrastructure;
/// <summary>
/// Executes benchmark scenarios using the GitHub Copilot SDK.
/// </summary>
public class SessionExecutor : IDisposable
{
private CopilotClient? _client;
/// <summary>
/// Executes a benchmark scenario with the provided configuration.
/// </summary>
/// <param name="config">The execution configuration.</param>
/// <returns>The result of the execution including timing and tool call information.</returns>
public async Task<ExecutionResult> ExecuteAsync(ExecutionConfig config)
{
var stopwatch = Stopwatch.StartNew();
var toolCalls = new List<ToolCallRecord>();
var pendingTimestamps = new Dictionary<string, double>();
try
{
if (_client != null)
{
throw new InvalidOperationException("ExecuteAsync can only be called once per SessionExecutor instance. Create a new SessionExecutor for each execution.");
}
_client = new CopilotClient();
// Build MCP server config - try explicit path first, then load from workspace
var mcpServers = BuildMcpServers(config.AzsdkMcpPath)
?? await McpConfigLoader.LoadFromWorkspaceAsync(config.WorkingDirectory);
var sessionConfig = new SessionConfig
{
WorkingDirectory = config.WorkingDirectory,
McpServers = mcpServers,
Model = config.Model,
Streaming = true,
// Auto-approve all permission requests (file edits, creates, etc.)
OnPermissionRequest = (request, invocation) =>
{
return Task.FromResult(new PermissionRequestResult
{
Kind = "approved"
});
},
Hooks = new SessionHooks
{
OnPreToolUse = (input, invocation) =>
{
Console.WriteLine($"Model is calling tool: {input.ToolName}");
config.OnActivity?.Invoke($"Calling tool: {input.ToolName}");
pendingTimestamps[input.ToolName] = input.Timestamp;
return Task.FromResult<PreToolUseHookOutput?>(null);
},
OnPostToolUse = (input, invocation) =>
{
double? durationMs = pendingTimestamps.TryGetValue(input.ToolName, out var startTs)
? input.Timestamp - startTs
: null;
var mcpServerName = input.ToolName.Contains("__")
? input.ToolName.Split("__", 2)[0]
: null;
toolCalls.Add(new ToolCallRecord
{
ToolName = input.ToolName,
ToolArgs = input.ToolArgs,
ToolResult = input.ToolResult,
DurationMs = durationMs,
McpServerName = mcpServerName,
Timestamp = startTs,
});
if (input.ToolName == "skill")
{
toolCalls.Add($"{input.ToolName} {input.ToolArgs?.ToString()}");
}
else
{
toolCalls.Add(input.ToolName);
}
return Task.FromResult<PostToolUseHookOutput?>(null);
}
},
// Auto-respond to ask_user with a simple response
OnUserInputRequest = (request, invocation) =>
{
Console.WriteLine($"Model requested user input with prompt: {request.Question}");
return Task.FromResult(new UserInputResponse
{
Answer = "Please proceed with your best judgment.",
WasFreeform = true
});
}
};
await using var session = await _client.CreateSessionAsync(sessionConfig);
if (config.Verbose)
{
SessionConfigHelper.ConfigureAgentActivityLogging(session);
}
// Send prompt and wait for completion
var messageOptions = new MessageOptions { Prompt = config.Prompt };
await session.SendAndWaitAsync(messageOptions, config.Timeout);
// Get messages for debugging
var messages = await session.GetMessagesAsync();
// stream version
stopwatch.Stop();
return new ExecutionResult
{
Completed = true,
Duration = stopwatch.Elapsed,
Messages = messages.Cast<object>().ToList(),
ToolCalls = toolCalls
};
}
catch (Exception ex)
{
stopwatch.Stop();
return new ExecutionResult
{
Completed = false,
Error = ex.Message,
Duration = stopwatch.Elapsed,
ToolCalls = toolCalls
};
}
}
/// <summary>
/// Builds MCP server configuration for the azsdk MCP server.
/// </summary>
/// <param name="azsdkPath">Optional path to the azsdk MCP server executable.</param>
/// <returns>MCP server configuration dictionary, or null if no path is available.</returns>
private static Dictionary<string, object>? BuildMcpServers(string? azsdkPath)
{
// Priority: config param > env var > null (let SDK use repo config)
var path = azsdkPath ?? Environment.GetEnvironmentVariable("AZSDK_MCP_PATH");
if (path == null)
{
return null;
}
return new Dictionary<string, object>
{
["azsdk"] = new McpLocalServerConfig
{
Type = "local",
Command = path,
Args = ["mcp"],
Tools = ["*"],
Env = new Dictionary<string, string>
{
// Set any necessary environment variables for the MCP server here
// For example: ["AZURE_SDK_KB_ENDPOINT"] = "http://localhost:8088"
["AZURE_SDK_KB_ENDPOINT"] = BenchmarkDefaults.DefaultAzureKnowledgeBaseEndpoint
}
}
};
}
/// <summary>
/// Disposes of the Copilot client and releases resources.
/// </summary>
public void Dispose()
{
_client?.Dispose();
GC.SuppressFinalize(this);
}
}