diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Database/DatabaseQueryCommand.cs b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Database/DatabaseQueryCommand.cs index b40fe44dab..e529a8d844 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Database/DatabaseQueryCommand.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Database/DatabaseQueryCommand.cs @@ -12,8 +12,9 @@ namespace Azure.Mcp.Tools.Postgres.Commands.Database; -public sealed class DatabaseQueryCommand(ILogger logger) : BaseDatabaseCommand(logger) +public sealed class DatabaseQueryCommand(IPostgresService postgresService, ILogger logger) : BaseDatabaseCommand(logger) { + private readonly IPostgresService _postgresService = postgresService; private const string CommandTitle = "Query PostgreSQL Database"; public override string Id => "81a28bca-014c-4738-9e1a-654d77cb2dd8"; @@ -58,10 +59,9 @@ public override async Task ExecuteAsync(CommandContext context, try { - IPostgresService pgService = context.GetService() ?? throw new InvalidOperationException("PostgreSQL service is not available."); // Validate the query early to avoid sending unsafe SQL to the server. SqlQueryValidator.EnsureReadOnlySelect(options.Query); - List queryResult = await pgService.ExecuteQueryAsync( + List queryResult = await _postgresService.ExecuteQueryAsync( options.Subscription!, options.ResourceGroup!, options.AuthType!, diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Commands/PostgresListCommand.cs b/tools/Azure.Mcp.Tools.Postgres/src/Commands/PostgresListCommand.cs index 7f582443f1..670763bde4 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Commands/PostgresListCommand.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Commands/PostgresListCommand.cs @@ -13,8 +13,9 @@ namespace Azure.Mcp.Tools.Postgres.Commands; -public sealed class PostgresListCommand(ILogger logger) : BasePostgresCommand(logger) +public sealed class PostgresListCommand(IPostgresService postgresService, ILogger logger) : BasePostgresCommand(logger) { + private readonly IPostgresService _postgresService = postgresService; public override string Id => "8a12c3f4-2e5d-4b3a-9f2c-5e6d7f8a9b0c"; public override string Name => "list"; @@ -73,13 +74,11 @@ public override async Task ExecuteAsync(CommandContext context, var options = BindOptions(parseResult); - IPostgresService postgresService = context.GetService() ?? throw new InvalidOperationException("PostgreSQL service is not available."); - // Route based on provided parameters if (!string.IsNullOrEmpty(options.Database)) { // List tables in specified database - List tables = await postgresService.ListTablesAsync( + List tables = await _postgresService.ListTablesAsync( options.Subscription!, options.ResourceGroup!, options.AuthType!, @@ -96,7 +95,7 @@ public override async Task ExecuteAsync(CommandContext context, else if (!string.IsNullOrEmpty(options.Server)) { // List databases on specified server - List databases = await postgresService.ListDatabasesAsync( + List databases = await _postgresService.ListDatabasesAsync( options.Subscription!, options.ResourceGroup!, options.AuthType!, @@ -112,7 +111,7 @@ public override async Task ExecuteAsync(CommandContext context, else { // List all servers in the subscription (optionally scoped to a resource group) - List servers = await postgresService.ListServersAsync( + List servers = await _postgresService.ListServersAsync( options.Subscription!, options.ResourceGroup, cancellationToken); diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerConfigGetCommand.cs b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerConfigGetCommand.cs index bf918ba74a..a0b6de4110 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerConfigGetCommand.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerConfigGetCommand.cs @@ -9,8 +9,9 @@ namespace Azure.Mcp.Tools.Postgres.Commands.Server; -public sealed class ServerConfigGetCommand(ILogger logger) : BaseServerCommand(logger) +public sealed class ServerConfigGetCommand(IPostgresService postgresService, ILogger logger) : BaseServerCommand(logger) { + private readonly IPostgresService _postgresService = postgresService; private const string CommandTitle = "Get PostgreSQL Server Configuration"; public override string Id => "049a0d10-0a6e-4278-a0a3-15ce6b2e5ee1"; @@ -44,8 +45,7 @@ public override async Task ExecuteAsync(CommandContext context, try { - IPostgresService pgService = context.GetService() ?? throw new InvalidOperationException("PostgreSQL service is not available."); - var config = await pgService.GetServerConfigAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, cancellationToken); + var config = await _postgresService.GetServerConfigAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, cancellationToken); context.Response.Results = config?.Length > 0 ? ResponseResult.Create(new(config), PostgresJsonContext.Default.ServerConfigGetCommandResult) : null; diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamGetCommand.cs b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamGetCommand.cs index e78526da35..51df3dcff0 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamGetCommand.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamGetCommand.cs @@ -11,8 +11,9 @@ namespace Azure.Mcp.Tools.Postgres.Commands.Server; -public sealed class ServerParamGetCommand(ILogger logger) : BaseServerCommand(logger) +public sealed class ServerParamGetCommand(IPostgresService postgresService, ILogger logger) : BaseServerCommand(logger) { + private readonly IPostgresService _postgresService = postgresService; private const string CommandTitle = "Get PostgreSQL Server Parameter"; public override string Id => "af3a581d-ab64-4939-9765-974815d9c7be"; @@ -58,8 +59,7 @@ public override async Task ExecuteAsync(CommandContext context, try { - IPostgresService pgService = context.GetService() ?? throw new InvalidOperationException("PostgreSQL service is not available."); - var parameterValue = await pgService.GetServerParameterAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Param!, cancellationToken); + var parameterValue = await _postgresService.GetServerParameterAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Param!, cancellationToken); context.Response.Results = parameterValue?.Length > 0 ? ResponseResult.Create(new(parameterValue), PostgresJsonContext.Default.ServerParamGetCommandResult) : null; diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamSetCommand.cs b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamSetCommand.cs index 48f6201986..322f7f8b44 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamSetCommand.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Server/ServerParamSetCommand.cs @@ -11,8 +11,9 @@ namespace Azure.Mcp.Tools.Postgres.Commands.Server; -public sealed class ServerParamSetCommand(ILogger logger) : BaseServerCommand(logger) +public sealed class ServerParamSetCommand(IPostgresService postgresService, ILogger logger) : BaseServerCommand(logger) { + private readonly IPostgresService _postgresService = postgresService; private const string CommandTitle = "Set PostgreSQL Server Parameter"; public override string Id => "2134621b-518f-48ac-a66a-82c40fcb58bb"; @@ -60,8 +61,7 @@ public override async Task ExecuteAsync(CommandContext context, try { - IPostgresService pgService = context.GetService() ?? throw new InvalidOperationException("PostgreSQL service is not available."); - var result = await pgService.SetServerParameterAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Param!, options.Value!, cancellationToken); + var result = await _postgresService.SetServerParameterAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Param!, options.Value!, cancellationToken); context.Response.Results = !string.IsNullOrEmpty(result) ? ResponseResult.Create(new(result, options.Param!, options.Value!), PostgresJsonContext.Default.ServerParamSetCommandResult) : null; diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Table/TableSchemaGetCommand.cs b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Table/TableSchemaGetCommand.cs index df6785fea8..63af53a279 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Commands/Table/TableSchemaGetCommand.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Commands/Table/TableSchemaGetCommand.cs @@ -11,8 +11,9 @@ namespace Azure.Mcp.Tools.Postgres.Commands.Table; -public sealed class TableSchemaGetCommand(ILogger logger) : BaseDatabaseCommand(logger) +public sealed class TableSchemaGetCommand(IPostgresService postgresService, ILogger logger) : BaseDatabaseCommand(logger) { + private readonly IPostgresService _postgresService = postgresService; private const string CommandTitle = "Get PostgreSQL Table Schema"; public override string Id => "643a3497-44e1-4727-b3d6-c2e5dba6cab2"; @@ -56,8 +57,7 @@ public override async Task ExecuteAsync(CommandContext context, { - IPostgresService pgService = context.GetService() ?? throw new InvalidOperationException("PostgreSQL service is not available."); - List schema = await pgService.GetTableSchemaAsync( + List schema = await _postgresService.GetTableSchemaAsync( options.Subscription!, options.ResourceGroup!, options.AuthType!, diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Database/DatabaseQueryCommandTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Database/DatabaseQueryCommandTests.cs index 64dab121de..00baf93c4a 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Database/DatabaseQueryCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Database/DatabaseQueryCommandTests.cs @@ -45,7 +45,7 @@ public async Task ExecuteAsync_ReturnsQueryResults_WhenQueryIsValid() _postgresService.ExecuteQueryAsync("sub123", "rg1", AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "SELECT * FROM test;", Arg.Any()) .Returns(expectedResults); - var command = new DatabaseQueryCommand(_logger); + var command = new DatabaseQueryCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", $"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra, "--user", "user1", "--server", "server1", "--database", "db123", "--query", "SELECT * FROM test;"]); var context = new CommandContext(_serviceProvider); var response = await command.ExecuteAsync(context, args, TestContext.Current.CancellationToken); @@ -66,7 +66,7 @@ public async Task ExecuteAsync_ReturnsEmpty_WhenQueryFails() _postgresService.ExecuteQueryAsync("sub123", "rg1", AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "SELECT * FROM test;", Arg.Any()) .Returns([]); - var command = new DatabaseQueryCommand(_logger); + var command = new DatabaseQueryCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", $"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra, "--user", "user1", "--server", "server1", "--database", "db123", "--query", "SELECT * FROM test;"]); var context = new CommandContext(_serviceProvider); @@ -91,7 +91,7 @@ public async Task ExecuteAsync_ReturnsEmpty_WhenQueryFails() [InlineData("--query")] public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missingParameter) { - var command = new DatabaseQueryCommand(_logger); + var command = new DatabaseQueryCommand(_postgresService, _logger); var args = command.GetCommand().Parse(ArgBuilder.BuildArgs(missingParameter, ("--subscription", "sub123"), ("--resource-group", "rg1"), @@ -154,7 +154,7 @@ public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missin [InlineData("SELECT * FROM pg_user_mappings")] // FDW credential exposure public async Task ExecuteAsync_InvalidQuery_ValidationError(string badQuery) { - var command = new DatabaseQueryCommand(_logger); + var command = new DatabaseQueryCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -178,7 +178,7 @@ public async Task ExecuteAsync_InvalidQuery_ValidationError(string badQuery) public async Task ExecuteAsync_LongQuery_ValidationError() { var longSelect = "SELECT " + new string('a', 6000) + " FROM test"; // exceeds max length - var command = new DatabaseQueryCommand(_logger); + var command = new DatabaseQueryCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/PostgresListCommandTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/PostgresListCommandTests.cs index 522c69e437..53a1f5ec80 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/PostgresListCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/PostgresListCommandTests.cs @@ -39,7 +39,7 @@ public async Task ExecuteAsync_ListsServers_WhenNoServerOrDatabaseProvided() var expectedServers = new List { "postgres-server-1", "postgres-server-2", "postgres-server-3" }; _postgresService.ListServersAsync("sub123", "rg1", Arg.Any()).Returns(expectedServers); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1" @@ -67,7 +67,7 @@ public async Task ExecuteAsync_ListsAllServersInSubscription_WhenNoResourceGroup var expectedServers = new List { "postgres-server-1", "postgres-server-2" }; _postgresService.ListServersAsync("sub123", null, Arg.Any()).Returns(expectedServers); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123" ]); @@ -91,7 +91,7 @@ public async Task ExecuteAsync_ListsAllServersInSubscription_WhenNoResourceGroup [Fact] public async Task ExecuteAsync_ReturnsError_WhenServerProvidedWithoutUser() { - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--server", "server1" @@ -118,7 +118,7 @@ public async Task ExecuteAsync_ListsDatabases_WhenServerProvided() "server1", Arg.Any()).Returns(expectedDatabases); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -157,7 +157,7 @@ public async Task ExecuteAsync_ListsTables_WhenServerAndDatabaseProvided() "db1", Arg.Any()).Returns(expectedTables); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -188,7 +188,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenNoServersExist() { _postgresService.ListServersAsync("sub123", "rg1", Arg.Any()).Returns([]); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1" @@ -222,7 +222,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenNoDatabasesExist() "server1", Arg.Any()).Returns([]); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -260,7 +260,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenNoTablesExist() "db1", Arg.Any()).Returns([]); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -293,7 +293,7 @@ public async Task ExecuteAsync_ReturnsError_WhenListServersThrows() _postgresService.ListServersAsync("sub123", "rg1", Arg.Any()) .ThrowsAsync(new Exception("Test error")); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1" @@ -321,7 +321,7 @@ public async Task ExecuteAsync_ReturnsError_WhenListDatabasesThrows() Arg.Any()) .ThrowsAsync(new Exception("Test error")); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -353,7 +353,7 @@ public async Task ExecuteAsync_ReturnsError_WhenListTablesThrows() Arg.Any()) .ThrowsAsync(new Exception("Test error")); - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse([ "--subscription", "sub123", "--resource-group", "rg1", @@ -375,7 +375,7 @@ public async Task ExecuteAsync_ReturnsError_WhenListTablesThrows() [InlineData("--subscription")] public async Task ExecuteAsync_ReturnsError_WhenRequiredParameterIsMissing(string missingParameter) { - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); var args = command.GetCommand().Parse(ArgBuilder.BuildArgs(missingParameter, ("--subscription", "sub123") )); @@ -391,7 +391,7 @@ public async Task ExecuteAsync_ReturnsError_WhenRequiredParameterIsMissing(strin [Fact] public void Metadata_IsConfiguredCorrectly() { - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); Assert.False(command.Metadata.Destructive); Assert.True(command.Metadata.ReadOnly); @@ -400,14 +400,14 @@ public void Metadata_IsConfiguredCorrectly() [Fact] public void Name_IsCorrect() { - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); Assert.Equal("list", command.Name); } [Fact] public void Description_IsCorrect() { - var command = new PostgresListCommand(_logger); + var command = new PostgresListCommand(_postgresService, _logger); Assert.Contains("List PostgreSQL servers", command.Description); Assert.Contains("databases, or tables", command.Description); } diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerConfigGetCommandTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerConfigGetCommandTests.cs index f5d5890e1c..c73665f1b8 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerConfigGetCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerConfigGetCommandTests.cs @@ -38,7 +38,7 @@ public async Task ExecuteAsync_ReturnsConfig_WhenConfigExists() var expectedConfig = "config123"; _postgresService.GetServerConfigAsync("sub123", "rg1", "user1", "server123", Arg.Any()).Returns(expectedConfig); - var command = new ServerConfigGetCommand(_logger); + var command = new ServerConfigGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123"]); var context = new CommandContext(_serviceProvider); @@ -59,7 +59,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenConfigDoesNotExist() { _postgresService.GetServerConfigAsync("sub123", "rg1", "user1", "server123", Arg.Any()).Returns(""); - var command = new ServerConfigGetCommand(_logger); + var command = new ServerConfigGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123"]); var context = new CommandContext(_serviceProvider); var response = await command.ExecuteAsync(context, args, TestContext.Current.CancellationToken); @@ -77,7 +77,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenConfigDoesNotExist() [InlineData("--server")] public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missingParameter) { - var command = new ServerConfigGetCommand(_logger); + var command = new ServerConfigGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(ArgBuilder.BuildArgs(missingParameter, ("--subscription", "sub123"), ("--resource-group", "rg1"), diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamGetCommandTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamGetCommandTests.cs index 905ddd1196..3921a7e0ae 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamGetCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamGetCommandTests.cs @@ -38,7 +38,7 @@ public async Task ExecuteAsync_ReturnsParamValue_WhenParamExists() var expectedValue = "value123"; _postgresService.GetServerParameterAsync("sub123", "rg1", "user1", "server123", "param123", Arg.Any()).Returns(expectedValue); - var command = new ServerParamGetCommand(_logger); + var command = new ServerParamGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123", "--param", "param123"]); var context = new CommandContext(_serviceProvider); var response = await command.ExecuteAsync(context, args, TestContext.Current.CancellationToken); @@ -59,7 +59,7 @@ public async Task ExecuteAsync_ReturnsParamValue_WhenParamExists() public async Task ExecuteAsync_ReturnsNull_WhenParamDoesNotExist() { _postgresService.GetServerParameterAsync("sub123", "rg1", "user1", "server123", "param123", Arg.Any()).Returns(""); - var command = new ServerParamGetCommand(_logger); + var command = new ServerParamGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123", "--param", "param123"]); var context = new CommandContext(_serviceProvider); var response = await command.ExecuteAsync(context, args, TestContext.Current.CancellationToken); @@ -78,7 +78,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenParamDoesNotExist() [InlineData("--param")] public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missingParameter) { - var command = new ServerParamGetCommand(_logger); + var command = new ServerParamGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(ArgBuilder.BuildArgs(missingParameter, ("--subscription", "sub123"), ("--resource-group", "rg1"), diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamSetCommandTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamSetCommandTests.cs index 4872e4111f..d9a4b1c089 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamSetCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Server/ServerParamSetCommandTests.cs @@ -38,7 +38,7 @@ public async Task ExecuteAsync_ReturnsSuccessMessage_WhenParamIsSet() var expectedMessage = "Parameter 'param123' updated successfully to 'value123'."; _postgresService.SetServerParameterAsync("sub123", "rg1", "user1", "server123", "param123", "value123", Arg.Any()).Returns(expectedMessage); - var command = new ServerParamSetCommand(_logger); + var command = new ServerParamSetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123", "--param", "param123", "--value", "value123"]); var context = new CommandContext(_serviceProvider); var response = await command.ExecuteAsync(context, args, TestContext.Current.CancellationToken); @@ -61,7 +61,7 @@ public async Task ExecuteAsync_ReturnsSuccessMessage_WhenParamIsSet() public async Task ExecuteAsync_ReturnsNull_WhenParamDoesNotExist() { _postgresService.SetServerParameterAsync("sub123", "rg1", "user1", "server123", "param123", "value123", Arg.Any()).Returns(""); - var command = new ServerParamSetCommand(_logger); + var command = new ServerParamSetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123", "--param", "param123", "--value", "value123"]); var context = new CommandContext(_serviceProvider); var response = await command.ExecuteAsync(context, args, TestContext.Current.CancellationToken); @@ -81,7 +81,7 @@ public async Task ExecuteAsync_ReturnsNull_WhenParamDoesNotExist() [InlineData("--value")] public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missingParameter) { - var command = new ServerParamSetCommand(_logger); + var command = new ServerParamSetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(ArgBuilder.BuildArgs(missingParameter, ("--subscription", "sub123"), ("--resource-group", "rg1"), @@ -105,7 +105,7 @@ public async Task ExecuteAsync_CallsServiceWithCorrectParameters() var expectedMessage = "Parameter updated successfully."; _postgresService.SetServerParameterAsync("sub123", "rg1", "user1", "server123", "max_connections", "200", Arg.Any()).Returns(expectedMessage); - var command = new ServerParamSetCommand(_logger); + var command = new ServerParamSetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", "--user", "user1", "--server", "server123", "--param", "max_connections", "--value", "200"]); var context = new CommandContext(_serviceProvider); diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Table/TableSchemaGetCommandTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Table/TableSchemaGetCommandTests.cs index f2cceec007..6d2db44383 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Table/TableSchemaGetCommandTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Table/TableSchemaGetCommandTests.cs @@ -39,7 +39,7 @@ public async Task ExecuteAsync_ReturnsSchema_WhenSchemaExists() var expectedSchema = new List(["CREATE TABLE test (id INT);"]); _postgresService.GetTableSchemaAsync("sub123", "rg1", AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "table123", Arg.Any()).Returns(expectedSchema); - var command = new TableSchemaGetCommand(_logger); + var command = new TableSchemaGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", $"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra, "--user", "user1", "--server", "server1", "--database", "db123", "--table", "table123"]); var context = new CommandContext(_serviceProvider); @@ -58,7 +58,7 @@ public async Task ExecuteAsync_ReturnsEmpty_WhenSchemaDoesNotExist() { _postgresService.GetTableSchemaAsync("sub123", "rg1", AuthTypes.MicrosoftEntra, "user1", null, "server1", "db123", "table123", Arg.Any()).Returns([]); - var command = new TableSchemaGetCommand(_logger); + var command = new TableSchemaGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(["--subscription", "sub123", "--resource-group", "rg1", $"--{PostgresOptionDefinitions.AuthTypeText}", AuthTypes.MicrosoftEntra, "--user", "user1", "--server", "server1", "--database", "db123", "--table", "table123"]); var context = new CommandContext(_serviceProvider); @@ -83,7 +83,7 @@ public async Task ExecuteAsync_ReturnsEmpty_WhenSchemaDoesNotExist() [InlineData("--table")] public async Task ExecuteAsync_ReturnsError_WhenParameterIsMissing(string missingParameter) { - var command = new TableSchemaGetCommand(_logger); + var command = new TableSchemaGetCommand(_postgresService, _logger); var args = command.GetCommand().Parse(ArgBuilder.BuildArgs(missingParameter, ("--subscription", "sub123"), ("--resource-group", "rg1"),