Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tools/azsdk-cli/Azure.Sdk.Tools.Cli.Contract/MCPTool.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.CommandLine;
using System.CommandLine.Invocation;

Expand All @@ -23,7 +25,7 @@ public void SetFailure(int exitCode = 1)
ExitCode = exitCode;
}

public CommandGroup[] CommandHierarchy { get; set; } = Array.Empty<CommandGroup>();
public CommandGroup[] CommandHierarchy { get; set; } = [];

public abstract Command GetCommand();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
<PackageReference Include="Azure.Identity" Version="1.14.2" />
<PackageReference Include="LibGit2Sharp" Version="0.31.0" />
<PackageReference Include="Octokit" Version="14.0.0" />
<PackageReference Include="ModelContextProtocol" Version="0.1.0-preview.11" />
<PackageReference Include="ModelContextProtocol" Version="0.3.0-preview.4" />
<PackageReference Include="ModelContextProtocol.AspNetCore" Version="0.1.0-preview.11" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="9.0.3" />
<PackageReference Include="Microsoft.Extensions.Hosting" Version="9.0.3" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public static class SharedOptions
IsRequired = false,
};

public static Option<string> PackagePath = new(["--package-path", "-p"], "Path to the package directory to check")
{
IsRequired = true
public static Option<string> PackagePath = new(["--package-path", "-p"], "Path to the package directory to check")
{
IsRequired = true
};

public static (string, bool) GetGlobalOptionValues(string[] args)
Expand Down
10 changes: 5 additions & 5 deletions tools/azsdk-cli/Azure.Sdk.Tools.Cli/Program.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.CommandLine;
using System.CommandLine.Builder;
using System.CommandLine.Parsing;
Expand Down Expand Up @@ -78,19 +80,17 @@ public static WebApplicationBuilder CreateAppBuilder(string[] args)

// register common services
ServiceRegistrations.RegisterCommonServices(builder.Services);

// register MCP tools
ServiceRegistrations.RegisterInstrumentedMcpTools(builder.Services, args);

builder.WebHost.ConfigureKestrel(options =>
{
options.Listen(System.Net.IPAddress.Loopback, 0); // 0 = dynamic port
});

var toolTypes = SharedOptions.GetFilteredToolTypes(args);

builder.Services
.AddMcpServer()
.WithStdioServerTransport()
.WithTools(toolTypes);
.WithStdioServerTransport();

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.Reflection;
using System.Text.Json;
using Microsoft.Extensions.Azure;
using ModelContextProtocol.Server;
using Azure.AI.OpenAI;
using Azure.Sdk.Tools.Cli.Helpers;
using Azure.Sdk.Tools.Cli.Commands;
using Azure.Sdk.Tools.Cli.Microagents;
using Microsoft.Extensions.Azure;
using Azure.Sdk.Tools.Cli.Helpers;
using Azure.Sdk.Tools.Cli.Tools;

namespace Azure.Sdk.Tools.Cli.Services
{
public static class ServiceRegistrations
{ /// <summary>
/// This is the function that defines all of the services available to any of the MCPTool instantiations. This
/// same collection modification is run within the HostServerTool::CreateAppBuilder.
/// </summary>
/// <param name="services"></param>
/// todo: make this use reflection to populate itself with all of our services and helpers
{
/// <summary>
/// This is the function that defines all of the services available to any of the MCPTool instantiations. This
/// same collection modification is run within the HostServerTool::CreateAppBuilder.
/// </summary>
/// <param name="services"></param>
/// todo: make this use reflection to populate itself with all of our services and helpers
public static void RegisterCommonServices(IServiceCollection services)
{
// Services
Expand Down Expand Up @@ -70,5 +76,37 @@ public static void RegisterCommonServices(IServiceCollection services)
});
});
}

public static void RegisterInstrumentedMcpTools(IServiceCollection services, string[] args)
{
JsonSerializerOptions? serializerOptions = null;
var toolTypes = SharedOptions.GetFilteredToolTypes(args);

foreach (var toolType in toolTypes)
{
if (toolType is null)
{
continue;
}

foreach (var toolMethod in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance))
{
if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null)
{
services.AddSingleton((Func<IServiceProvider, McpServerTool>)(services =>
{
var options = new McpServerToolCreateOptions { Services = services, SerializerOptions = serializerOptions };
var innerTool = toolMethod.IsStatic
? McpServerTool.Create(toolMethod, options: options)
: McpServerTool.Create(toolMethod, r => ActivatorUtilities.CreateInstance(r.Services, toolType), options);

var loggerFactory = services.GetRequiredService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger(toolType);
return new InstrumentedTool(logger, innerTool, toolMethod.Name);
}));
}
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.Text.Json;

namespace Azure.Sdk.Tools.Cli.Services;

public static class TelemetryService
{
private const int MaxInstrumentationUploadTime = 5;

public static void InstrumentationBefore(ILogger logger, string toolName, object? args, CancellationToken ct)
{
using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(MaxInstrumentationUploadTime));
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, timeoutCts.Token);
Task.Run(() => _instrumentationBefore(logger, toolName, args), linkedCts.Token);
}

private static void _instrumentationBefore(ILogger logger, string toolName, object? args)
{
// TODO: replace with app insights
logger.LogDebug("[tool req] {toolName} [args] {args}", toolName, args);
}

public static void InstrumentationAfter(ILogger logger, string toolName, object? result, CancellationToken ct)
{
using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(MaxInstrumentationUploadTime));
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, timeoutCts.Token);
Task.Run(() => _instrumentationAfter(logger, toolName, result), linkedCts.Token);
}

private static void _instrumentationAfter(ILogger logger, string toolName, object? result)
{
var serialized = "SERIALIZER ERROR";
try
{
serialized = JsonSerializer.Serialize(result);
}
catch (Exception ex)
{
logger.LogError(ex, "Error serializing tool response for instrumentation");
}

// TODO: replace with app insights
logger.LogDebug("[tool resp] {toolName} [result] {result}", toolName, serialized);
}

public static void InstrumentationError(ILogger logger, string toolName, Exception ex, CancellationToken ct)
{
using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(MaxInstrumentationUploadTime));
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, timeoutCts.Token);
// TODO: replace with app insights
Task.Run(() => _instrumentationError(logger, toolName, ex), linkedCts.Token);
}

private static void _instrumentationError(ILogger logger, string toolName, Exception ex)
{
logger.LogError(ex, "[tool error] {toolName} [error] {error}", toolName, ex.Message);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using Azure.Sdk.Tools.Cli.Services;

namespace Azure.Sdk.Tools.Cli.Tools;

public class InstrumentedTool(ILogger logger, McpServerTool innerTool, string toolName) : DelegatingMcpServerTool(innerTool)
Comment thread
benbp marked this conversation as resolved.
{
public override async ValueTask<CallToolResult> InvokeAsync(RequestContext<CallToolRequestParams> request, CancellationToken ct = default)
{
try
{
TelemetryService.InstrumentationBefore(logger, toolName, request.Params?.Arguments, ct);
var result = await base.InvokeAsync(request, ct);
TelemetryService.InstrumentationAfter(logger, toolName, result, ct);
return result;
}
catch (Exception ex)
{
TelemetryService.InstrumentationError(logger, toolName, ex, ct);
throw;
}
}
}
Loading