diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 76069d3411..7d42c1e6ed 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -166,6 +166,12 @@ # ServiceLabel: %tools-Docker # ServiceOwners: @conniey @microsoft/azure-mcp +# PRLabel: %tools-DocumentDb +/tools/Azure.Mcp.Tools.DocumentDb/ @xingfan-git @microsoft/azure-mcp + +# ServiceLabel: %tools-DocumentDb +# ServiceOwners: @xingfan-git + # ServiceLabel: %tools-Eclipse # ServiceOwners: @srnagar @microsoft/azure-mcp diff --git a/Directory.Packages.props b/Directory.Packages.props index 9ac3e74524..89e339ba97 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -83,6 +83,7 @@ + diff --git a/eng/scripts/New-BuildInfo.ps1 b/eng/scripts/New-BuildInfo.ps1 index 1b1bce7546..187acdf52f 100644 --- a/eng/scripts/New-BuildInfo.ps1 +++ b/eng/scripts/New-BuildInfo.ps1 @@ -170,6 +170,53 @@ function CheckVariable($name) { return $value } +function Test-ProjectUsesMongoDbDriver { + param( + [string] $ProjectPath + ) + + if (!(Test-Path $ProjectPath)) { + return $false + } + + $projectContent = Get-Content $ProjectPath -Raw + return $projectContent -match ' + + + + + + + + diff --git a/servers/Azure.Mcp.Server/README.md b/servers/Azure.Mcp.Server/README.md index 3377dcafee..0d9d36bb8c 100644 --- a/servers/Azure.Mcp.Server/README.md +++ b/servers/Azure.Mcp.Server/README.md @@ -963,6 +963,22 @@ Example prompts that generate Azure CLI commands: * "Get Azure Data Explorer databases in cluster 'mycluster'" * "Sample 10 rows from table 'StormEvents' in Azure Data Explorer database 'db1'" +### 🗄️ Azure DocumentDB (with MongoDB compatibility) + +* "List indexes for collection 'items' in DocumentDB database 'test'" +* "Create an index on field 'category' for collection 'items' in DocumentDB database 'test'" +* "Drop index 'category_1' from collection 'items' in DocumentDB database 'test'" +* "Show index statistics for collection 'items' in DocumentDB database 'test'" +* "Show current DocumentDB operations" +* "List all databases in DocumentDB" +* "Get statistics for database 'mydb'" +* "Get details for database 'analytics' in DocumentDB" +* "Drop database 'testdb'" +* "Show me collections in database 'mydb'" +* "Get statistics for collection 'users'" +* "Rename collection 'old-name' to 'new-name'" +* "Sample documents from collection 'products'" + ### 📣 Azure Event Grid * "List all Event Grid topics in subscription 'my-subscription'" diff --git a/servers/Azure.Mcp.Server/changelog-entries/1773124572664.yaml b/servers/Azure.Mcp.Server/changelog-entries/1773124572664.yaml new file mode 100644 index 0000000000..7fc9dfd344 --- /dev/null +++ b/servers/Azure.Mcp.Server/changelog-entries/1773124572664.yaml @@ -0,0 +1,4 @@ +pr: 1968 +changes: + - section: "Features Added" + description: "Added mcp tools for managing Azure DocumentDB (with MongoDB compatibility) index" \ No newline at end of file diff --git a/servers/Azure.Mcp.Server/changelog-entries/1773579786664.yaml b/servers/Azure.Mcp.Server/changelog-entries/1773579786664.yaml new file mode 100644 index 0000000000..f7cd615a31 --- /dev/null +++ b/servers/Azure.Mcp.Server/changelog-entries/1773579786664.yaml @@ -0,0 +1,3 @@ +changes: + - section: "Features Added" + description: "Added mcp tools for managing Azure DocumentDB (with MongoDB compatibility) database" \ No newline at end of file diff --git a/servers/Azure.Mcp.Server/changelog-entries/1773584979801.yaml b/servers/Azure.Mcp.Server/changelog-entries/1773584979801.yaml new file mode 100644 index 0000000000..2ee231df4f --- /dev/null +++ b/servers/Azure.Mcp.Server/changelog-entries/1773584979801.yaml @@ -0,0 +1,3 @@ +changes: + - section: "Features Added" + description: "Added DocumentDB tools for managing Azure DocumentDB (with MongoDB compatibility) collection" \ No newline at end of file diff --git a/servers/Azure.Mcp.Server/docs/azmcp-commands.md b/servers/Azure.Mcp.Server/docs/azmcp-commands.md index dbb2165bba..394dcda263 100644 --- a/servers/Azure.Mcp.Server/docs/azmcp-commands.md +++ b/servers/Azure.Mcp.Server/docs/azmcp-commands.md @@ -1694,6 +1694,86 @@ azmcp deviceregistry namespace list --subscription \ [--resource-group ] ``` +### Azure DocumentDB (with MongoDB compatibility) Operations + +```bash +# List all indexes on a collection +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb index list indexes --connection-string \ + --db-name \ + --collection-name + +# Create an index on a collection +# ✅ Destructive | ❌ Idempotent | ❌ OpenWorld | ❌ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb index create index --connection-string \ + --db-name \ + --collection-name \ + --keys \ + [--options ] + +# Drop an index from a collection +# ✅ Destructive | ❌ Idempotent | ❌ OpenWorld | ❌ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb index drop index --connection-string \ + --db-name \ + --collection-name \ + --index-name + +# List all databases or inspect a single database +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb database list databases --connection-string \ + [--db-name ] + +# Rename a collection +# ✅ Destructive | ❌ Idempotent | ❌ OpenWorld | ❌ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb collection rename collection --connection-string \ + --db-name \ + --collection-name \ + --new-collection-name + +# Drop a collection +# ✅ Destructive | ❌ Idempotent | ❌ OpenWorld | ❌ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb collection drop collection --connection-string \ + --db-name \ + --collection-name + +# Sample documents from a collection +# ❌ Destructive | ❌ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb collection sample documents --connection-string \ + --db-name \ + --collection-name \ + [--sample-size ] + +# Get DocumentDB statistics for a database +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb others get stats --connection-string \ + --resource-type database \ + --db-name + +# Get DocumentDB statistics for a collection +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb others get stats --connection-string \ + --resource-type collection \ + --db-name \ + --collection-name + +# Get DocumentDB statistics for indexes on a collection +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb others get stats --connection-string \ + --resource-type index \ + --db-name \ + --collection-name + +# Get current DocumentDB operations +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb others current ops --connection-string \ + [--ops ] + +# Drop a database +# ✅ Destructive | ❌ Idempotent | ❌ OpenWorld | ❌ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp documentdb database drop database --connection-string \ + --db-name +``` + ### Azure Event Grid Operations ```bash diff --git a/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md b/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md index a75f04dac0..48399fbcdf 100644 --- a/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md +++ b/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md @@ -356,6 +356,28 @@ This file contains prompts used for end-to-end testing to ensure each tool is in | deviceregistry_namespace_list | List Device Registry namespaces in resource group | | deviceregistry_namespace_list | What Device Registry namespaces do I have in my Azure subscription? | +## Azure DocumentDB (with MongoDB compatibility) + +| Tool Name | Test Prompt | +|:----------|:----------| +| documentdb_index_list_indexes | List indexes for collection in DocumentDB database | +| documentdb_index_list_indexes | Show me all indexes on collection in database | +| documentdb_index_create_index | Create an index on collection in DocumentDB database using keys | +| documentdb_index_create_index | Add a DocumentDB index for collection in database with keys and options | +| documentdb_index_drop_index | Drop index from collection in DocumentDB database | +| documentdb_index_drop_index | Remove the index from DocumentDB collection in database | +| documentdb_others_get_stats | Show index statistics for collection in DocumentDB database | +| documentdb_others_get_stats | Get DocumentDB index stats for collection in database | +| documentdb_others_current_ops | Show current DocumentDB operations | +| documentdb_others_current_ops | Get current DocumentDB operations filtered by | +| documentdb_others_get_stats | Get statistics for database | +| documentdb_others_get_stats | Show me stats for DocumentDB database | +| documentdb_database_drop_database | Drop database | +| documentdb_database_drop_database | Delete the database from DocumentDB | +| documentdb_database_list_databases | List all databases in DocumentDB | +| documentdb_database_list_databases | Show me all DocumentDB databases | +| documentdb_database_list_databases | Get details for database in DocumentDB | + ## Azure Event Grid | Tool Name | Test Prompt | diff --git a/servers/Azure.Mcp.Server/src/Program.cs b/servers/Azure.Mcp.Server/src/Program.cs index 53c3b752ce..7fd21e71f7 100644 --- a/servers/Azure.Mcp.Server/src/Program.cs +++ b/servers/Azure.Mcp.Server/src/Program.cs @@ -104,6 +104,7 @@ private static IAreaSetup[] RegisterAreas() new Azure.Mcp.Tools.AzureTerraformBestPractices.AzureTerraformBestPracticesSetup(), new Azure.Mcp.Tools.Deploy.DeploySetup(), new Azure.Mcp.Tools.DeviceRegistry.DeviceRegistrySetup(), + new Azure.Mcp.Tools.DocumentDb.DocumentDbSetup(), new Azure.Mcp.Tools.EventGrid.EventGridSetup(), new Azure.Mcp.Tools.Acr.AcrSetup(), new Azure.Mcp.Tools.Advisor.AdvisorSetup(), diff --git a/servers/Azure.Mcp.Server/src/Resources/consolidated-tools.json b/servers/Azure.Mcp.Server/src/Resources/consolidated-tools.json index c3deaf18a9..aaf8ebbd37 100644 --- a/servers/Azure.Mcp.Server/src/Resources/consolidated-tools.json +++ b/servers/Azure.Mcp.Server/src/Resources/consolidated-tools.json @@ -137,7 +137,7 @@ }, { "name": "get_azure_databases_details", - "description": "Comprehensive Azure database management tool for MySQL, PostgreSQL, SQL Database, SQL Server, and Cosmos DB. List and query databases, retrieve server configurations and parameters, explore table schemas, execute database queries, manage Cosmos DB containers and items, list and view detailed information about SQL servers, and view database server details across all Azure database services.", + "description": "Comprehensive Azure database management tool for MySQL, PostgreSQL, SQL Database, SQL Server, Cosmos DB, and Azure DocumentDB (MongoDB-compatible). List and query databases, retrieve server configurations and parameters, explore table schemas, execute database queries, manage Cosmos DB containers and items, inspect DocumentDB databases, indexes, statistics, and current operations, list and view detailed information about SQL servers, and view database server details across Azure database services.", "toolMetadata": { "destructive": { "value": false, @@ -178,7 +178,81 @@ "sql_db_get", "sql_server_get", "cosmos_list", - "cosmos_database_container_item_query" + "cosmos_database_container_item_query", + "documentdb_database_list_databases", + "documentdb_index_list_indexes", + "documentdb_others_current_ops", + "documentdb_others_get_stats" + ] + }, + { + "name": "sample_azure_documentdb_collection_documents", + "description": "Retrieve sample documents from Azure DocumentDB collections to inspect document structure and infer schema patterns for query authoring.", + "toolMetadata": { + "destructive": { + "value": false, + "description": "This tool performs only additive updates without deleting or modifying existing resources." + }, + "idempotent": { + "value": false, + "description": "Running this operation multiple times with the same arguments may have additional effects or produce different results." + }, + "openWorld": { + "value": false, + "description": "This tool's domain of interaction is closed and well-defined, limited to a specific set of entities." + }, + "readOnly": { + "value": true, + "description": "This tool only performs read operations without modifying any state or data." + }, + "secret": { + "value": false, + "description": "This tool does not handle sensitive or secret information." + }, + "localRequired": { + "value": false, + "description": "This tool is available in both local and remote server modes." + } + }, + "mappedToolList": [ + "documentdb_collection_sample_documents" + ] + }, + { + "name": "edit_azure_documentdb_databases_and_collections", + "description": "Manage Azure DocumentDB databases, collections, and indexes. Create and drop indexes, rename or drop collections, and drop databases.", + "toolMetadata": { + "destructive": { + "value": true, + "description": "This tool may delete or modify existing resources in its environment." + }, + "idempotent": { + "value": false, + "description": "Running this operation multiple times with the same arguments may have additional effects or produce different results." + }, + "openWorld": { + "value": false, + "description": "This tool's domain of interaction is closed and well-defined, limited to a specific set of entities." + }, + "readOnly": { + "value": false, + "description": "This tool may modify its environment by creating, updating, or deleting data." + }, + "secret": { + "value": false, + "description": "This tool does not handle sensitive or secret information." + }, + "localRequired": { + "value": false, + "description": "This tool is available in both local and remote server modes." + } + }, + "mappedToolList": [ + "documentdb_collection_drop_collection", + "documentdb_collection_rename_collection", + "documentdb_database_drop_database", + "documentdb_index_create_index", + "documentdb_index_drop_index" ] }, { diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/AssemblyInfo.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/AssemblyInfo.cs new file mode 100644 index 0000000000..9f6cdfab4b --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Azure.Mcp.Tools.DocumentDb.UnitTests")] diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Azure.Mcp.Tools.DocumentDb.csproj b/tools/Azure.Mcp.Tools.DocumentDb/src/Azure.Mcp.Tools.DocumentDb.csproj new file mode 100644 index 0000000000..0131c8a708 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Azure.Mcp.Tools.DocumentDb.csproj @@ -0,0 +1,22 @@ + + + true + + false + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/BaseDocumentDbCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/BaseDocumentDbCommand.cs new file mode 100644 index 0000000000..2540569aec --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/BaseDocumentDbCommand.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Diagnostics.CodeAnalysis; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; + +namespace Azure.Mcp.Tools.DocumentDb.Commands; + +public abstract class BaseDocumentDbCommand< + [DynamicallyAccessedMembers(TrimAnnotations.CommandAnnotations)] TOptions> + : GlobalCommand where TOptions : BaseDocumentDbOptions, new() +{ + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.ConnectionString); + } + + protected override TOptions BindOptions(ParseResult parseResult) + { + return new TOptions + { + ConnectionString = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.ConnectionString.Name) + }; + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/DropCollectionCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/DropCollectionCommand.cs new file mode 100644 index 0000000000..4a61790822 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/DropCollectionCommand.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Collection; + +public sealed class DropCollectionCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "d0e1f2a3-b4c5-4d0e-7f8a-9b0c1d2e3f4a"; + + public override string Name => "drop_collection"; + + public override string Description => "Drop a collection from a database"; + + public override string Title => "Drop Collection"; + + public override ToolMetadata Metadata => new() + { + Destructive = true, + Idempotent = false, + OpenWorld = false, + Secret = false, + LocalRequired = false, + ReadOnly = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName); + } + + protected override DropCollectionOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + DropCollectionOptions? options = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + options = BindOptions(parseResult); + + var service = context.GetService(); + + var result = await service.DropCollectionAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to drop collection: {CollectionName} from database: {DbName}", options?.CollectionName, options?.DbName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/RenameCollectionCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/RenameCollectionCommand.cs new file mode 100644 index 0000000000..4b4f7ffada --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/RenameCollectionCommand.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Collection; + +public sealed class RenameCollectionCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "c9d0e1f2-a3b4-4c9d-6e7f-8a9b0c1d2e3f"; + + public override string Name => "rename_collection"; + + public override string Description => "Rename a collection"; + + public override string Title => "Rename Collection"; + + public override ToolMetadata Metadata => new() + { + Destructive = true, + Idempotent = false, + OpenWorld = false, + Secret = false, + LocalRequired = false, + ReadOnly = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName); + command.Options.Add(DocumentDbOptionDefinitions.NewCollectionName); + } + + protected override RenameCollectionOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + options.NewCollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.NewCollectionName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + RenameCollectionOptions? options = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + options = BindOptions(parseResult); + + var service = context.GetService(); + + var result = await service.RenameCollectionAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, options.NewCollectionName!, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to rename collection from {OldName} to {NewName} in database: {DbName}", options?.CollectionName, options?.NewCollectionName, options?.DbName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/SampleDocumentsCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/SampleDocumentsCommand.cs new file mode 100644 index 0000000000..5dd391d4d5 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Collection/SampleDocumentsCommand.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Collection; + +public sealed class SampleDocumentsCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "e1f2a3b4-c5d6-4e1f-8a9b-0c1d2e3f4a5b"; + + public override string Name => "sample_documents"; + + public override string Description => "Retrieve sample documents from a specific collection. Useful for understanding data schema and query generation"; + + public override string Title => "Sample Documents"; + + public override ToolMetadata Metadata => new() + { + Destructive = false, + Idempotent = false, + OpenWorld = false, + Secret = false, + LocalRequired = false, + ReadOnly = true + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName); + command.Options.Add(DocumentDbOptionDefinitions.SampleSize); + } + + protected override SampleDocumentsOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + options.SampleSize = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.SampleSize.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + SampleDocumentsOptions? options = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + options = BindOptions(parseResult); + + var service = context.GetService(); + + var result = await service.SampleDocumentsAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, options.SampleSize, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to sample documents from collection: {CollectionName} in database: {DbName} with sample size: {SampleSize}", options?.CollectionName, options?.DbName, options?.SampleSize); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Database/DropDatabaseCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Database/DropDatabaseCommand.cs new file mode 100644 index 0000000000..631b7f842c --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Database/DropDatabaseCommand.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Database; + +public sealed class DropDatabaseCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "a7b8c9d0-e1f2-4a7b-4c5d-6e7f8a9b0c1d"; + + public override string Name => "drop_database"; + + public override string Description => "Drop a DocumentDB database, removing all of its collections and data."; + + public override string Title => "Drop Database"; + + public override ToolMetadata Metadata => new() + { + Destructive = true, + Idempotent = false, + OpenWorld = false, + ReadOnly = false, + Secret = false, + LocalRequired = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + } + + protected override DropDatabaseOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + DropDatabaseOptions? options = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + options = BindOptions(parseResult); + var dbName = options.DbName!; + + var service = context.GetService(); + + var result = await service.DropDatabaseAsync(options.ConnectionString!, dbName, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to drop database: {DbName}", options?.DbName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Database/ListDatabasesCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Database/ListDatabasesCommand.cs new file mode 100644 index 0000000000..875cc94dc1 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Database/ListDatabasesCommand.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; +using Microsoft.Mcp.Core.Models.Option; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Database; + +public sealed class ListDatabasesCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "d4e5f6a7-b8c9-4d4e-1f2a-3b4c5d6e7f8a"; + + public override string Name => "list_databases"; + + public override string Description => "List DocumentDB databases. If --db-name is omitted, returns all database names. If --db-name is provided, returns detailed information for that database."; + + public override string Title => "List Databases"; + + public override ToolMetadata Metadata => new() + { + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName.AsOptional()); + } + + protected override ListDatabasesOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + ListDatabasesOptions? options = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + options = BindOptions(parseResult); + var dbName = options.DbName; + + var service = context.GetService(); + + var result = await service.GetDatabasesAsync(options.ConnectionString!, dbName, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to get DocumentDB database details. Database: {DbName}", options?.DbName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbHelpers.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbHelpers.cs new file mode 100644 index 0000000000..749d28ed9a --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbHelpers.cs @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; + +namespace Azure.Mcp.Tools.DocumentDb.Commands; + +internal static class DocumentDbHelpers +{ + public static BsonDocument? ParseBsonDocument(string? json) + { + if (string.IsNullOrWhiteSpace(json)) + return null; + + try + { + return BsonDocument.Parse(json); + } + catch (Exception ex) + { + throw new ArgumentException("The provided value is not valid BSON/JSON document content.", nameof(json), ex); + } + } + + public static BsonDocument? ParseBsonDocument(object? value) + { + if (value == null) + return null; + + if (value is string str) + return ParseBsonDocument(str); + + if (value is BsonDocument doc) + return doc; + + try + { + var json = DocumentDbResponseHelper.SerializeToJson(value); + return BsonDocument.Parse(json); + } + catch (Exception ex) + { + throw new InvalidOperationException($"The value of type '{value.GetType().FullName}' could not be converted to a BSON document.", ex); + } + } + + public static List? ParseBsonDocumentList(string? json) + { + if (string.IsNullOrWhiteSpace(json)) + return null; + + try + { + var bsonArray = BsonSerializer.Deserialize(json); + return bsonArray.Select(item => item.AsBsonDocument).ToList(); + } + catch (Exception ex) + { + throw new ArgumentException("The provided value is not valid BSON/JSON array content.", nameof(json), ex); + } + } + + public static List? ParseBsonDocumentList(object? value) + { + if (value == null) + return null; + + if (value is string str) + return ParseBsonDocumentList(str); + + if (value is List list) + return list; + + try + { + var json = DocumentDbResponseHelper.SerializeToJson(value); + return ParseBsonDocumentList(json); + } + catch (Exception ex) + { + throw new InvalidOperationException($"The value of type '{value.GetType().FullName}' could not be converted to a BSON document list.", ex); + } + } + + public static bool ParseBoolean(string? value, bool defaultValue = false) + { + if (string.IsNullOrWhiteSpace(value)) + return defaultValue; + + if (bool.TryParse(value, out var result)) + return result; + + // Handle common string representations + return value.Trim().ToLowerInvariant() switch + { + "true" or "1" or "yes" => true, + "false" or "0" or "no" => false, + _ => defaultValue + }; + } + + public static int ParseInt(string? value, int defaultValue = 0) + { + if (string.IsNullOrWhiteSpace(value)) + return defaultValue; + + return int.TryParse(value, out var result) ? result : defaultValue; + } + + public static string SerializeBsonToJson(BsonDocument document) + { + var jsonWriterSettings = new JsonWriterSettings { OutputMode = JsonOutputMode.RelaxedExtendedJson }; + return document.ToJson(jsonWriterSettings); + } + + public static string SerializeBsonToJson(object obj) + { + if (obj is BsonDocument doc) + return SerializeBsonToJson(doc); + + return DocumentDbResponseHelper.SerializeToJson(obj); + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbJsonContext.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbJsonContext.cs new file mode 100644 index 0000000000..225ca9ded4 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbJsonContext.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; +using Azure.Mcp.Tools.DocumentDb.Models; +using MongoDB.Bson; + +namespace Azure.Mcp.Tools.DocumentDb.Commands; + +[JsonSourceGenerationOptions( + WriteIndented = false, + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] +[JsonSerializable(typeof(string))] +[JsonSerializable(typeof(object))] +[JsonSerializable(typeof(int))] +[JsonSerializable(typeof(bool))] +[JsonSerializable(typeof(long))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(List>))] +[JsonSerializable(typeof(List>))] +[JsonSerializable(typeof(System.Text.Json.JsonElement))] +internal partial class DocumentDbJsonContext : JsonSerializerContext; + +/// +/// Helper class for creating ResponseResult from JSON strings +/// +internal static class DocumentDbResponseHelper +{ + public static Microsoft.Mcp.Core.Models.Command.ResponseResult CreateFromJson(string json) + { + // Parse the JSON string to a JsonElement to get proper serialization + var element = System.Text.Json.JsonSerializer.Deserialize(json, DocumentDbJsonContext.Default.JsonElement); + return Microsoft.Mcp.Core.Models.Command.ResponseResult.Create(element, DocumentDbJsonContext.Default.JsonElement); + } + + public static string SerializeToJson(object value) + { + return value switch + { + // Handle BsonDocument by converting to JSON first + MongoDB.Bson.BsonDocument bsonDoc => bsonDoc.ToJson(new MongoDB.Bson.IO.JsonWriterSettings { OutputMode = MongoDB.Bson.IO.JsonOutputMode.RelaxedExtendedJson }), + List bsonList => "[" + string.Join(",", bsonList.Select(doc => doc.ToJson(new MongoDB.Bson.IO.JsonWriterSettings { OutputMode = MongoDB.Bson.IO.JsonOutputMode.RelaxedExtendedJson }))) + "]", + + // Handle standard types + Dictionary dict => System.Text.Json.JsonSerializer.Serialize(dict, DocumentDbJsonContext.Default.DictionaryStringObject), + List> list => System.Text.Json.JsonSerializer.Serialize(list, DocumentDbJsonContext.Default.ListDictionaryStringObject), + List strList => System.Text.Json.JsonSerializer.Serialize(strList, DocumentDbJsonContext.Default.ListString), + string str => System.Text.Json.JsonSerializer.Serialize(str, DocumentDbJsonContext.Default.String), + int i => System.Text.Json.JsonSerializer.Serialize(i, DocumentDbJsonContext.Default.Int32), + long l => System.Text.Json.JsonSerializer.Serialize(l, DocumentDbJsonContext.Default.Int64), + bool b => System.Text.Json.JsonSerializer.Serialize(b, DocumentDbJsonContext.Default.Boolean), + System.Text.Json.JsonElement element => System.Text.Json.JsonSerializer.Serialize(element, DocumentDbJsonContext.Default.JsonElement), + + // Handle IEnumerable (LINQ results) + System.Collections.Generic.IEnumerable enumStr => System.Text.Json.JsonSerializer.Serialize(enumStr.ToList(), DocumentDbJsonContext.Default.ListString), + + _ => throw new NotSupportedException($"Type {value.GetType().FullName} is not supported for AOT serialization. Please add it to DocumentDbJsonContext.") + }; + } + + public static T? DeserializeFromJson(string json) where T : class + { + // Only supports object type for AOT compatibility + if (typeof(T) == typeof(object)) + { + return System.Text.Json.JsonSerializer.Deserialize(json, DocumentDbJsonContext.Default.Object) as T; + } + + throw new NotSupportedException($"Type {typeof(T).Name} is not supported. Only 'object' type is AOT-compatible."); + } + + /// + /// Processes a DocumentDb service response and applies it to the command context. + /// + /// The command context to update. + /// The service response. + public static void ProcessResponse(Microsoft.Mcp.Core.Models.Command.CommandContext context, DocumentDbResponse response) + { + context.Response.Status = response.StatusCode; + + if (response.Success) + { + // For success with no data, create an empty result with the message + var dataToSerialize = response.Data ?? new Dictionary + { + ["message"] = response.Message + }; + context.Response.Results = CreateFromJson(SerializeToJson(dataToSerialize)); + } + else + { + context.Response.Message = response.Message ?? "Unknown error"; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbOptionDefinitions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbOptionDefinitions.cs new file mode 100644 index 0000000000..4d45eea5e9 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/DocumentDbOptionDefinitions.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; + +namespace Azure.Mcp.Tools.DocumentDb.Commands; + +internal static class DocumentDbOptionDefinitions +{ + public static readonly Option ConnectionString = new("--connection-string") + { + Description = "Azure DocumentDB connection string used for this request.", + Required = true + }; + + public static readonly Option DbName = new("--db-name") + { + Description = "Database name", + Required = true + }; + + public static readonly Option CollectionName = new("--collection-name") + { + Description = "Collection name", + Required = true + }; + + public static readonly Option ResourceType = CreateResourceTypeOption(); + + public static readonly Option NewCollectionName = new("--new-collection-name") + { + Description = "New collection name", + Required = true + }; + + public static readonly Option SampleSize = new("--sample-size") + { + Description = "Number of documents to sample", + DefaultValueFactory = _ => 10 + }; + + public static readonly Option Query = new("--query") + { + Description = "Query filter in JSON format" + }; + + public static readonly Option Options = new("--options") + { + Description = "Query options" + }; + + public static readonly Option Document = new("--document") + { + Description = "Document to insert", + Required = true + }; + + public static readonly Option Documents = new("--documents") + { + Description = "Documents to insert", + Required = true + }; + + public static readonly Option Filter = new("--filter") + { + Description = "Filter for update/delete", + Required = true + }; + + public static readonly Option Update = new("--update") + { + Description = "Update operations", + Required = true + }; + + public static readonly Option Upsert = new("--upsert") + { + Description = "Create document if it doesn't exist", + DefaultValueFactory = _ => false + }; + + public static readonly Option Pipeline = new("--pipeline") + { + Description = "Aggregation pipeline", + Required = true + }; + + public static readonly Option AllowDiskUse = new("--allow-disk-use") + { + Description = "Allow pipeline stages to write to disk", + DefaultValueFactory = _ => false + }; + + public static readonly Option Keys = new("--keys") + { + Description = "Index keys", + Required = true + }; + + public static readonly Option IndexName = new("--index-name") + { + Description = "Index name", + Required = true + }; + + public static readonly Option Ops = new("--ops") + { + Description = "Filter for current operations" + }; + + private static Option CreateResourceTypeOption() + { + var option = new Option("--resource-type") + { + Description = "Resource type to retrieve statistics for. Valid values: collection, database, index.", + Required = true + }; + + option.AcceptOnlyFromAmong("collection", "database", "index"); + return option; + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/CreateIndexCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/CreateIndexCommand.cs new file mode 100644 index 0000000000..f39549a8bb --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/CreateIndexCommand.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Index; + +public sealed class CreateIndexCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "a5b6c7d8-e9f0-4a5b-2c3d-4e5f6a7b8c9d"; + + public override string Name => "create_index"; + + public override string Description => "Create an index on a collection"; + + public override string Title => "Create Index"; + + public override ToolMetadata Metadata => new() + { + Destructive = true, + Idempotent = false, + OpenWorld = false, + ReadOnly = false, + LocalRequired = false, + Secret = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName); + command.Options.Add(DocumentDbOptionDefinitions.Keys); + command.Options.Add(DocumentDbOptionDefinitions.Options); + } + + protected override CreateIndexOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + options.Keys = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.Keys.Name); + options.Options = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.Options.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + CreateIndexOptions? commandOptions = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = commandOptions = BindOptions(parseResult); + + var service = context.GetService(); + + var keys = DocumentDbHelpers.ParseBsonDocument(options.Keys); + if (keys == null) + { + throw new ArgumentException("Invalid keys format"); + } + + var indexOptions = DocumentDbHelpers.ParseBsonDocument(options.Options); + + DocumentDbResponse result = await service.CreateIndexAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, keys, indexOptions, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to create index on collection: {CollectionName}, database: {DbName}, keys: {Keys}", commandOptions?.CollectionName, commandOptions?.DbName, commandOptions?.Keys); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/DropIndexCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/DropIndexCommand.cs new file mode 100644 index 0000000000..67849acbc3 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/DropIndexCommand.cs @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Index; + +public sealed class DropIndexCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "c7d8e9f0-a1b2-4c7d-4e5f-6a7b8c9d0e1f"; + + public override string Name => "drop_index"; + + public override string Description => "Drop an index from a collection"; + + public override string Title => "Drop Index"; + + public override ToolMetadata Metadata => new() + { + Destructive = true, + Idempotent = false, + OpenWorld = false, + ReadOnly = false, + LocalRequired = false, + Secret = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName); + command.Options.Add(DocumentDbOptionDefinitions.IndexName); + } + + protected override DropIndexOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + options.IndexName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.IndexName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + DropIndexOptions? commandOptions = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = commandOptions = BindOptions(parseResult); + + var service = context.GetService(); + + DocumentDbResponse result = await service.DropIndexAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, options.IndexName!, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to drop index: {IndexName} from collection: {CollectionName}, database: {DbName}", commandOptions?.IndexName, commandOptions?.CollectionName, commandOptions?.DbName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/ListIndexesCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/ListIndexesCommand.cs new file mode 100644 index 0000000000..54fd61343a --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Index/ListIndexesCommand.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Index; + +public sealed class ListIndexesCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "b6c7d8e9-f0a1-4b6c-3d4e-5f6a7b8c9d0e"; + + public override string Name => "list_indexes"; + + public override string Description => "List all indexes on a collection"; + + public override string Title => "List Indexes"; + + public override ToolMetadata Metadata => new() + { + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + LocalRequired = false, + Secret = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName); + } + + protected override ListIndexesOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + ListIndexesOptions? commandOptions = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = commandOptions = BindOptions(parseResult); + + var service = context.GetService(); + + DocumentDbResponse result = await service.ListIndexesAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to list indexes on collection: {CollectionName}, database: {DbName}", commandOptions?.CollectionName, commandOptions?.DbName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Others/CurrentOpsCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Others/CurrentOpsCommand.cs new file mode 100644 index 0000000000..d9913d5cfc --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Others/CurrentOpsCommand.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Others; + +public sealed class CurrentOpsCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "e9f0a1b2-c3d4-4e9f-6a7b-8c9d0e1f2a3b"; + + public override string Name => "current_ops"; + + public override string Description => "Get information about current DocumentDB operations"; + + public override string Title => "Current Operations"; + + public override ToolMetadata Metadata => new() + { + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + LocalRequired = false, + Secret = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.Ops); + } + + protected override CurrentOpsOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.Ops = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.Ops.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + CurrentOpsOptions? commandOptions = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = commandOptions = BindOptions(parseResult); + + var service = context.GetService(); + + var filter = DocumentDbHelpers.ParseBsonDocument(options.Ops); + + DocumentDbResponse result = await service.GetCurrentOpsAsync(options.ConnectionString!, filter, cancellationToken); + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to get current operations with filter: {Ops}", commandOptions?.Ops); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Others/GetStatsCommand.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Others/GetStatsCommand.cs new file mode 100644 index 0000000000..a9855e4d1c --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Commands/Others/GetStatsCommand.cs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Core.Commands; +using Azure.Mcp.Core.Extensions; +using Azure.Mcp.Tools.DocumentDb.Options; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Models.Command; +using Microsoft.Mcp.Core.Models.Option; + +namespace Azure.Mcp.Tools.DocumentDb.Commands.Others; + +public sealed class GetStatsCommand(ILogger logger) + : BaseDocumentDbCommand() +{ + private readonly ILogger _logger = logger; + + public override string Id => "73d37b66-e26e-4cd0-b401-cff1f9f09d8e"; + + public override string Name => "get_stats"; + + public override string Description => "Get statistics for a DocumentDB collection, database, or index by resource type."; + + public override string Title => "Get Statistics"; + + public override ToolMetadata Metadata => new() + { + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false + }; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(DocumentDbOptionDefinitions.ResourceType); + command.Options.Add(DocumentDbOptionDefinitions.DbName); + command.Options.Add(DocumentDbOptionDefinitions.CollectionName.AsOptional()); + } + + protected override GetStatsOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.ResourceType = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.ResourceType.Name); + options.DbName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.DbName.Name); + options.CollectionName = parseResult.GetValueOrDefault(DocumentDbOptionDefinitions.CollectionName.Name); + return options; + } + + public override async Task ExecuteAsync( + CommandContext context, + ParseResult parseResult, + CancellationToken cancellationToken) + { + GetStatsOptions? commandOptions = null; + + try + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = commandOptions = BindOptions(parseResult); + + if ((options.ResourceType is "collection" or "index") && string.IsNullOrWhiteSpace(options.CollectionName)) + { + context.Response.Status = HttpStatusCode.BadRequest; + context.Response.Message = $"--collection-name is required when --resource-type is '{options.ResourceType}'."; + return context.Response; + } + + var service = context.GetService(); + var result = options.ResourceType switch + { + "collection" => await service.GetCollectionStatsAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, cancellationToken), + "database" => await service.GetDatabaseStatsAsync(options.ConnectionString!, options.DbName!, cancellationToken), + "index" => await service.GetIndexStatsAsync(options.ConnectionString!, options.DbName!, options.CollectionName!, cancellationToken), + _ => throw new InvalidOperationException($"Unsupported resource type '{options.ResourceType}'.") + }; + + DocumentDbResponseHelper.ProcessResponse(context, result); + + return context.Response; + } + catch (Exception ex) + { + _logger.LogError( + ex, + "Failed to get {ResourceType} statistics for database: {DbName}, collection: {CollectionName}", + commandOptions?.ResourceType, + commandOptions?.DbName, + commandOptions?.CollectionName); + HandleException(context, ex); + return context.Response; + } + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/DocumentDbSetup.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/DocumentDbSetup.cs new file mode 100644 index 0000000000..6c211f0142 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/DocumentDbSetup.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +using Azure.Mcp.Tools.DocumentDb.Commands.Collection; +using Azure.Mcp.Tools.DocumentDb.Commands.Database; +using Azure.Mcp.Tools.DocumentDb.Commands.Index; +using Azure.Mcp.Tools.DocumentDb.Commands.Others; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Mcp.Core.Areas; +using Microsoft.Mcp.Core.Commands; + +namespace Azure.Mcp.Tools.DocumentDb; + +public sealed class DocumentDbSetup : IAreaSetup +{ + public string Name => "documentdb"; + public string Title => "Azure DocumentDB (with MongoDB compatibility)"; + + public void ConfigureServices(IServiceCollection services) + { + services.AddSingleton(); + + // Index commands + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + + // Database commands + services.AddSingleton(); + services.AddSingleton(); + + // Collection commands + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + + // Other commands + services.AddSingleton(); + services.AddSingleton(); + } + + public CommandGroup RegisterCommands(IServiceProvider serviceProvider) + { + // Create DocumentDB root command group + var documentDb = new CommandGroup( + Name, + "Azure DocumentDB index, database, collection, and diagnostics operations for Azure DocumentDB.", + Title); + + var index = new CommandGroup( + "index", + "Manage indexes and inspect index-related diagnostics by providing a DocumentDB connection string per request."); + var database = new CommandGroup( + "database", + "Inspect and manage DocumentDB databases by providing a DocumentDB connection string per request."); + var collection = new CommandGroup( + "collection", + "Manage DocumentDB collections by providing a DocumentDB connection string per request."); + var others = new CommandGroup( + "others", + "Inspect DocumentDB statistics and diagnostic operations by providing a DocumentDB connection string per request."); + + documentDb.AddSubGroup(index); + documentDb.AddSubGroup(database); + documentDb.AddSubGroup(collection); + documentDb.AddSubGroup(others); + + // Index commands + var createIndexCommand = serviceProvider.GetRequiredService(); + var listIndexesCommand = serviceProvider.GetRequiredService(); + var dropIndexCommand = serviceProvider.GetRequiredService(); + + // Database commands + var listDatabasesCommand = serviceProvider.GetRequiredService(); + var dropDatabaseCommand = serviceProvider.GetRequiredService(); + + // Collection commands + var renameCollectionCommand = serviceProvider.GetRequiredService(); + var dropCollectionCommand = serviceProvider.GetRequiredService(); + var sampleDocumentsCommand = serviceProvider.GetRequiredService(); + + // Other commands + var getStatsCommand = serviceProvider.GetRequiredService(); + var currentOpsCommand = serviceProvider.GetRequiredService(); + + index.AddCommand(createIndexCommand.Name, createIndexCommand); + index.AddCommand(listIndexesCommand.Name, listIndexesCommand); + index.AddCommand(dropIndexCommand.Name, dropIndexCommand); + + database.AddCommand(listDatabasesCommand.Name, listDatabasesCommand); + database.AddCommand(dropDatabaseCommand.Name, dropDatabaseCommand); + + collection.AddCommand(renameCollectionCommand.Name, renameCollectionCommand); + collection.AddCommand(dropCollectionCommand.Name, dropCollectionCommand); + collection.AddCommand(sampleDocumentsCommand.Name, sampleDocumentsCommand); + + others.AddCommand(getStatsCommand.Name, getStatsCommand); + others.AddCommand(currentOpsCommand.Name, currentOpsCommand); + + return documentDb; + } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/GlobalUsings.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/GlobalUsings.cs new file mode 100644 index 0000000000..b41cc886b4 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/GlobalUsings.cs @@ -0,0 +1,4 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +global using System.CommandLine; diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Models/DocumentDbResponse.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Models/DocumentDbResponse.cs new file mode 100644 index 0000000000..834143a678 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Models/DocumentDbResponse.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; + +namespace Azure.Mcp.Tools.DocumentDb.Models; + +/// +/// Represents a unified response structure for all DocumentDb MCP commands. +/// +public class DocumentDbResponse +{ + /// + /// Gets or sets a value indicating whether the operation was successful. + /// + public bool Success { get; set; } + + /// + /// Gets or sets the HTTP status code of the operation. + /// + public HttpStatusCode StatusCode { get; set; } + + /// + /// Gets or sets the message (error or informational) from the operation. + /// + public string? Message { get; set; } + + /// + /// Gets or sets the response data payload from the operation. + /// + public object? Data { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/BaseDocumentDbOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/BaseDocumentDbOptions.cs new file mode 100644 index 0000000000..f81c87ca5d --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/BaseDocumentDbOptions.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Mcp.Core.Options; + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class BaseDocumentDbOptions : GlobalOptions +{ + public string? ConnectionString { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/CreateIndexOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/CreateIndexOptions.cs new file mode 100644 index 0000000000..5260b6857f --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/CreateIndexOptions.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class CreateIndexOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } + + public string? CollectionName { get; set; } + + public string? Keys { get; set; } + + public string? Options { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/CurrentOpsOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/CurrentOpsOptions.cs new file mode 100644 index 0000000000..fbba9cfd7b --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/CurrentOpsOptions.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class CurrentOpsOptions : BaseDocumentDbOptions +{ + public string? Ops { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropCollectionOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropCollectionOptions.cs new file mode 100644 index 0000000000..23d8b1214b --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropCollectionOptions.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class DropCollectionOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } + + public string? CollectionName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropDatabaseOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropDatabaseOptions.cs new file mode 100644 index 0000000000..a040ab7317 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropDatabaseOptions.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class DropDatabaseOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropIndexOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropIndexOptions.cs new file mode 100644 index 0000000000..712fbb2268 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/DropIndexOptions.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class DropIndexOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } + + public string? CollectionName { get; set; } + + public string? IndexName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/GetStatsOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/GetStatsOptions.cs new file mode 100644 index 0000000000..36565db041 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/GetStatsOptions.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public sealed class GetStatsOptions : BaseDocumentDbOptions +{ + public string? ResourceType { get; set; } + + public string? DbName { get; set; } + + public string? CollectionName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/ListDatabasesOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/ListDatabasesOptions.cs new file mode 100644 index 0000000000..2195f39e25 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/ListDatabasesOptions.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class ListDatabasesOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/ListIndexesOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/ListIndexesOptions.cs new file mode 100644 index 0000000000..582e715843 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/ListIndexesOptions.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class ListIndexesOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } + + public string? CollectionName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/RenameCollectionOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/RenameCollectionOptions.cs new file mode 100644 index 0000000000..04b4a393a9 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/RenameCollectionOptions.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class RenameCollectionOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } + + public string? CollectionName { get; set; } + + public string? NewCollectionName { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Options/SampleDocumentsOptions.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/SampleDocumentsOptions.cs new file mode 100644 index 0000000000..db0c16c1b0 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Options/SampleDocumentsOptions.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.DocumentDb.Options; + +public class SampleDocumentsOptions : BaseDocumentDbOptions +{ + public string? DbName { get; set; } + + public string? CollectionName { get; set; } + + public int SampleSize { get; set; } = 10; +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Services/DocumentDbService.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Services/DocumentDbService.cs new file mode 100644 index 0000000000..b9a3c9c85b --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Services/DocumentDbService.cs @@ -0,0 +1,690 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Models; +using Microsoft.Extensions.Logging; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Driver; + +namespace Azure.Mcp.Tools.DocumentDb.Services; + +public sealed class DocumentDbService(ILogger logger) : IDocumentDbService +{ + private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + private static readonly JsonWriterSettings s_jsonWriterSettings = new() { OutputMode = JsonOutputMode.RelaxedExtendedJson }; + + #region Index Management + + public async Task CreateIndexAsync(string connectionString, string databaseName, string collectionName, BsonDocument keys, BsonDocument? options = null, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + ArgumentNullException.ThrowIfNull(keys); + + try + { + var collection = GetCollection(connectionString, databaseName, collectionName); + var createIndexOptions = CreateIndexOptions(options); + var model = new CreateIndexModel(new BsonDocumentIndexKeysDefinition(keys), createIndexOptions); + var indexName = await collection.Indexes.CreateOneAsync(model, cancellationToken: cancellationToken); + + _logger.LogInformation("Created index {IndexName} on {DatabaseName}.{CollectionName}", indexName, databaseName, collectionName); + + return Success( + "Index created successfully", + new Dictionary + { + ["index_name"] = indexName, + ["keys"] = BsonDocumentToJson(keys), + ["options"] = BsonDocumentToJson(options) + }); + } + catch (MongoCommandException ex) when (ex.Code == 26) + { + _logger.LogWarning("Database '{DatabaseName}' not found", databaseName); + return Failure(HttpStatusCode.BadRequest, $"Database '{databaseName}' not found"); + } + catch (MongoCommandException ex) when (ex.CodeName == "NamespaceNotFound") + { + _logger.LogWarning("Collection '{CollectionName}' not found in database '{DatabaseName}'", collectionName, databaseName); + return Failure(HttpStatusCode.BadRequest, $"Collection '{collectionName}' not found"); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access creating index on {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error creating index on {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to create index: {ex.Message}"); + } + } + + public async Task ListIndexesAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + + try + { + var collection = GetCollection(connectionString, databaseName, collectionName); + var indexes = await collection.Indexes.List(cancellationToken: cancellationToken).ToListAsync(cancellationToken); + + return Success( + "Indexes retrieved successfully", + new Dictionary + { + ["indexes"] = BsonDocumentListToJson(indexes), + ["count"] = indexes.Count + }); + } + catch (MongoCommandException ex) when (ex.Code == 26) + { + _logger.LogWarning("Database '{DatabaseName}' not found", databaseName); + return Failure(HttpStatusCode.BadRequest, $"Database '{databaseName}' not found"); + } + catch (MongoCommandException ex) when (ex.CodeName == "NamespaceNotFound") + { + _logger.LogWarning("Collection '{CollectionName}' not found in database '{DatabaseName}'", collectionName, databaseName); + return Failure(HttpStatusCode.BadRequest, $"Collection '{collectionName}' not found"); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access listing indexes for {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error listing indexes for {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to list indexes: {ex.Message}"); + } + } + + public async Task DropIndexAsync(string connectionString, string databaseName, string collectionName, string indexName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + ValidateParameter(indexName, nameof(indexName)); + + try + { + var collection = GetCollection(connectionString, databaseName, collectionName); + await collection.Indexes.DropOneAsync(indexName, cancellationToken: cancellationToken); + + _logger.LogInformation("Dropped index {IndexName} from {DatabaseName}.{CollectionName}", indexName, databaseName, collectionName); + + return Success( + $"Index '{indexName}' dropped successfully", + new Dictionary + { + ["index_name"] = indexName + }); + } + catch (MongoCommandException ex) when (ex.Code == 26) + { + _logger.LogWarning("Database '{DatabaseName}' not found", databaseName); + return Failure(HttpStatusCode.BadRequest, $"Database '{databaseName}' not found"); + } + catch (MongoCommandException ex) when (ex.CodeName == "NamespaceNotFound") + { + _logger.LogWarning("Collection '{CollectionName}' not found in database '{DatabaseName}'", collectionName, databaseName); + return Failure(HttpStatusCode.BadRequest, $"Collection '{collectionName}' not found"); + } + catch (MongoCommandException ex) when (ex.CodeName == "IndexNotFound") + { + _logger.LogWarning("Index '{IndexName}' not found in {DatabaseName}.{CollectionName}", indexName, databaseName, collectionName); + return Failure(HttpStatusCode.BadRequest, $"Index '{indexName}' not found"); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access dropping index {IndexName} from {DatabaseName}.{CollectionName}", indexName, databaseName, collectionName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error dropping index {IndexName} from {DatabaseName}.{CollectionName}", indexName, databaseName, collectionName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to drop index: {ex.Message}"); + } + } + + public async Task GetIndexStatsAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + + try + { + var collection = GetCollection(connectionString, databaseName, collectionName); + var pipeline = new[] + { + new BsonDocument("$indexStats", new BsonDocument()) + }; + + var stats = await collection.Aggregate(pipeline, cancellationToken: cancellationToken).ToListAsync(cancellationToken); + + return Success( + "Index statistics retrieved successfully", + new Dictionary + { + ["stats"] = BsonDocumentListToJson(stats), + ["count"] = stats.Count + }); + } + catch (MongoCommandException ex) when (ex.Code == 26) + { + _logger.LogWarning("Database '{DatabaseName}' not found", databaseName); + return Failure(HttpStatusCode.BadRequest, $"Database '{databaseName}' not found"); + } + catch (MongoCommandException ex) when (ex.CodeName == "NamespaceNotFound") + { + _logger.LogWarning("Collection '{CollectionName}' not found in database '{DatabaseName}'", collectionName, databaseName); + return Failure(HttpStatusCode.BadRequest, $"Collection '{collectionName}' not found"); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access getting index stats for {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting index stats for {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to get index stats: {ex.Message}"); + } + } + + public async Task GetCurrentOpsAsync(string connectionString, BsonDocument? filter = null, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + + try + { + var adminDb = CreateClient(connectionString).GetDatabase("admin"); + var command = new BsonDocument("currentOp", 1); + + if (filter != null && filter.ElementCount > 0) + { + foreach (var element in filter) + { + if (string.Equals(element.Name, "currentOp", StringComparison.Ordinal)) + { + return Failure(HttpStatusCode.BadRequest, "The 'currentOp' filter field is reserved and cannot be overridden."); + } + + command[element.Name] = element.Value; + } + } + + var result = await adminDb.RunCommandAsync(command, cancellationToken: cancellationToken); + + return Success( + "Current operations retrieved successfully", + new Dictionary + { + ["operations"] = BsonDocumentToJson(result) + }); + } + catch (MongoCommandException ex) when (ex.Code == 26) + { + _logger.LogWarning("Admin database not found"); + return Failure(HttpStatusCode.BadRequest, "Admin database not found"); + } + catch (MongoCommandException ex) when (ex.CodeName == "NamespaceNotFound") + { + _logger.LogWarning("Namespace not found for current operations"); + return Failure(HttpStatusCode.BadRequest, "Namespace not found"); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access getting current operations"); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting current operations"); + return Failure(HttpStatusCode.InternalServerError, $"Failed to get current operations: {ex.Message}"); + } + } + + #endregion + + #region Database Management + + public async Task GetDatabasesAsync(string connectionString, string? dbName = null, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + + try + { + var client = CreateClient(connectionString); + var databaseNames = await client.ListDatabaseNames(cancellationToken: cancellationToken).ToListAsync(cancellationToken); + + if (!string.IsNullOrWhiteSpace(dbName) && !databaseNames.Contains(dbName, StringComparer.Ordinal)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{dbName}' was not found."); + } + + List> databases; + if (string.IsNullOrWhiteSpace(dbName)) + { + databases = databaseNames + .Select(databaseName => new Dictionary + { + ["name"] = databaseName + }) + .ToList(); + } + else + { + databases = [await GetDatabaseInfoAsync(client, dbName, cancellationToken)]; + } + + return Success( + string.IsNullOrWhiteSpace(dbName) + ? "Databases retrieved successfully." + : $"Database '{dbName}' retrieved successfully.", + databases); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access listing databases"); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error listing databases. Database: {DatabaseName}", dbName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to list databases: {ex.Message}"); + } + } + + public async Task GetDatabaseStatsAsync(string connectionString, string dbName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(dbName, nameof(dbName)); + + try + { + var client = CreateClient(connectionString); + + if (!await DatabaseExistsAsync(client, dbName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{dbName}' was not found."); + } + + var database = client.GetDatabase(dbName); + var stats = await database.RunCommandAsync(new BsonDocument("dbStats", 1), cancellationToken: cancellationToken); + + return Success($"Database statistics for '{dbName}' retrieved successfully.", stats); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access getting stats for database {DatabaseName}", dbName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting database stats for {DatabaseName}", dbName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to get database stats: {ex.Message}"); + } + } + + public async Task DropDatabaseAsync(string connectionString, string dbName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(dbName, nameof(dbName)); + + try + { + var client = CreateClient(connectionString); + + if (!await DatabaseExistsAsync(client, dbName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{dbName}' was not found."); + } + + await client.DropDatabaseAsync(dbName, cancellationToken); + + _logger.LogInformation("Dropped DocumentDB database {DatabaseName}", dbName); + + return Success( + $"Database '{dbName}' dropped successfully.", + new Dictionary + { + ["name"] = dbName, + ["deleted"] = true + }); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access dropping database {DatabaseName}", dbName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error dropping database {DatabaseName}", dbName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to drop database: {ex.Message}"); + } + } + + #endregion + + #region Collection Operations + + public async Task GetCollectionStatsAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + + try + { + var client = CreateClient(connectionString); + + if (!await DatabaseExistsAsync(client, databaseName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{databaseName}' was not found."); + } + + if (!await CollectionExistsAsync(client, databaseName, collectionName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Collection '{collectionName}' was not found in database '{databaseName}'."); + } + + var database = client.GetDatabase(databaseName); + var command = new BsonDocument { { "collStats", collectionName } }; + var stats = await database.RunCommandAsync(command, cancellationToken: cancellationToken); + + return Success($"Collection statistics for '{collectionName}' retrieved successfully.", stats); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access getting stats for collection {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting collection stats for {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to get collection stats: {ex.Message}"); + } + } + + public async Task RenameCollectionAsync(string connectionString, string databaseName, string oldName, string newName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(oldName, nameof(oldName)); + ValidateParameter(newName, nameof(newName)); + + try + { + var client = CreateClient(connectionString); + + if (!await DatabaseExistsAsync(client, databaseName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{databaseName}' was not found."); + } + + if (!await CollectionExistsAsync(client, databaseName, oldName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Collection '{oldName}' was not found in database '{databaseName}'."); + } + + if (await CollectionExistsAsync(client, databaseName, newName, cancellationToken)) + { + return Failure(HttpStatusCode.Conflict, $"Collection '{newName}' already exists in database '{databaseName}'."); + } + + var database = client.GetDatabase(databaseName); + await database.RenameCollectionAsync(oldName, newName, cancellationToken: cancellationToken); + + _logger.LogInformation("Renamed collection {OldName} to {NewName} in database {DatabaseName}", oldName, newName, databaseName); + + return Success( + $"Collection '{oldName}' was renamed to '{newName}' successfully.", + new Dictionary + { + ["databaseName"] = databaseName, + ["oldName"] = oldName, + ["newName"] = newName, + ["renamed"] = true + }); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access renaming collection {OldName} to {NewName} in database {DatabaseName}", oldName, newName, databaseName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error renaming collection {OldName} to {NewName} in database {DatabaseName}", oldName, newName, databaseName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to rename collection: {ex.Message}"); + } + } + + public async Task DropCollectionAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + + try + { + var client = CreateClient(connectionString); + + if (!await DatabaseExistsAsync(client, databaseName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{databaseName}' was not found."); + } + + if (!await CollectionExistsAsync(client, databaseName, collectionName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Collection '{collectionName}' was not found in database '{databaseName}'."); + } + + var database = client.GetDatabase(databaseName); + await database.DropCollectionAsync(collectionName, cancellationToken); + + _logger.LogWarning("Dropped collection {CollectionName} from database {DatabaseName}", collectionName, databaseName); + + return Success( + $"Collection '{collectionName}' dropped successfully.", + new Dictionary + { + ["databaseName"] = databaseName, + ["collectionName"] = collectionName, + ["deleted"] = true + }); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access dropping collection {CollectionName} from database {DatabaseName}", collectionName, databaseName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error dropping collection {CollectionName} from database {DatabaseName}", collectionName, databaseName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to drop collection: {ex.Message}"); + } + } + + public async Task SampleDocumentsAsync(string connectionString, string databaseName, string collectionName, int sampleSize = 10, CancellationToken cancellationToken = default) + { + ValidateParameter(connectionString, nameof(connectionString)); + ValidateParameter(databaseName, nameof(databaseName)); + ValidateParameter(collectionName, nameof(collectionName)); + + try + { + var client = CreateClient(connectionString); + + if (!await DatabaseExistsAsync(client, databaseName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Database '{databaseName}' was not found."); + } + + if (!await CollectionExistsAsync(client, databaseName, collectionName, cancellationToken)) + { + return Failure(HttpStatusCode.NotFound, $"Collection '{collectionName}' was not found in database '{databaseName}'."); + } + + var database = client.GetDatabase(databaseName); + var collection = database.GetCollection(collectionName); + + var pipeline = new[] + { + new BsonDocument("$sample", new BsonDocument("size", sampleSize)) + }; + + var documents = await collection.Aggregate(pipeline, cancellationToken: cancellationToken).ToListAsync(cancellationToken); + + return Success($"Retrieved {documents.Count} sample document(s) from collection '{collectionName}'.", documents); + } + catch (MongoAuthenticationException ex) + { + _logger.LogWarning(ex, "Unauthorized access sampling documents from collection {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.Unauthorized, $"Unauthorized access: {ex.Message}"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error sampling documents from collection {DatabaseName}.{CollectionName}", databaseName, collectionName); + return Failure(HttpStatusCode.InternalServerError, $"Failed to sample documents: {ex.Message}"); + } + } + + #endregion + + #region Helper Functions + + private static async Task DatabaseExistsAsync(MongoClient client, string dbName, CancellationToken cancellationToken) + { + var databaseNames = await client.ListDatabaseNames(cancellationToken: cancellationToken).ToListAsync(cancellationToken); + return databaseNames.Contains(dbName, StringComparer.Ordinal); + } + + private static async Task CollectionExistsAsync(MongoClient client, string dbName, string collectionName, CancellationToken cancellationToken) + { + var database = client.GetDatabase(dbName); + var collectionNames = await database.ListCollectionNames(cancellationToken: cancellationToken).ToListAsync(cancellationToken); + return collectionNames.Contains(collectionName, StringComparer.Ordinal); + } + + private static async Task> GetDatabaseInfoAsync(MongoClient client, string dbName, CancellationToken cancellationToken) + { + var database = client.GetDatabase(dbName); + var collectionNames = await database.ListCollectionNames(cancellationToken: cancellationToken).ToListAsync(cancellationToken); + + var collections = new List>(collectionNames.Count); + foreach (var collectionName in collectionNames) + { + var collection = database.GetCollection(collectionName); + var documentCount = await collection.CountDocumentsAsync(FilterDefinition.Empty, cancellationToken: cancellationToken); + + collections.Add(new Dictionary + { + ["name"] = collectionName, + ["documentCount"] = documentCount + }); + } + + return new Dictionary + { + ["name"] = dbName, + ["collectionCount"] = collectionNames.Count, + ["collections"] = collections + }; + } + + private static IMongoCollection GetCollection(string connectionString, string databaseName, string collectionName) + { + return CreateClient(connectionString) + .GetDatabase(databaseName) + .GetCollection(collectionName); + } + + private static MongoClient CreateClient(string connectionString) + { + var settings = MongoClientSettings.FromConnectionString(connectionString); + settings.ServerSelectionTimeout = TimeSpan.FromSeconds(10); + return new MongoClient(settings); + } + + private static CreateIndexOptions CreateIndexOptions(BsonDocument? options) + { + var createIndexOptions = new CreateIndexOptions(); + + if (options == null) + { + return createIndexOptions; + } + + if (options.Contains("unique")) + { + createIndexOptions.Unique = options["unique"].AsBoolean; + } + + if (options.Contains("name")) + { + createIndexOptions.Name = options["name"].AsString; + } + + if (options.Contains("sparse")) + { + createIndexOptions.Sparse = options["sparse"].AsBoolean; + } + + if (options.Contains("expireAfterSeconds")) + { + createIndexOptions.ExpireAfter = TimeSpan.FromSeconds(options["expireAfterSeconds"].ToInt32()); + } + + return createIndexOptions; + } + + private static string? BsonDocumentToJson(BsonDocument? doc) + { + return doc?.ToJson(s_jsonWriterSettings); + } + + private static List BsonDocumentListToJson(List docs) + { + return docs.Select(doc => doc.ToJson(s_jsonWriterSettings)).ToList(); + } + + private static DocumentDbResponse Success(string message, object? data = null) + { + return new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = message, + Data = data + }; + } + + private static DocumentDbResponse Failure(HttpStatusCode statusCode, string message) + { + return new DocumentDbResponse + { + Success = false, + StatusCode = statusCode, + Message = message, + Data = null + }; + } + + private static void ValidateParameter(string? value, string paramName) + { + if (string.IsNullOrWhiteSpace(value)) + { + throw new ArgumentException($"{paramName} cannot be null or empty", paramName); + } + } + + #endregion +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/src/Services/IDocumentDbService.cs b/tools/Azure.Mcp.Tools.DocumentDb/src/Services/IDocumentDbService.cs new file mode 100644 index 0000000000..ac53d73073 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/src/Services/IDocumentDbService.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Mcp.Tools.DocumentDb.Models; +using MongoDB.Bson; + +namespace Azure.Mcp.Tools.DocumentDb.Services; + +public interface IDocumentDbService +{ + // Index Management + Task CreateIndexAsync(string connectionString, string databaseName, string collectionName, BsonDocument keys, BsonDocument? options = null, CancellationToken cancellationToken = default); + Task ListIndexesAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default); + Task DropIndexAsync(string connectionString, string databaseName, string collectionName, string indexName, CancellationToken cancellationToken = default); + Task GetIndexStatsAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default); + Task GetCurrentOpsAsync(string connectionString, BsonDocument? filter = null, CancellationToken cancellationToken = default); + + // Database Management + Task GetDatabasesAsync(string connectionString, string? dbName = null, CancellationToken cancellationToken = default); + Task GetDatabaseStatsAsync(string connectionString, string dbName, CancellationToken cancellationToken = default); + Task DropDatabaseAsync(string connectionString, string dbName, CancellationToken cancellationToken = default); + + // Collection Operations + Task GetCollectionStatsAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default); + Task RenameCollectionAsync(string connectionString, string databaseName, string oldName, string newName, CancellationToken cancellationToken = default); + Task DropCollectionAsync(string connectionString, string databaseName, string collectionName, CancellationToken cancellationToken = default); + Task SampleDocumentsAsync(string connectionString, string databaseName, string collectionName, int sampleSize = 10, CancellationToken cancellationToken = default); +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/Azure.Mcp.Tools.DocumentDb.LiveTests.csproj b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/Azure.Mcp.Tools.DocumentDb.LiveTests.csproj new file mode 100644 index 0000000000..fa0e6b88e4 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/Azure.Mcp.Tools.DocumentDb.LiveTests.csproj @@ -0,0 +1,17 @@ + + + true + Exe + + + + + + + + + + + + + \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/DocumentDbCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/DocumentDbCommandTests.cs new file mode 100644 index 0000000000..910d95e2fe --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/DocumentDbCommandTests.cs @@ -0,0 +1,483 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using Microsoft.Mcp.Tests; +using Microsoft.Mcp.Tests.Client; +using MongoDB.Bson; +using MongoDB.Driver; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.LiveTests; + +public class DocumentDbCommandTests(ITestOutputHelper output, LiveServerFixture serverFixture) + : CommandTestsBase(output, serverFixture) +{ + private const string TestDatabaseName = "test"; + private const string CollectionName = "items"; + private static bool _testDataInitialized; + private static readonly SemaphoreSlim InitLock = new(1, 1); + + private string ConnectionString => Settings.DeploymentOutputs["DOCUMENTDB_CONNECTION_STRING"]; + + public override async ValueTask InitializeAsync() + { + await LoadSettingsAsync(); + + Assert.SkipWhen(TestMode != Microsoft.Mcp.Tests.Helpers.TestMode.Live, + "DocumentDb index tests are live-only and do not support record/playback mode"); + + SetArguments("server", "start", "--mode", "all", "--dangerously-disable-elicitation"); + await base.InitializeAsync(); + + if (_testDataInitialized) + { + return; + } + + await InitLock.WaitAsync(); + + try + { + if (_testDataInitialized) + { + return; + } + + await SeedTestDatabaseAsync(); + _testDataInitialized = true; + } + finally + { + InitLock.Release(); + } + } + + [Fact] + public async Task Should_list_indexes_with_connection_string() + { + var result = await CallToolAsync( + "documentdb_index_list_indexes", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName } + }); + + var indexesArray = result.AssertProperty("indexes"); + Assert.Equal(JsonValueKind.Array, indexesArray.ValueKind); + Assert.NotEmpty(indexesArray.EnumerateArray()); + } + + [Fact] + public async Task Should_create_and_drop_index_with_connection_string() + { + var indexName = $"value_1_mcp_{Guid.NewGuid():N}"; + + var createResult = await CallToolAsync( + "documentdb_index_create_index", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName }, + { "keys", "{\"value\":1}" }, + { "options", $"{{\"name\":\"{indexName}\"}}" } + }); + + Assert.Equal(indexName, createResult.AssertProperty("index_name").GetString()); + + var listResult = await CallToolAsync( + "documentdb_index_list_indexes", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName } + }); + + Assert.Contains(listResult.AssertProperty("indexes").EnumerateArray(), element => + element.GetString()?.Contains(indexName, StringComparison.Ordinal) == true); + + var dropResult = await CallToolAsync( + "documentdb_index_drop_index", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName }, + { "index-name", indexName } + }); + + Assert.Equal(indexName, dropResult.AssertProperty("index_name").GetString()); + } + + [Fact] + public async Task Should_get_index_stats_with_connection_string() + { + var indexName = $"category_1_mcp_{Guid.NewGuid():N}"; + + await CallToolAsync( + "documentdb_index_create_index", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName }, + { "keys", "{\"category\":1}" }, + { "options", $"{{\"name\":\"{indexName}\"}}" } + }); + + var statsResult = await CallToolAsync( + "documentdb_others_get_stats", + new() + { + { "connection-string", ConnectionString }, + { "resource-type", "index" }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName } + }); + + var stats = statsResult.AssertProperty("stats"); + Assert.Equal(JsonValueKind.Array, stats.ValueKind); + Assert.True(stats.EnumerateArray().Any()); + + await CallToolAsync( + "documentdb_index_drop_index", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName }, + { "index-name", indexName } + }); + } + + [Fact] + public async Task Should_list_all_databases() + { + var result = await CallToolAsync( + "documentdb_database_list_databases", + new() + { + { "connection-string", ConnectionString } + }); + + Assert.NotNull(result); + Assert.Equal(JsonValueKind.Array, result.Value.ValueKind); + Assert.NotEmpty(result.Value.EnumerateArray()); + + foreach (var database in result.Value.EnumerateArray()) + { + var name = database.AssertProperty("name"); + Assert.False(string.IsNullOrWhiteSpace(name.GetString())); + } + } + + [Fact] + public async Task Should_get_single_database_details_when_db_name_is_provided() + { + var result = await CallToolAsync( + "documentdb_database_list_databases", + new() + { + { "connection-string", ConnectionString }, + { "db-name", "test" } + }); + + Assert.NotNull(result); + Assert.Equal(JsonValueKind.Array, result.Value.ValueKind); + + var database = Assert.Single(result.Value.EnumerateArray()); + var name = database.AssertProperty("name"); + Assert.Equal("test", name.GetString()); + + var collectionCount = database.AssertProperty("collectionCount"); + Assert.True(collectionCount.GetInt32() >= 1); + + var collections = database.AssertProperty("collections"); + Assert.Equal(JsonValueKind.Array, collections.ValueKind); + Assert.NotEmpty(collections.EnumerateArray()); + } + + [Fact] + public async Task Should_get_database_statistics() + { + var result = await CallToolAsync( + "documentdb_others_get_stats", + new() + { + { "connection-string", ConnectionString }, + { "resource-type", "database" }, + { "db-name", "test" } + }); + + Assert.NotNull(result); + + var database = result.Value.AssertProperty("db"); + Assert.Equal("test", database.GetString()); + + var collections = result.Value.AssertProperty("collections"); + Assert.True(collections.GetInt32() >= 1); + } + + [Fact] + public async Task Should_drop_database() + { + const string databaseName = "dropme"; + + await CreateCollectionWithDocumentsAsync( + databaseName, + CollectionName, + [new BsonDocument { { "name", "drop-item" }, { "value", 1 } }]); + + var result = await CallToolAsync( + "documentdb_database_drop_database", + new() + { + { "connection-string", ConnectionString }, + { "db-name", databaseName } + }); + + Assert.NotNull(result); + + var name = result.Value.AssertProperty("name"); + Assert.Equal(databaseName, name.GetString()); + + var deleted = result.Value.AssertProperty("deleted"); + Assert.True(deleted.GetBoolean()); + } + + private async Task SeedTestDatabaseAsync() + { + const int maxAttempts = 3; + Exception? lastException = null; + + for (var attempt = 1; attempt <= maxAttempts; attempt++) + { + try + { + Output.WriteLine($"Seeding DocumentDB index test data (attempt {attempt}/{maxAttempts})..."); + + var client = new MongoClient(ConnectionString); + var database = client.GetDatabase(TestDatabaseName); + + var existingCollections = await (await database.ListCollectionNamesAsync()).ToListAsync(); + if (!existingCollections.Contains(CollectionName, StringComparer.Ordinal)) + { + await database.CreateCollectionAsync(CollectionName); + } + + var collection = database.GetCollection(CollectionName); + await collection.DeleteManyAsync(Builders.Filter.Empty); + await collection.InsertManyAsync([ + new BsonDocument { { "name", "item1" }, { "value", 100 }, { "category", "A" } }, + new BsonDocument { { "name", "item2" }, { "value", 200 }, { "category", "B" } }, + new BsonDocument { { "name", "item3" }, { "value", 300 }, { "category", "A" } } + ]); + + Output.WriteLine("DocumentDB index test data seeded successfully."); + return; + } + catch (Exception ex) + { + lastException = ex; + + if (attempt == maxAttempts) + { + break; + } + + Output.WriteLine($"DocumentDB seeding attempt {attempt} failed: {ex.Message}"); + await Task.Delay(TimeSpan.FromSeconds(10)); + } + } + + throw new InvalidOperationException("Failed to seed DocumentDB index test database.", lastException); + } + + [Fact] + public async Task Should_get_collection_statistics() + { + var result = await CallToolAsync( + "documentdb_others_get_stats", + new() + { + { "connection-string", ConnectionString }, + { "resource-type", "collection" }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName } + }); + + Assert.NotNull(result); + + var ns = result.Value.AssertProperty("ns"); + Assert.Equal($"{TestDatabaseName}.{CollectionName}", ns.GetString()); + + var count = result.Value.AssertProperty("count"); + Assert.True(count.GetInt32() >= 3); + } + + [Fact] + public async Task Should_sample_documents_from_collection() + { + const int sampleSize = 2; + + var result = await CallToolAsync( + "documentdb_collection_sample_documents", + new() + { + { "connection-string", ConnectionString }, + { "db-name", TestDatabaseName }, + { "collection-name", CollectionName }, + { "sample-size", sampleSize.ToString() } + }); + + Assert.NotNull(result); + Assert.Equal(JsonValueKind.Array, result.Value.ValueKind); + + var samples = result.Value.EnumerateArray().ToList(); + Assert.NotEmpty(samples); + Assert.True(samples.Count <= sampleSize); + + foreach (var sample in samples) + { + Assert.Equal(JsonValueKind.Object, sample.ValueKind); + } + } + + [Fact] + public async Task Should_rename_collection() + { + var databaseName = CreateUniqueName("rename-db-"); + var collectionName = CreateUniqueName("old-"); + var newCollectionName = CreateUniqueName("new-"); + + try + { + await CreateCollectionWithDocumentsAsync( + databaseName, + collectionName, + [new BsonDocument { { "name", "rename-item" }, { "value", 1 } }]); + + var result = await CallToolAsync( + "documentdb_collection_rename_collection", + new() + { + { "connection-string", ConnectionString }, + { "db-name", databaseName }, + { "collection-name", collectionName }, + { "new-collection-name", newCollectionName } + }); + + Assert.NotNull(result); + + var resultDatabaseName = result.Value.AssertProperty("databaseName"); + Assert.Equal(databaseName, resultDatabaseName.GetString()); + + var oldName = result.Value.AssertProperty("oldName"); + Assert.Equal(collectionName, oldName.GetString()); + + var newName = result.Value.AssertProperty("newName"); + Assert.Equal(newCollectionName, newName.GetString()); + + var renamed = result.Value.AssertProperty("renamed"); + Assert.True(renamed.GetBoolean()); + + Assert.False(await CollectionExistsAsync(databaseName, collectionName)); + Assert.True(await CollectionExistsAsync(databaseName, newCollectionName)); + } + finally + { + await DeleteDatabaseIfExistsAsync(databaseName); + } + } + + [Fact] + public async Task Should_drop_collection() + { + var databaseName = CreateUniqueName("drop-db-"); + var collectionName = CreateUniqueName("drop-col-"); + + try + { + await CreateCollectionWithDocumentsAsync( + databaseName, + collectionName, + [new BsonDocument { { "name", "drop-item" }, { "value", 1 } }]); + + var result = await CallToolAsync( + "documentdb_collection_drop_collection", + new() + { + { "connection-string", ConnectionString }, + { "db-name", databaseName }, + { "collection-name", collectionName } + }); + + Assert.NotNull(result); + + var resultDatabaseName = result.Value.AssertProperty("databaseName"); + Assert.Equal(databaseName, resultDatabaseName.GetString()); + + var resultCollectionName = result.Value.AssertProperty("collectionName"); + Assert.Equal(collectionName, resultCollectionName.GetString()); + + var deleted = result.Value.AssertProperty("deleted"); + Assert.True(deleted.GetBoolean()); + + Assert.False(await CollectionExistsAsync(databaseName, collectionName)); + } + finally + { + await DeleteDatabaseIfExistsAsync(databaseName); + } + } + + private static string CreateUniqueName(string prefix) + { + return $"{prefix}{Guid.NewGuid():N}"; + } + + private async Task CreateCollectionWithDocumentsAsync(string databaseName, string collectionName, IEnumerable documents) + { + var client = new MongoClient(ConnectionString); + var database = client.GetDatabase(databaseName); + + var existingCollections = await (await database.ListCollectionNamesAsync()).ToListAsync(); + if (!existingCollections.Contains(collectionName, StringComparer.Ordinal)) + { + await database.CreateCollectionAsync(collectionName); + } + + var collection = database.GetCollection(collectionName); + await collection.DeleteManyAsync(Builders.Filter.Empty); + + var documentsList = documents.ToList(); + if (documentsList.Count > 0) + { + await collection.InsertManyAsync(documentsList); + } + } + + private async Task CollectionExistsAsync(string databaseName, string collectionName) + { + var client = new MongoClient(ConnectionString); + var database = client.GetDatabase(databaseName); + var collections = await (await database.ListCollectionNamesAsync()).ToListAsync(); + + return collections.Contains(collectionName, StringComparer.Ordinal); + } + + private async Task DeleteDatabaseIfExistsAsync(string databaseName) + { + var client = new MongoClient(ConnectionString); + var databases = await (await client.ListDatabaseNamesAsync()).ToListAsync(); + + if (databases.Contains(databaseName, StringComparer.Ordinal)) + { + await client.DropDatabaseAsync(databaseName); + } + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/assets.json b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/assets.json new file mode 100644 index 0000000000..f838934cca --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.LiveTests/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "", + "TagPrefix": "Azure.Mcp.Tools.DocumentDb.LiveTests", + "Tag": "" +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Azure.Mcp.Tools.DocumentDb.UnitTests.csproj b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Azure.Mcp.Tools.DocumentDb.UnitTests.csproj new file mode 100644 index 0000000000..7b8c62121b --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Azure.Mcp.Tools.DocumentDb.UnitTests.csproj @@ -0,0 +1,17 @@ + + + true + Exe + + + + + + + + + + + + + \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/DropCollectionCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/DropCollectionCommandTests.cs new file mode 100644 index 0000000000..4e4715f0ca --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/DropCollectionCommandTests.cs @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Collection; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Collection; + +public class DropCollectionCommandTests +{ + private const string ConnectionString = "mongodb://localhost:27017"; + + private readonly IServiceProvider _serviceProvider; + private readonly IDocumentDbService _documentDbService; + private readonly ILogger _logger; + private readonly DropCollectionCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public DropCollectionCommandTests() + { + _documentDbService = Substitute.For(); + _logger = Substitute.For>(); + _command = new(_logger); + _commandDefinition = _command.GetCommand(); + _serviceProvider = new ServiceCollection() + .AddSingleton(_documentDbService) + .BuildServiceProvider(); + _context = new(_serviceProvider); + } + + [Fact] + public async Task ExecuteAsync_DropsCollection_WhenCollectionExists() + { + // Arrange + var dbName = "testdb"; + var collectionName = "testcollection"; + + _documentDbService.DropCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Collection '{collectionName}' dropped successfully.", + Data = new Dictionary + { + ["databaseName"] = dbName, + ["collectionName"] = collectionName, + ["deleted"] = true + } + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenCollectionNotFound() + { + // Arrange + var dbName = "testdb"; + var collectionName = "nonexistent"; + + _documentDbService.DropCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Collection '{collectionName}' was not found in database '{dbName}'." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenDatabaseNotFound() + { + // Arrange + var dbName = "nonexistentdb"; + var collectionName = "testcollection"; + + _documentDbService.DropCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Database '{dbName}' was not found." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Theory] + [InlineData("--connection-string", ConnectionString, "--db-name", "testdb")] + [InlineData("--connection-string", ConnectionString, "--collection-name", "testcollection")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + // Arrange & Act + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns500_OnUnexpectedException() + { + // Arrange + var dbName = "testdb"; + var collectionName = "testcollection"; + var expectedError = "Unexpected error occurred"; + + _documentDbService.DropCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.InternalServerError, + Message = $"Failed to drop collection: {expectedError}" + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.InternalServerError, response.Status); + Assert.Contains(expectedError, response.Message); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/RenameCollectionCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/RenameCollectionCommandTests.cs new file mode 100644 index 0000000000..1c2acf7c6f --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/RenameCollectionCommandTests.cs @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Collection; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Collection; + +public class RenameCollectionCommandTests +{ + private const string ConnectionString = "mongodb://localhost:27017"; + + private readonly IServiceProvider _serviceProvider; + private readonly IDocumentDbService _documentDbService; + private readonly ILogger _logger; + private readonly RenameCollectionCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public RenameCollectionCommandTests() + { + _documentDbService = Substitute.For(); + _logger = Substitute.For>(); + _command = new(_logger); + _commandDefinition = _command.GetCommand(); + _serviceProvider = new ServiceCollection() + .AddSingleton(_documentDbService) + .BuildServiceProvider(); + _context = new(_serviceProvider); + } + + [Fact] + public async Task ExecuteAsync_RenamesCollection_WhenCollectionExists() + { + // Arrange + var dbName = "testdb"; + var collectionName = "oldname"; + var newCollectionName = "newname"; + + _documentDbService.RenameCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(newCollectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Collection '{collectionName}' was renamed to '{newCollectionName}' successfully.", + Data = new Dictionary + { + ["databaseName"] = dbName, + ["oldName"] = collectionName, + ["newName"] = newCollectionName, + ["renamed"] = true + } + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--new-collection-name", newCollectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenCollectionNotFound() + { + // Arrange + var dbName = "testdb"; + var collectionName = "nonexistent"; + var newCollectionName = "newname"; + + _documentDbService.RenameCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(newCollectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Collection '{collectionName}' was not found in database '{dbName}'." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--new-collection-name", newCollectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenDatabaseNotFound() + { + // Arrange + var dbName = "nonexistentdb"; + var collectionName = "testcollection"; + var newCollectionName = "newname"; + + _documentDbService.RenameCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(newCollectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Database '{dbName}' was not found." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--new-collection-name", newCollectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns409_WhenNewNameAlreadyExists() + { + // Arrange + var dbName = "testdb"; + var collectionName = "oldname"; + var newCollectionName = "existingname"; + var expectedError = $"Collection '{newCollectionName}' already exists in database '{dbName}'."; + + _documentDbService.RenameCollectionAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(newCollectionName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.Conflict, + Message = $"Collection '{newCollectionName}' already exists in database '{dbName}'." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--new-collection-name", newCollectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.Conflict, response.Status); + Assert.Contains(expectedError, response.Message); + } + + [Theory] + [InlineData("--connection-string", ConnectionString, "--db-name", "testdb", "--collection-name", "oldname")] + [InlineData("--connection-string", ConnectionString, "--db-name", "testdb", "--new-collection-name", "newname")] + [InlineData("--connection-string", ConnectionString, "--collection-name", "oldname", "--new-collection-name", "newname")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + // Arrange & Act + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLower()); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/SampleDocumentsCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/SampleDocumentsCommandTests.cs new file mode 100644 index 0000000000..81bcd7e03a --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Collection/SampleDocumentsCommandTests.cs @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Collection; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using MongoDB.Bson; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Collection; + +public class SampleDocumentsCommandTests +{ + private const string ConnectionString = "mongodb://localhost:27017"; + + private readonly IServiceProvider _serviceProvider; + private readonly IDocumentDbService _documentDbService; + private readonly ILogger _logger; + private readonly SampleDocumentsCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public SampleDocumentsCommandTests() + { + _documentDbService = Substitute.For(); + _logger = Substitute.For>(); + _command = new(_logger); + _commandDefinition = _command.GetCommand(); + _serviceProvider = new ServiceCollection() + .AddSingleton(_documentDbService) + .BuildServiceProvider(); + _context = new(_serviceProvider); + } + + [Fact] + public async Task ExecuteAsync_ReturnsSampleDocuments_WhenDocumentsExist() + { + // Arrange + var dbName = "testdb"; + var collectionName = "testcollection"; + var sampleSize = 5; + var expectedDocuments = new List + { + new() { { "_id", ObjectId.GenerateNewId() }, { "name", "doc1" } }, + new() { { "_id", ObjectId.GenerateNewId() }, { "name", "doc2" } } + }; + + _documentDbService.SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(sampleSize), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Retrieved {expectedDocuments.Count} sample document(s) from collection '{collectionName}'.", + Data = expectedDocuments + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--sample-size", sampleSize.ToString() + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_ReturnsEmptyList_WhenNoDocumentsExist() + { + // Arrange + var dbName = "testdb"; + var collectionName = "emptycollection"; + var emptyList = new List(); + + _documentDbService.SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any(), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Retrieved 0 sample document(s) from collection '{collectionName}'.", + Data = emptyList + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_UsesDefaultSampleSize_WhenNotProvided() + { + // Arrange + var dbName = "testdb"; + var collectionName = "testcollection"; + var defaultSampleSize = 10; + var emptyList = new List(); + + _documentDbService.SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(defaultSampleSize), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Retrieved 0 sample document(s) from collection '{collectionName}'.", + Data = emptyList + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + await _documentDbService.Received(1).SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(defaultSampleSize), + Arg.Any()); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenCollectionNotFound() + { + // Arrange + var dbName = "testdb"; + var collectionName = "nonexistent"; + + _documentDbService.SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any(), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Collection '{collectionName}' was not found in database '{dbName}'." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenDatabaseNotFound() + { + // Arrange + var dbName = "nonexistentdb"; + var collectionName = "testcollection"; + + _documentDbService.SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Any(), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Database '{dbName}' was not found." + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Theory] + [InlineData("--connection-string", ConnectionString, "--db-name", "testdb")] + [InlineData("--connection-string", ConnectionString, "--collection-name", "testcollection")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + // Arrange & Act + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLower()); + } + + [Theory] + [InlineData(1)] + [InlineData(100)] + [InlineData(1000)] + public async Task ExecuteAsync_HandlesVariousSampleSizes(int sampleSize) + { + // Arrange + var dbName = "testdb"; + var collectionName = "testcollection"; + var emptyList = new List(); + + _documentDbService.SampleDocumentsAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Is(collectionName), + Arg.Is(sampleSize), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Retrieved 0 sample document(s) from collection '{collectionName}'.", + Data = emptyList + }); + + var args = _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--sample-size", sampleSize.ToString() + ]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Database/DropDatabaseCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Database/DropDatabaseCommandTests.cs new file mode 100644 index 0000000000..d2a871dcab --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Database/DropDatabaseCommandTests.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Database; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Database; + +public class DropDatabaseCommandTests +{ + private const string ConnectionString = "mongodb://localhost:27017"; + + private readonly IServiceProvider _serviceProvider; + private readonly IDocumentDbService _documentDbService; + private readonly ILogger _logger; + private readonly DropDatabaseCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public DropDatabaseCommandTests() + { + _documentDbService = Substitute.For(); + _logger = Substitute.For>(); + _command = new(_logger); + _commandDefinition = _command.GetCommand(); + _serviceProvider = new ServiceCollection() + .AddSingleton(_documentDbService) + .BuildServiceProvider(); + _context = new(_serviceProvider); + } + + [Fact] + public async Task ExecuteAsync_DropsDatabase_WhenDatabaseExists() + { + // Arrange + var dbName = "testdb"; + + _documentDbService.DropDatabaseAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Database '{dbName}' dropped successfully", + Data = new Dictionary + { + ["name"] = dbName, + ["deleted"] = true + } + }); + + var args = _commandDefinition.Parse(["--connection-string", ConnectionString, "--db-name", dbName]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenDatabaseNotFound() + { + // Arrange + var dbName = "nonexistentdb"; + + _documentDbService.DropDatabaseAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Database '{dbName}' was not found." + }); + + var args = _commandDefinition.Parse(["--connection-string", ConnectionString, "--db-name", dbName]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns400_WhenDbNameIsMissing() + { + // Arrange & Act + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([]), TestContext.Current.CancellationToken); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLower()); + } + + [Fact] + public async Task ExecuteAsync_Returns500_OnUnexpectedError() + { + // Arrange + var dbName = "testdb"; + var expectedError = "Unexpected error occurred"; + + _documentDbService.DropDatabaseAsync( + Arg.Is(ConnectionString), + Arg.Is(dbName), + Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.InternalServerError, + Message = $"Failed to drop database: {expectedError}" + }); + + var args = _commandDefinition.Parse(["--connection-string", ConnectionString, "--db-name", dbName]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.InternalServerError, response.Status); + Assert.Contains("Failed to drop database", response.Message); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Database/ListDatabasesCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Database/ListDatabasesCommandTests.cs new file mode 100644 index 0000000000..84ee0ff6f0 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Database/ListDatabasesCommandTests.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Database; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Database; + +public class ListDatabasesCommandTests +{ + private const string ConnectionString = "mongodb://localhost:27017"; + + private readonly IServiceProvider _serviceProvider; + private readonly IDocumentDbService _documentDbService; + private readonly ILogger _logger; + private readonly ListDatabasesCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public ListDatabasesCommandTests() + { + _documentDbService = Substitute.For(); + _logger = Substitute.For>(); + _command = new(_logger); + _commandDefinition = _command.GetCommand(); + _serviceProvider = new ServiceCollection() + .AddSingleton(_documentDbService) + .BuildServiceProvider(); + _context = new(_serviceProvider); + } + + [Fact] + public async Task ExecuteAsync_ReturnsDatabases_WhenDatabasesExist() + { + // Arrange + var expectedDatabases = new List> + { + new() + { + ["name"] = "database1" + }, + new() + { + ["name"] = "database2" + } + }; + + _documentDbService.GetDatabasesAsync(ConnectionString, null, Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Databases retrieved successfully.", + Data = expectedDatabases + }); + + var args = _commandDefinition.Parse(["--connection-string", ConnectionString]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + + await _documentDbService.Received(1).GetDatabasesAsync(ConnectionString, null, Arg.Any()); + } + + [Fact] + public async Task ExecuteAsync_ReturnsSingleDatabase_WhenDbNameIsProvided() + { + // Arrange + const string dbName = "database1"; + + _documentDbService.GetDatabasesAsync(ConnectionString, dbName, Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = $"Database '{dbName}' retrieved successfully.", + Data = new List> + { + new() + { + ["name"] = dbName, + ["collectionCount"] = 1, + ["collections"] = new List> + { + new() { ["name"] = "items", ["documentCount"] = 42L } + } + } + } + }); + + var args = _commandDefinition.Parse(["--connection-string", ConnectionString, "--db-name", dbName]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + + await _documentDbService.Received(1).GetDatabasesAsync(ConnectionString, dbName, Arg.Any()); + } + + [Fact] + public async Task ExecuteAsync_Returns404_WhenDatabaseIsMissing() + { + // Arrange + const string dbName = "missingdb"; + + _documentDbService.GetDatabasesAsync(ConnectionString, dbName, Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.NotFound, + Message = $"Database '{dbName}' was not found." + }); + + var args = _commandDefinition.Parse(["--connection-string", ConnectionString, "--db-name", dbName]); + + // Act + var response = await _command.ExecuteAsync(_context, args, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.NotFound, response.Status); + Assert.Contains("not found", response.Message, StringComparison.OrdinalIgnoreCase); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/CreateIndexCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/CreateIndexCommandTests.cs new file mode 100644 index 0000000000..e9f2675491 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/CreateIndexCommandTests.cs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Index; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using MongoDB.Bson; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Index; + +public class CreateIndexCommandTests +{ + private readonly IDocumentDbService _documentDbService; + private readonly CreateIndexCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public CreateIndexCommandTests() + { + _documentDbService = Substitute.For(); + _command = new(Substitute.For>()); + _commandDefinition = _command.GetCommand(); + _context = new(new ServiceCollection().AddSingleton(_documentDbService).BuildServiceProvider()); + } + + [Fact] + public async Task ExecuteAsync_CreatesIndex_WhenValidKeysProvided() + { + const string connectionString = "mongodb://localhost:27017"; + const string dbName = "testdb"; + const string collectionName = "testcollection"; + const string keys = "{\"status\": 1}"; + + _documentDbService.CreateIndexAsync(connectionString, dbName, collectionName, Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Index created successfully", + Data = new Dictionary { ["index_name"] = "status_1" } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--keys", keys]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_CreatesIndexWithOptions_WhenOptionsProvided() + { + const string connectionString = "mongodb://localhost:27017"; + const string dbName = "testdb"; + const string collectionName = "testcollection"; + const string keys = "{\"email\": 1}"; + const string options = "{\"unique\": true}"; + + _documentDbService.CreateIndexAsync(connectionString, dbName, collectionName, Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Index created successfully", + Data = new Dictionary { ["index_name"] = "email_1" } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", dbName, + "--collection-name", collectionName, + "--keys", keys, + "--options", options]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns400_WhenCollectionNotFound() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.CreateIndexAsync(connectionString, "testdb", "nonexistent", Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.BadRequest, + Message = "Collection 'nonexistent' not found" + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "nonexistent", + "--keys", "{\"status\": 1}"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("not found", response.Message); + } + + [Theory] + [InlineData("--connection-string", "mongodb://localhost:27017", "--db-name", "testdb", "--collection-name", "coll")] + [InlineData("--connection-string", "mongodb://localhost:27017", "--db-name", "testdb", "--keys", "{\"a\":1}")] + [InlineData("--connection-string", "mongodb://localhost:27017", "--collection-name", "coll", "--keys", "{\"a\":1}")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLowerInvariant()); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/DropIndexCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/DropIndexCommandTests.cs new file mode 100644 index 0000000000..865d23a4f8 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/DropIndexCommandTests.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Index; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Index; + +public class DropIndexCommandTests +{ + private readonly IDocumentDbService _documentDbService; + private readonly DropIndexCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public DropIndexCommandTests() + { + _documentDbService = Substitute.For(); + _command = new(Substitute.For>()); + _commandDefinition = _command.GetCommand(); + _context = new(new ServiceCollection().AddSingleton(_documentDbService).BuildServiceProvider()); + } + + [Fact] + public async Task ExecuteAsync_DropsIndex_WhenIndexExists() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.DropIndexAsync(connectionString, "testdb", "testcollection", "status_1", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Index dropped successfully", + Data = new Dictionary { ["index_name"] = "status_1" } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "testcollection", + "--index-name", "status_1"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns400_WhenCollectionNotFound() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.DropIndexAsync(connectionString, "testdb", "nonexistent", "status_1", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.BadRequest, + Message = "Collection 'nonexistent' not found" + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "nonexistent", + "--index-name", "status_1"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("not found", response.Message); + } + + [Fact] + public async Task ExecuteAsync_Returns400_WhenIndexDoesNotExist() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.DropIndexAsync(connectionString, "testdb", "testcollection", "nonexistent_index", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.BadRequest, + Message = "Index 'nonexistent_index' not found" + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "testcollection", + "--index-name", "nonexistent_index"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("not found", response.Message); + } + + [Theory] + [InlineData("--connection-string", "mongodb://localhost:27017", "--db-name", "testdb", "--collection-name", "coll")] + [InlineData("--connection-string", "mongodb://localhost:27017", "--db-name", "testdb", "--index-name", "idx")] + [InlineData("--connection-string", "mongodb://localhost:27017", "--collection-name", "coll", "--index-name", "idx")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLowerInvariant()); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/ListIndexesCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/ListIndexesCommandTests.cs new file mode 100644 index 0000000000..850ae59313 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Index/ListIndexesCommandTests.cs @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Index; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Index; + +public class ListIndexesCommandTests +{ + private readonly IDocumentDbService _documentDbService; + private readonly ListIndexesCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public ListIndexesCommandTests() + { + _documentDbService = Substitute.For(); + _command = new(Substitute.For>()); + _commandDefinition = _command.GetCommand(); + _context = new(new ServiceCollection().AddSingleton(_documentDbService).BuildServiceProvider()); + } + + [Fact] + public async Task ExecuteAsync_ReturnsIndexes_WhenIndexesExist() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.ListIndexesAsync(connectionString, "testdb", "testcollection", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Indexes retrieved successfully", + Data = new Dictionary + { + ["indexes"] = new List { "{\"name\":\"_id_\"}" }, + ["count"] = 1 + } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "testcollection"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns400_WhenCollectionNotFound() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.ListIndexesAsync(connectionString, "testdb", "nonexistent", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.BadRequest, + Message = "Collection 'nonexistent' not found" + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "nonexistent"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("not found", response.Message); + } + + [Fact] + public async Task ExecuteAsync_Returns401_WhenUnauthorized() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.ListIndexesAsync(connectionString, "testdb", "testcollection", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.Unauthorized, + Message = "Unauthorized access" + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--db-name", "testdb", + "--collection-name", "testcollection"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.Unauthorized, response.Status); + Assert.Contains("Unauthorized access", response.Message); + } + + [Theory] + [InlineData("--connection-string", "mongodb://localhost:27017", "--db-name", "testdb")] + [InlineData("--connection-string", "mongodb://localhost:27017", "--collection-name", "testcollection")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLowerInvariant()); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Others/CurrentOpsCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Others/CurrentOpsCommandTests.cs new file mode 100644 index 0000000000..445fedc827 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Others/CurrentOpsCommandTests.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Others; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using MongoDB.Bson; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Others; + +public class CurrentOpsCommandTests +{ + private readonly IDocumentDbService _documentDbService; + private readonly CurrentOpsCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public CurrentOpsCommandTests() + { + _documentDbService = Substitute.For(); + _command = new(Substitute.For>()); + _commandDefinition = _command.GetCommand(); + _context = new(new ServiceCollection().AddSingleton(_documentDbService).BuildServiceProvider()); + } + + [Fact] + public async Task ExecuteAsync_ReturnsOps_WhenOpsExist() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.GetCurrentOpsAsync(connectionString, Arg.Any(), Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Current operations retrieved successfully", + Data = new Dictionary { ["operations"] = "{\"inprog\":[]}" } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_ReturnsFilteredOps_WhenFilterProvided() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.GetCurrentOpsAsync(connectionString, Arg.Any(), Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Current operations retrieved successfully", + Data = new Dictionary { ["operations"] = "{\"inprog\":[{\"op\":\"query\"}]}" } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString, + "--ops", "{\"op\":\"query\"}"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns500_WhenServiceFails() + { + const string connectionString = "mongodb://localhost:27017"; + + _documentDbService.GetCurrentOpsAsync(connectionString, Arg.Any(), Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = false, + StatusCode = HttpStatusCode.InternalServerError, + Message = "Failed to retrieve current operations" + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", connectionString]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.InternalServerError, response.Status); + Assert.Contains("Failed to retrieve current operations", response.Message); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Others/GetStatsCommandTests.cs b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Others/GetStatsCommandTests.cs new file mode 100644 index 0000000000..96065910dd --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/Azure.Mcp.Tools.DocumentDb.UnitTests/Others/GetStatsCommandTests.cs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.CommandLine; +using System.Net; +using Azure.Mcp.Tools.DocumentDb.Commands.Others; +using Azure.Mcp.Tools.DocumentDb.Models; +using Azure.Mcp.Tools.DocumentDb.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Models.Command; +using MongoDB.Bson; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.DocumentDb.UnitTests.Others; + +public class GetStatsCommandTests +{ + private const string ConnectionString = "mongodb://localhost:27017"; + + private readonly IDocumentDbService _documentDbService; + private readonly GetStatsCommand _command; + private readonly CommandContext _context; + private readonly Command _commandDefinition; + + public GetStatsCommandTests() + { + _documentDbService = Substitute.For(); + _command = new(Substitute.For>()); + _commandDefinition = _command.GetCommand(); + _context = new(new ServiceCollection().AddSingleton(_documentDbService).BuildServiceProvider()); + } + + [Fact] + public async Task ExecuteAsync_ReturnsDatabaseStats_WhenResourceTypeIsDatabase() + { + _documentDbService.GetDatabaseStatsAsync(ConnectionString, "testdb", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Database statistics retrieved successfully", + Data = new BsonDocument { { "db", "testdb" }, { "collections", 5 } } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--resource-type", "database", + "--db-name", "testdb"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_ReturnsCollectionStats_WhenResourceTypeIsCollection() + { + _documentDbService.GetCollectionStatsAsync(ConnectionString, "testdb", "testcollection", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Collection statistics retrieved successfully", + Data = new BsonDocument { { "ns", "testdb.testcollection" }, { "count", 10 } } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--resource-type", "collection", + "--db-name", "testdb", + "--collection-name", "testcollection"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_ReturnsIndexStats_WhenResourceTypeIsIndex() + { + _documentDbService.GetIndexStatsAsync(ConnectionString, "testdb", "testcollection", Arg.Any()) + .Returns(new DocumentDbResponse + { + Success = true, + StatusCode = HttpStatusCode.OK, + Message = "Index statistics retrieved successfully", + Data = new Dictionary + { + ["stats"] = new List { "{\"name\":\"_id_\"}" }, + ["count"] = 1 + } + }); + + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--resource-type", "index", + "--db-name", "testdb", + "--collection-name", "testcollection"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.Status); + Assert.NotNull(response.Results); + } + + [Fact] + public async Task ExecuteAsync_Returns400_WhenCollectionNameMissingForCollectionStats() + { + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse([ + "--connection-string", ConnectionString, + "--resource-type", "collection", + "--db-name", "testdb"]), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("--collection-name is required", response.Message); + } + + [Theory] + [InlineData("--connection-string", ConnectionString, "--resource-type", "database")] + [InlineData("--connection-string", ConnectionString, "--db-name", "testdb")] + public async Task ExecuteAsync_Returns400_WhenRequiredParametersAreMissing(params string[] args) + { + var response = await _command.ExecuteAsync(_context, _commandDefinition.Parse(args), TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("required", response.Message.ToLowerInvariant()); + } +} \ No newline at end of file diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/test-resources-post.ps1 b/tools/Azure.Mcp.Tools.DocumentDb/tests/test-resources-post.ps1 new file mode 100644 index 0000000000..c279c90355 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/test-resources-post.ps1 @@ -0,0 +1,94 @@ +param( + [string] $TenantId, + [string] $TestApplicationId, + [string] $ResourceGroupName, + [string] $BaseName, + [hashtable] $DeploymentOutputs +) + +$ErrorActionPreference = "Stop" + +. "$PSScriptRoot/../../../eng/common/scripts/common.ps1" +. "$PSScriptRoot/../../../eng/scripts/helpers/TestResourcesHelpers.ps1" + +$testSettings = New-TestSettings @PSBoundParameters -OutputPath $PSScriptRoot + +# $testSettings contains: +# - TenantId +# - TenantName +# - SubscriptionId +# - SubscriptionName +# - ResourceGroupName +# - ResourceBaseName + +# $DeploymentOutputs keys are all UPPERCASE + +Write-Host "Test resources deployed successfully for DocumentDB" +Write-Host "Connection string saved to .testsettings.json" + +# Initialize test database and collections using MongoDB driver +Write-Host "Initializing test database and collections..." + +try { + # Check if mongosh is available + $mongoshPath = Get-Command mongosh -ErrorAction SilentlyContinue + + if ($null -eq $mongoshPath) { + Write-Warning "mongosh not found. Skipping database initialization." + Write-Warning "You may need to manually create the 'test' database and 'items' collection." + Write-Warning "Install mongosh from: https://www.mongodb.com/try/download/shell" + } else { + $connectionString = $DeploymentOutputs['DOCUMENTDB_CONNECTION_STRING'] + + # Wait for firewall rules to propagate + Write-Host "Waiting for firewall rules to propagate (30 seconds)..." + Start-Sleep -Seconds 30 + + # Create init script + $initScript = @" +use test +db.createCollection('items') +db.items.insertMany([ + { name: 'item1', value: 100, category: 'A' }, + { name: 'item2', value: 200, category: 'B' }, + { name: 'item3', value: 300, category: 'A' } +]) +print('Test database and collection initialized successfully') +"@ + + $scriptPath = Join-Path $env:TEMP "documentdb-init.js" + $initScript | Out-File -FilePath $scriptPath -Encoding UTF8 + + Write-Host "Running initialization script..." + $retries = 3 + $success = $false + + for ($i = 1; $i -le $retries; $i++) { + try { + Write-Host "Attempt $i of $retries..." + & mongosh "$connectionString" --file $scriptPath --quiet + $success = $true + Write-Host "Database initialization completed successfully" + break + } catch { + if ($i -lt $retries) { + Write-Warning "Connection failed, retrying in 10 seconds..." + Start-Sleep -Seconds 10 + } else { + throw + } + } + } + + Remove-Item $scriptPath -ErrorAction SilentlyContinue + + if (-not $success) { + Write-Warning "Database initialization failed after $retries attempts." + Write-Warning "You may need to manually initialize the database and collection." + } + } +} catch { + Write-Warning "Failed to initialize database: $_" + Write-Warning "Tests may fail if database and collections don't exist." + Write-Warning "You can manually run: mongosh `"$($DeploymentOutputs['DOCUMENTDB_CONNECTION_STRING'])`"" +} diff --git a/tools/Azure.Mcp.Tools.DocumentDb/tests/test-resources.bicep b/tools/Azure.Mcp.Tools.DocumentDb/tests/test-resources.bicep new file mode 100644 index 0000000000..0799ada072 --- /dev/null +++ b/tools/Azure.Mcp.Tools.DocumentDb/tests/test-resources.bicep @@ -0,0 +1,70 @@ +targetScope = 'resourceGroup' + +@minLength(3) +@maxLength(40) +@description('The base resource name.') +param baseName string = resourceGroup().name + +@description('The location of the resource. By default, this is the same as the resource group.') +param location string = 'westus' == resourceGroup().location ? 'westus2' : resourceGroup().location + +@description('Enable an additional public firewall rule for local development. Leave disabled for CI and shared test environments.') +param enablePublicIpRule bool = false + +@description('Start IP address for the optional public firewall rule used for local development.') +param allowedStartIpAddress string = '0.0.0.0' + +@description('End IP address for the optional public firewall rule used for local development.') +param allowedEndIpAddress string = '255.255.255.255' + +var administratorLogin = 'testadmin' +// Use a password without special characters that need URL encoding (! and @ cause issues) +var administratorLoginPassword = 'Pass${uniqueString(resourceGroup().id)}0rd' + +// DocumentDB (Azure Cosmos DB for MongoDB vCore) account +resource documentDbAccount 'Microsoft.DocumentDB/mongoClusters@2024-03-01-preview' = { + name: '${take(baseName, 30)}-ddb' + location: location + properties: { + administratorLogin: administratorLogin + administratorLoginPassword: administratorLoginPassword + serverVersion: '5.0' + nodeGroupSpecs: [ + { + kind: 'Shard' + sku: 'M30' + diskSizeGB: 128 + enableHa: false + nodeCount: 1 + } + ] + publicNetworkAccess: 'Enabled' + } +} + +// Allow access from Azure services (enables test proxy and Azure pipelines) +resource allowAzureServices 'Microsoft.DocumentDB/mongoClusters/firewallRules@2024-03-01-preview' = { + parent: documentDbAccount + name: 'AllowAzureServices' + properties: { + startIpAddress: '0.0.0.0' + endIpAddress: '0.0.0.0' + } +} + +// Optional public IP rule for local development. +// Keep disabled for CI/shared environments and prefer a narrow caller IP range when enabled. +resource allowPublicIpRange 'Microsoft.DocumentDB/mongoClusters/firewallRules@2024-03-01-preview' = if (enablePublicIpRule) { + parent: documentDbAccount + name: 'AllowPublicIpRange' + properties: { + startIpAddress: allowedStartIpAddress + endIpAddress: allowedEndIpAddress + } +} + +// Output the connection string (will be sanitized in tests) +// The connectionString property returns a template like: mongodb+srv://:@host... +// We need to replace and with actual credentials +output DOCUMENTDB_ENDPOINT string = documentDbAccount.properties.connectionString +output DOCUMENTDB_CONNECTION_STRING string = replace(replace(documentDbAccount.properties.connectionString, '', administratorLogin), '', administratorLoginPassword) \ No newline at end of file