-
Notifications
You must be signed in to change notification settings - Fork 436
Expand file tree
/
Copy pathDatabaseQueryCommand.cs
More file actions
86 lines (71 loc) · 3.07 KB
/
DatabaseQueryCommand.cs
File metadata and controls
86 lines (71 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using Azure.Mcp.Core.Extensions;
using Azure.Mcp.Tools.Postgres.Options;
using Azure.Mcp.Tools.Postgres.Options.Database;
using Azure.Mcp.Tools.Postgres.Services;
using Azure.Mcp.Tools.Postgres.Validation;
using Microsoft.Extensions.Logging;
using Microsoft.Mcp.Core.Commands;
using Microsoft.Mcp.Core.Models.Command;
namespace Azure.Mcp.Tools.Postgres.Commands.Database;
public sealed class DatabaseQueryCommand(IPostgresService postgresService, ILogger<DatabaseQueryCommand> logger) : BaseDatabaseCommand<DatabaseQueryOptions>(logger)
{
private readonly IPostgresService _postgresService = postgresService;
private const string CommandTitle = "Query PostgreSQL Database";
public override string Id => "81a28bca-014c-4738-9e1a-654d77cb2dd8";
public override string Name => "query";
public override string Description => "Executes a SQL query on an Azure Database for PostgreSQL server to search for specific terms, retrieve records, or perform SELECT operations.";
public override string Title => CommandTitle;
public override ToolMetadata Metadata => new()
{
Destructive = false,
Idempotent = true,
OpenWorld = false,
ReadOnly = true,
LocalRequired = false,
Secret = false
};
protected override void RegisterOptions(Command command)
{
base.RegisterOptions(command);
command.Options.Add(PostgresOptionDefinitions.Query);
}
protected override DatabaseQueryOptions BindOptions(ParseResult parseResult)
{
var options = base.BindOptions(parseResult);
options.Query = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Query.Name);
return options;
}
public override async Task<CommandResponse> ExecuteAsync(CommandContext context, ParseResult parseResult, CancellationToken cancellationToken)
{
if (!Validate(parseResult.CommandResult, context.Response).IsValid)
{
return context.Response;
}
var options = BindOptions(parseResult);
try
{
// Validate the query early to avoid sending unsafe SQL to the server.
SqlQueryValidator.EnsureReadOnlySelect(options.Query);
List<string> queryResult = await _postgresService.ExecuteQueryAsync(
options.Subscription!,
options.ResourceGroup!,
options.AuthType!,
options.User!,
options.Password,
options.Server!,
options.Database!,
options.Query!,
cancellationToken);
context.Response.Results = ResponseResult.Create(new(queryResult ?? []), PostgresJsonContext.Default.DatabaseQueryCommandResult);
}
catch (Exception ex)
{
_logger.LogError(ex, "An exception occurred while executing the query.");
HandleException(context, ex);
}
return context.Response;
}
internal record DatabaseQueryCommandResult(List<string> QueryResult);
}