-
Notifications
You must be signed in to change notification settings - Fork 7
feat: Add MCP integration #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: Add MCP integration #80
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds Model Context Protocol (MCP) integration to dnet, enabling it to function as an MCP server for AI assistants like Claude Desktop and Cursor. The integration exposes a new /mcp endpoint with six tools for model management and inference, making dnet accessible through the MCP protocol.
Key Changes
- Added comprehensive MCP server implementation with tools for chat completion, model loading/unloading, and cluster status queries
- Integrated MCP server into existing FastAPI application using mounted ASGI app pattern with shared lifespan management
- Added fastmcp dependency to support the new functionality
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
src/dnet/api/mcp_handler.py |
New file implementing MCP server with 6 tools (chat_completion, load_model, unload_model, list_models, get_status, get_cluster_details), custom error handling with JSON-RPC error codes, and 3 resources for protocol compliance |
src/dnet/api/http_api.py |
Integrates MCP server by mounting FastMCP app at /mcp path and using its lifespan manager for the main FastAPI application |
pyproject.toml |
Adds fastmcp as a new dependency for MCP protocol support |
README.md |
Documents MCP integration with setup instructions for Claude Desktop/Cursor using mcp-remote, and lists available tools |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/dnet/api/mcp_handler.py
Outdated
| Args: | ||
| model: Model ID from catalog | ||
| kv_bits: KV cache quantization |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace detected in the docstring. The line ends with extra spaces after "quantization". This should be removed for consistency with code style.
| kv_bits: KV cache quantization | |
| kv_bits: KV cache quantization |
src/dnet/api/mcp_handler.py
Outdated
| stop=stops, | ||
| repetition_penalty=repetition_penalty, | ||
| stream=False, | ||
| ) |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The closing parenthesis of ChatRequestModel construction is misaligned. It should either be aligned with the opening line (at the same indentation as req = ChatRequestModel() or with the parameters inside. The current placement with 8 spaces is inconsistent with Python style conventions.
| ) | |
| ) |
| try: | ||
| req = APILoadModelRequest( | ||
| model=model, | ||
| kv_bits=kv_bits, | ||
| seq_len=seq_len, | ||
| ) | ||
| if ctx: | ||
| await ctx.info(f"Starting to load model: {req.model}") | ||
|
|
||
| if model_manager.current_model_id == req.model: | ||
| return f"Model '{req.model}' is already loaded." | ||
|
|
||
| topology = cluster_manager.current_topology | ||
| if topology is None or topology.model != req.model: | ||
| if ctx: | ||
| await ctx.info("Preparing topology ...") | ||
|
|
||
| await cluster_manager.scan_devices() | ||
| if not cluster_manager.shards: | ||
| raise McpError( | ||
| -32002, | ||
| "No shards discovered. Check shard connectivity.", | ||
| data={"action": "check_shard_connectivity"} | ||
| ) | ||
|
|
||
| if ctx: | ||
| await ctx.info("Profiling cluster performance") | ||
|
|
||
| model_config = get_model_config_json(req.model) | ||
| embedding_size = int(model_config["hidden_size"]) | ||
| num_layers = int(model_config["num_hidden_layers"]) | ||
|
|
||
| batch_sizes = [1] | ||
| profiles = await cluster_manager.profile_cluster( | ||
| req.model, embedding_size, 2, batch_sizes | ||
| ) | ||
| if not profiles: | ||
| raise McpError( | ||
| -32603, | ||
| "Failed to collect device profiles. Check shard connectivity.", | ||
| data={ | ||
| "step": "profiling", | ||
| "shards_count": len(cluster_manager.shards) if cluster_manager.shards else 0 | ||
| } | ||
| ) | ||
|
|
||
| if ctx: | ||
| await ctx.info("Computing optimal layer distribution") | ||
|
|
||
| model_profile_split = profile_model( | ||
| repo_id=req.model, | ||
| batch_sizes=batch_sizes, | ||
| sequence_length=req.seq_len, | ||
| ) | ||
| model_profile = model_profile_split.to_model_profile() | ||
|
|
||
| topology = await cluster_manager.solve_topology( | ||
| profiles, model_profile, req.model, num_layers, req.kv_bits | ||
| ) | ||
| cluster_manager.current_topology = topology | ||
|
|
||
| if ctx: | ||
| await ctx.info("Topology prepared") | ||
|
|
||
| if ctx: | ||
| await ctx.info("Loading model layers across shards...") | ||
| api_props = await cluster_manager.discovery.async_get_own_properties() | ||
| response = await model_manager.load_model( | ||
| topology, api_props, inference_manager.grpc_port | ||
| ) | ||
|
|
||
| if not response.success: | ||
| error_msg = response.message or "Model loading failed" | ||
| shard_errors = [ | ||
| {"instance": s.instance, "message": s.message} | ||
| for s in response.shard_statuses | ||
| if not s.success | ||
| ] | ||
| raise McpError( | ||
| -32603, | ||
| f"Model loading failed: {error_msg}. " | ||
| f"{len(shard_errors)}/{len(response.shard_statuses)} shards failed.", | ||
| data={ | ||
| "model": req.model, | ||
| "shard_errors": shard_errors, | ||
| "failed_shards": len(shard_errors), | ||
| "total_shards": len(response.shard_statuses) | ||
| } | ||
| ) | ||
|
|
||
| if topology.devices: | ||
| first_shard = topology.devices[0] | ||
| await inference_manager.connect_to_ring( | ||
| first_shard.local_ip, first_shard.shard_port, api_props.local_ip | ||
| ) | ||
|
|
||
| if ctx: | ||
| await ctx.info(f"Model {req.model} loaded successfully across {len(response.shard_statuses)} shards") | ||
|
|
||
| success_count = len([s for s in response.shard_statuses if s.success]) | ||
| return f"Model '{req.model}' loaded successfully. Loaded on {success_count}/{len(response.shard_statuses)} shards." |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_model function duplicates a significant amount of logic from http_api.py's load_model method (lines 154-231). This includes topology preparation, device scanning, profiling, and model loading. Consider extracting this shared logic into a separate service method that both endpoints can call to avoid code duplication and ensure consistent behavior.
src/dnet/api/mcp_handler.py
Outdated
| f"Failed to load model '{req.model}': {str(e)}", | ||
| data={"model": req.model, "original_error": type(e).__name__} |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the general exception handler, the code references req.model but req may not be defined if the ValidationError occurred during the APILoadModelRequest construction (line 159-163). This will cause a NameError. Consider using the model parameter instead, or wrapping the req construction in a separate try-except block.
| f"Failed to load model '{req.model}': {str(e)}", | |
| data={"model": req.model, "original_error": type(e).__name__} | |
| f"Failed to load model '{model}': {str(e)}", | |
| data={"model": model, "original_error": type(e).__name__} |
src/dnet/api/mcp_handler.py
Outdated
| Args: | ||
| model: Model ID from catalog | ||
| kv_bits: KV cache quantization | ||
| seq_len: Sequence length |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace detected in the docstring. The line ends with extra spaces after "length". This should be removed for consistency with code style.
| seq_len: Sequence length | |
| seq_len: Sequence length |
src/dnet/api/mcp_handler.py
Outdated
|
|
||
|
|
||
| class McpError(Exception): | ||
| """Custom MCP error with JSON-RPC 2.0 error codes. |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace detected in the docstring. The docstring line ends with extra spaces after "error codes." This should be removed for consistency with code style.
| """Custom MCP error with JSON-RPC 2.0 error codes. | |
| """Custom MCP error with JSON-RPC 2.0 error codes. |
src/dnet/api/mcp_handler.py
Outdated
| try: | ||
| req = APILoadModelRequest( | ||
| model=model, | ||
| kv_bits=kv_bits, | ||
| seq_len=seq_len, | ||
| ) | ||
| if ctx: | ||
| await ctx.info(f"Starting to load model: {req.model}") | ||
|
|
||
| if model_manager.current_model_id == req.model: | ||
| return f"Model '{req.model}' is already loaded." |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable req is assigned but the original parameter model is still used instead of req.model on line 167. This creates confusion about which value is being checked. Either use model consistently with the function parameter, or use req.model after constructing the request object. The current approach constructs the request object but then ignores it for the initial check.
src/dnet/api/mcp_handler.py
Outdated
| profiles = await cluster_manager.profile_cluster( | ||
| req.model, embedding_size, 2, batch_sizes |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded value 2 for the third parameter to profile_cluster appears to be a magic number without explanation. Consider adding a comment explaining what this parameter represents or extracting it to a named constant to improve code clarity.
| profiles = await cluster_manager.profile_cluster( | |
| req.model, embedding_size, 2, batch_sizes | |
| # The third argument to profile_cluster represents the number of ??? (TODO: replace with actual meaning) | |
| PROFILE_CLUSTER_NUM_ARG = 2 # TODO: Replace with a more descriptive name if possible | |
| profiles = await cluster_manager.profile_cluster( | |
| req.model, embedding_size, PROFILE_CLUSTER_NUM_ARG, batch_sizes |
| def create_mcp_server( | ||
| inference_manager: InferenceManager, | ||
| model_manager: ModelManager, | ||
| cluster_manager: ClusterManager, | ||
| ) -> FastMCP: | ||
| """Create and configure the MCP server for dnet.""" | ||
|
|
||
| mcp = FastMCP("dnet") | ||
| mcp.add_middleware(ErrorHandlingMiddleware()) | ||
| @mcp.custom_route("/mcp-health", methods=["GET"]) | ||
| async def mcp_health_check(request): | ||
| """Health check endpoint for MCP server.""" | ||
| return JSONResponse({ | ||
| "status": "healthy", | ||
| "service": "dnet-mcp", | ||
| "model_loaded": model_manager.current_model_id is not None, | ||
| "model": model_manager.current_model_id, | ||
| "topology_configured": cluster_manager.current_topology is not None, | ||
| "shards_discovered": len(cluster_manager.shards) if cluster_manager.shards else 0, | ||
| }) | ||
|
|
||
| @mcp.tool() | ||
| async def chat_completion( | ||
| messages: list[dict[str, str]], | ||
| model: str | None = None, | ||
| temperature: float = 1.0, | ||
| max_tokens: int = 2000, | ||
| top_p: float = 1.0, | ||
| top_k: int = -1, | ||
| stop: str | list[str] | None = None, | ||
| repetition_penalty: float = 1.0, | ||
| ctx: Context | None = None, | ||
| ) -> str: | ||
| """Generate text using distributed LLM inference. | ||
| Args: | ||
| messages: Array of message objects with 'role' and 'content' fields. | ||
| Each message should be a dict like: {"role": "user", "content": "Hello"} | ||
| model: Model name (optional, uses currently loaded model if not specified) | ||
| temperature: Sampling temperature (0-2), default is 1.0 | ||
| max_tokens: Maximum tokens to generate, default is 2000 | ||
| top_p: Nucleus sampling parameter (0-1), default is 1.0 | ||
| top_k: Top-k sampling parameter (-1 for disabled), default is -1 | ||
| stop: Stop sequences (string or list), default is None | ||
| repetition_penalty: Repetition penalty (>=0), default is 1.0 | ||
| """ | ||
|
|
||
| if ctx: | ||
| await ctx.info("Starting inference...") | ||
|
|
||
| if not model_manager.current_model_id: | ||
| raise McpError( | ||
| -32000, | ||
| "No model loaded. Please load a model first using load_model tool.", | ||
| data={"action": "load_model"} | ||
| ) | ||
|
|
||
| model_id = model or model_manager.current_model_id | ||
| stops = [stop] if isinstance(stop, str) else (stop or []) | ||
|
|
||
| try: | ||
| msgs = [ | ||
| ChatMessage(**msg) if isinstance(msg, dict) else msg | ||
| for msg in messages | ||
| ] | ||
| req = ChatRequestModel( | ||
| messages=msgs, | ||
| model=model_id, | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| top_p=top_p, | ||
| top_k=top_k, | ||
| stop=stops, | ||
| repetition_penalty=repetition_penalty, | ||
| stream=False, | ||
| ) | ||
| result = await inference_manager.chat_completions(req) | ||
| except ValidationError as e: | ||
| raise McpError( | ||
| -32602, | ||
| f"Invalid request parameters: {str(e)}", | ||
| data={"validation_errors": str(e)} | ||
| ) | ||
| except Exception as e: | ||
| logger.exception("Error in chat_completion: %s", e) | ||
| raise McpError( | ||
| -32603, | ||
| f"Inference failed: {str(e)}", | ||
| data={"model": model_id, "original_error": type(e).__name__} | ||
| ) | ||
|
|
||
| if not result.choices or not result.choices[0].message: | ||
| raise McpError(-32603, "No content generated", data={"model": model_id}) | ||
|
|
||
| text = result.choices[0].message.content or "" | ||
| if ctx: | ||
| await ctx.info("Inference completed successfully") | ||
|
|
||
| return text | ||
|
|
||
| @mcp.tool() | ||
| async def load_model( | ||
| model: str, | ||
| kv_bits: str = "8bit", | ||
| seq_len: int = 4096, | ||
| ctx: Context | None = None, | ||
| ) -> str: | ||
| """Load a model for distributed inference across the cluster. | ||
| If a different model is already loaded, both models will stay in memory (old model | ||
| is not automatically unloaded). If the same model is already loaded, returns early. | ||
| Automatically prepares topology and discovers devices if needed. | ||
| Args: | ||
| model: Model ID from catalog | ||
| kv_bits: KV cache quantization | ||
| seq_len: Sequence length | ||
| """ | ||
| try: | ||
| req = APILoadModelRequest( | ||
| model=model, | ||
| kv_bits=kv_bits, | ||
| seq_len=seq_len, | ||
| ) | ||
| if ctx: | ||
| await ctx.info(f"Starting to load model: {req.model}") | ||
|
|
||
| if model_manager.current_model_id == req.model: | ||
| return f"Model '{req.model}' is already loaded." | ||
|
|
||
| topology = cluster_manager.current_topology | ||
| if topology is None or topology.model != req.model: | ||
| if ctx: | ||
| await ctx.info("Preparing topology ...") | ||
|
|
||
| await cluster_manager.scan_devices() | ||
| if not cluster_manager.shards: | ||
| raise McpError( | ||
| -32002, | ||
| "No shards discovered. Check shard connectivity.", | ||
| data={"action": "check_shard_connectivity"} | ||
| ) | ||
|
|
||
| if ctx: | ||
| await ctx.info("Profiling cluster performance") | ||
|
|
||
| model_config = get_model_config_json(req.model) | ||
| embedding_size = int(model_config["hidden_size"]) | ||
| num_layers = int(model_config["num_hidden_layers"]) | ||
|
|
||
| batch_sizes = [1] | ||
| profiles = await cluster_manager.profile_cluster( | ||
| req.model, embedding_size, 2, batch_sizes | ||
| ) | ||
| if not profiles: | ||
| raise McpError( | ||
| -32603, | ||
| "Failed to collect device profiles. Check shard connectivity.", | ||
| data={ | ||
| "step": "profiling", | ||
| "shards_count": len(cluster_manager.shards) if cluster_manager.shards else 0 | ||
| } | ||
| ) | ||
|
|
||
| if ctx: | ||
| await ctx.info("Computing optimal layer distribution") | ||
|
|
||
| model_profile_split = profile_model( | ||
| repo_id=req.model, | ||
| batch_sizes=batch_sizes, | ||
| sequence_length=req.seq_len, | ||
| ) | ||
| model_profile = model_profile_split.to_model_profile() | ||
|
|
||
| topology = await cluster_manager.solve_topology( | ||
| profiles, model_profile, req.model, num_layers, req.kv_bits | ||
| ) | ||
| cluster_manager.current_topology = topology | ||
|
|
||
| if ctx: | ||
| await ctx.info("Topology prepared") | ||
|
|
||
| if ctx: | ||
| await ctx.info("Loading model layers across shards...") | ||
| api_props = await cluster_manager.discovery.async_get_own_properties() | ||
| response = await model_manager.load_model( | ||
| topology, api_props, inference_manager.grpc_port | ||
| ) | ||
|
|
||
| if not response.success: | ||
| error_msg = response.message or "Model loading failed" | ||
| shard_errors = [ | ||
| {"instance": s.instance, "message": s.message} | ||
| for s in response.shard_statuses | ||
| if not s.success | ||
| ] | ||
| raise McpError( | ||
| -32603, | ||
| f"Model loading failed: {error_msg}. " | ||
| f"{len(shard_errors)}/{len(response.shard_statuses)} shards failed.", | ||
| data={ | ||
| "model": req.model, | ||
| "shard_errors": shard_errors, | ||
| "failed_shards": len(shard_errors), | ||
| "total_shards": len(response.shard_statuses) | ||
| } | ||
| ) | ||
|
|
||
| if topology.devices: | ||
| first_shard = topology.devices[0] | ||
| await inference_manager.connect_to_ring( | ||
| first_shard.local_ip, first_shard.shard_port, api_props.local_ip | ||
| ) | ||
|
|
||
| if ctx: | ||
| await ctx.info(f"Model {req.model} loaded successfully across {len(response.shard_statuses)} shards") | ||
|
|
||
| success_count = len([s for s in response.shard_statuses if s.success]) | ||
| return f"Model '{req.model}' loaded successfully. Loaded on {success_count}/{len(response.shard_statuses)} shards." | ||
|
|
||
| except ValidationError as e: | ||
| raise McpError( | ||
| -32602, | ||
| f"Invalid load_model parameters: {str(e)}", | ||
| data={"validation_errors": str(e)} | ||
| ) | ||
| except McpError: | ||
| raise | ||
| except Exception as e: | ||
| logger.exception("Error in load_model: %s", e) | ||
| if ctx: | ||
| await ctx.error(f"Failed to load model: {str(e)}") | ||
| raise McpError( | ||
| -32603, | ||
| f"Failed to load model '{req.model}': {str(e)}", | ||
| data={"model": req.model, "original_error": type(e).__name__} | ||
| ) | ||
|
|
||
| @mcp.tool() | ||
| async def unload_model(ctx: Context | None = None) -> str: | ||
| """Unload the currently loaded model to free memory. | ||
| Unloads the model from all shards and clears the topology. If no model is loaded, returns early. | ||
| """ | ||
| if not model_manager.current_model_id: | ||
| return "No model is currently loaded." | ||
|
|
||
| model_name = model_manager.current_model_id | ||
| if ctx: | ||
| await ctx.info(f"Unloading model: {model_name}") | ||
|
|
||
| await cluster_manager.scan_devices() | ||
| shards = cluster_manager.shards | ||
| response = await model_manager.unload_model(shards) | ||
|
|
||
| if response.success: | ||
| cluster_manager.current_topology = None | ||
| if ctx: | ||
| await ctx.info("Model unloaded successfully") | ||
| return f"Model '{model_name}' unloaded successfully from all shards." | ||
| else: | ||
| shard_errors = [ | ||
| {"instance": s.instance, "message": s.message} | ||
| for s in response.shard_statuses | ||
| if not s.success | ||
| ] | ||
| raise McpError( | ||
| -32603, | ||
| "Model unloading failed", | ||
| data={ | ||
| "model": model_name, | ||
| "shard_errors": shard_errors, | ||
| "failed_shards": len(shard_errors), | ||
| "total_shards": len(response.shard_statuses) | ||
| } | ||
| ) | ||
|
|
||
| # Resources (for MCP protocol compliance) | ||
| @mcp.resource("mcp://dnet/models") | ||
| async def get_available_models() -> str: | ||
| """List of models available in dnet catalog, organized by family and quantization.""" | ||
| return await _get_available_models_data() | ||
|
|
||
| @mcp.resource("mcp://dnet/status") | ||
| async def get_model_status() -> str: | ||
| """Currently loaded model and cluster status information.""" | ||
| return await _get_model_status_data() | ||
|
|
||
| @mcp.resource("mcp://dnet/cluster") | ||
| async def get_cluster_info() -> str: | ||
| """Detailed cluster information including devices and topology.""" | ||
| return await _get_cluster_info_data() | ||
|
|
||
| # Tools that wrap resources (for Claude Desktop compatibility) | ||
| @mcp.tool() | ||
| async def list_models() -> str: | ||
| """List all available models in the dnet catalog. | ||
| Returns a formatted list of models organized by family and quantization. | ||
| Use this to see what models you can load. | ||
| """ | ||
| return await _get_available_models_data() | ||
|
|
||
| @mcp.tool() | ||
| async def get_status() -> str: | ||
| """Get the current status of dnet including loaded model, topology, and cluster information. | ||
| Returns detailed status about: | ||
| - Currently loaded model (if any) | ||
| - Topology configuration | ||
| - Discovered shards in the cluster | ||
| """ | ||
| return await _get_model_status_data() | ||
|
|
||
| @mcp.tool() | ||
| async def get_cluster_details() -> str: | ||
| """Get detailed cluster information including shard details and topology breakdown. | ||
| Returns comprehensive information about: | ||
| - All discovered shards with their IPs and ports | ||
| - Current topology configuration | ||
| - Layer assignments across devices | ||
| """ | ||
| return await _get_cluster_info_data() | ||
|
|
||
|
|
||
| async def _get_available_models_data() -> str: | ||
| models_by_family = defaultdict(list) | ||
| for model in model_manager.available_models: | ||
| models_by_family[model.alias].append(model) | ||
|
|
||
| output_lines = ["Available Models in dnet Catalog:\n"] | ||
| output_lines.append("=" * 60) | ||
|
|
||
| for family_name in sorted(models_by_family.keys()): | ||
| models = sorted(models_by_family[family_name], key=lambda m: m.id) | ||
| output_lines.append(f"\n{family_name.upper()}") | ||
| output_lines.append("-" * 60) | ||
|
|
||
| by_quant = defaultdict(list) | ||
| for model in models: | ||
| by_quant[model.quantization].append(model) | ||
|
|
||
| for quant in ["bf16", "fp16", "8bit", "4bit"]: | ||
| if quant in by_quant: | ||
| quant_models = by_quant[quant] | ||
| quant_display = { | ||
| "bf16": "BF16 (Full precision)", | ||
| "fp16": "FP16 (Full precision)", | ||
| "8bit": "8-bit quantized", | ||
| "4bit": "4-bit quantized (smallest)", | ||
| }.get(quant, quant) | ||
| output_lines.append(f" {quant_display}:") | ||
| for model in quant_models: | ||
| output_lines.append(f" - {model.id}") | ||
|
|
||
| output_lines.append("\n" + "=" * 60) | ||
| output_lines.append(f"\nTotal: {len(model_manager.available_models)} models") | ||
| output_lines.append("\nTo load a model, use the load_model tool with the full model ID.") | ||
|
|
||
| return "\n".join(output_lines) | ||
|
|
||
| async def _get_model_status_data() -> str: | ||
| status_lines = ["dnet Status"] | ||
| status_lines.append("=" * 60) | ||
|
|
||
| if model_manager.current_model_id: | ||
| status_lines.append(f"\n Model Loaded: {model_manager.current_model_id}") | ||
| else: | ||
| status_lines.append("\n No Model Loaded") | ||
|
|
||
| topology = cluster_manager.current_topology | ||
| if topology: | ||
| status_lines.append(f"\n Topology:\n Model: {topology.model}\n Devices: {len(topology.devices)}\n Layers: {topology.num_layers}\n KV Cache: {topology.kv_bits}") | ||
|
|
||
| if topology.assignments: | ||
| status_lines.append(f"\n Layer Distribution:") | ||
| for assignment in topology.assignments: | ||
| layers_str = ", ".join( | ||
| f"{r[0]}-{r[-1]}" if len(r) > 1 else str(r[0]) | ||
| for r in assignment.layers | ||
| ) | ||
| status_lines.append( | ||
| f" {assignment.instance}: layers [{layers_str}]" | ||
| ) | ||
| else: | ||
| status_lines.append("\n Topology: Not configured") | ||
|
|
||
| shards = cluster_manager.shards | ||
| if shards: | ||
| shard_names = ", ".join(sorted(shards.keys())) | ||
| status_lines.append(f"\n Cluster:\n Discovered Shards: {len(shards)}\n Shard Names: {shard_names}") | ||
| else: | ||
| status_lines.append("\n Cluster: No shards discovered") | ||
|
|
||
| status_lines.append("\n" + "=" * 60) | ||
|
|
||
| return "\n".join(status_lines) | ||
|
|
||
| async def _get_cluster_info_data() -> str: | ||
| output_lines = ["dnet Cluster Information"] | ||
| output_lines.append("=" * 60) | ||
|
|
||
| shards = cluster_manager.shards | ||
| if shards: | ||
| output_lines.append(f"\n Shards ({len(shards)}):") | ||
| for name, props in sorted(shards.items()): | ||
| output_lines.append(f"\n {name}:\n IP: {props.local_ip}\n HTTP Port: {props.server_port}\n gRPC Port: {props.shard_port}\n Manager: {'Yes' if props.is_manager else 'No'}\n Busy: {'Yes' if props.is_busy else 'No'}") | ||
| else: | ||
| output_lines.append("\n No shards discovered") | ||
|
|
||
| topology = cluster_manager.current_topology | ||
| if topology: | ||
| output_lines.append(f"\n Topology:\n Model: {topology.model}\n Total Layers: {topology.num_layers}\n KV Cache Bits: {topology.kv_bits}\n Devices: {len(topology.devices)}") | ||
|
|
||
| if topology.assignments: | ||
| output_lines.append(f"\n Layer Assignments:") | ||
| for assignment in topology.assignments: | ||
| layers_flat = [ | ||
| layer | ||
| for round_layers in assignment.layers | ||
| for layer in round_layers | ||
| ] | ||
| layers_str = ", ".join(map(str, sorted(layers_flat))) | ||
| output_lines.append( | ||
| f" {assignment.instance}: [{layers_str}] " | ||
| f"(window={assignment.window_size}, " | ||
| f"next={assignment.next_instance or 'N/A'})" | ||
| ) | ||
| else: | ||
| output_lines.append("\n No topology configured") | ||
|
|
||
| output_lines.append("\n" + "=" * 60) | ||
|
|
||
| return "\n".join(output_lines) | ||
|
|
||
| return mcp |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new MCP handler functionality lacks test coverage. Given that comprehensive tests exist for other API endpoints in tests/subsystems/test_api_http_server.py, similar tests should be added for the MCP tools (chat_completion, load_model, unload_model) and resources. At minimum, tests should cover success cases, error handling (McpError raising), and edge cases like no model loaded.
pyproject.toml
Outdated
| "dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py", | ||
| "rich>=13.0.0", | ||
| "psutil>=5.9.0", | ||
| "fastmcp", |
Copilot
AI
Dec 14, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fastmcp dependency is not version-pinned, while other dependencies in the project have specific version constraints (e.g., mlx-lm==0.28.2, huggingface-hub==0.24.0) or minimum versions (e.g., numpy>=1.24.0). Consider adding a version constraint to ensure reproducible builds and avoid potential breaking changes from future fastmcp releases.
| "fastmcp", | |
| "fastmcp==1.0.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/dnet/api/mcp_handler.py
Outdated
|
|
||
|
|
||
| class McpError(Exception): | ||
| """Custom MCP error with JSON-RPC 2.0 error codes. |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace after the docstring on line 20. This should be removed to maintain code cleanliness.
src/dnet/api/mcp_handler.py
Outdated
| model: Model ID from catalog | ||
| kv_bits: KV cache quantization | ||
| seq_len: Sequence length |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incomplete documentation for the kv_bits and seq_len parameters. The parameter descriptions are truncated and should provide complete information about accepted values and their meanings.
| model: Model ID from catalog | |
| kv_bits: KV cache quantization | |
| seq_len: Sequence length | |
| model: Model ID from catalog. | |
| kv_bits: KV cache quantization mode for the loaded model's KV cache, as a string. | |
| Typical values are backend-dependent (for example, ``"8bit"`` or ``"16bit"``); | |
| the default is ``"8bit"``. Only quantization modes supported by the cluster | |
| and model backend are valid. | |
| seq_len: Maximum sequence length (in tokens) to assume when profiling the model | |
| and computing the cluster topology. Must be a positive integer; defaults to | |
| ``4096``. |
pyproject.toml
Outdated
| "dnet-p2p @ file://${PROJECT_ROOT}/lib/dnet-p2p/bindings/py", | ||
| "rich>=13.0.0", | ||
| "psutil>=5.9.0", | ||
| "fastmcp", |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fastmcp dependency lacks a version constraint. For reproducible builds and to avoid potential breaking changes, it's recommended to pin or constrain the version (e.g., "fastmcp>=0.1.0,<1.0.0" or "fastmcp==0.x.y").
| "fastmcp", | |
| "fastmcp>=0.1.0,<1.0.0", |
src/dnet/api/http_api.py
Outdated
| # Mount MCP server as ASGI app | ||
| self.app.mount("/mcp", mcp_app) | ||
|
|
||
|
|
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra blank line after the mount statement. While not critical, there's an unnecessary blank line at line 60 that should be removed to maintain consistent spacing in the codebase.
erhant
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I agree on Copilot with its comments, most of these are due to lack of formatting. Make sure you do the styling correctly (see
make formatandmake lint) - As Copilot also mentions, there is a great deal of duplicated logic between
mcp_handlerandhttp_api. We should share as much as of it from one place in code, if possible. - An integration tests for the MCP server would be great.
23e7f04 to
93b3930
Compare
Adds Model Context Protocol (MCP) integration to enable dnet as an MCP server for AI assistants and applications.
Changes
/mcpendpointType of Change
Testing
Manually tested with Claude Desktop and Cursor - all tools are working correctly.
Related Issues
Closes #35