Skip to content

Commit b191993

Browse files
committed
Add Linting
1 parent 5a0089e commit b191993

File tree

4 files changed

+12
-15
lines changed

4 files changed

+12
-15
lines changed

src/nvidia_resiliency_ext/attribution/mcp_integration/mcp_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
from mcp.client.session import ClientSession
1717
from mcp.client.stdio import StdioServerParameters, stdio_client
1818

19-
from nvidia_resiliency_ext.attribution.mcp_integration.registry import (
20-
deserialize_result,
21-
)
19+
from nvidia_resiliency_ext.attribution.mcp_integration.registry import deserialize_result
2220

2321
logger = logging.getLogger(__name__)
2422

@@ -141,7 +139,7 @@ async def run_module(self, module_name: str, **kwargs) -> Dict[str, Any]:
141139
result_str = await self.call_tool(module_name, arguments)
142140
return deserialize_result(result_str)
143141

144-
async def get_result(self, result_id: str) -> Dict[str, Any]:
142+
async def get_result(self, result_id: str) -> Dict[str, Any]:
145143
"""
146144
Retrieve a cached result by ID.
147145

src/nvidia_resiliency_ext/attribution/mcp_integration/mcp_server.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import asyncio
1010
import json
1111
import logging
12-
import uuid
1312
from typing import Any, Dict, List, Optional
1413

1514
from mcp.server import Server
@@ -199,7 +198,7 @@ async def _handle_module_execution(
199198
"""Execute a single attribution module."""
200199
# Apply default values from input schema
201200
arguments_with_defaults = self.registry.apply_defaults(module_name, arguments)
202-
201+
203202
# Get or create module instance
204203
if module_name not in self.module_instances:
205204
# Convert arguments to argparse.Namespace
@@ -235,12 +234,11 @@ async def _handle_module_execution(
235234

236235
return [TextContent(type="text", text=serialize_result(response))]
237236

238-
239237
async def run(self):
240238
"""Run the MCP server."""
241239
import os
242240

243-
logger.info(f"Starting NVRX Attribution MCP Server")
241+
logger.info("Starting NVRX Attribution MCP Server")
244242
logger.info(f"Registered modules: {self.registry.list_modules()}, pid: {os.getpid()}")
245243

246244
async with stdio_server() as (read_stream, write_stream):
@@ -250,4 +248,4 @@ async def run(self):
250248

251249
def run_sync(self):
252250
"""Run the server synchronously."""
253-
asyncio.run(self.run())
251+
asyncio.run(self.run())

src/nvidia_resiliency_ext/attribution/mcp_integration/module_definitions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from nvidia_resiliency_ext.attribution.mcp_integration.registry import global_registry
1010
from nvidia_resiliency_ext.attribution.trace_analyzer.fr_attribution import CollectiveAnalyzer
1111

12+
1213
def register_all_modules():
1314
"""Register all NVRX attribution modules with the global registry."""
1415

@@ -71,6 +72,7 @@ def register_all_modules():
7172
dependencies=[],
7273
)
7374

75+
7476
def create_args_from_dict(module_name: str, config: dict) -> argparse.Namespace:
7577
"""
7678
Create an argparse.Namespace from a configuration dictionary.

src/nvidia_resiliency_ext/attribution/mcp_integration/registry.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
"""
55

66
import hashlib
7-
import inspect
87
import json
98
import logging
109
from dataclasses import asdict, dataclass, is_dataclass
11-
from typing import Any, Callable, Dict, List, Optional, Type
10+
from typing import Any, Dict, List, Optional, Type
1211

1312
from nvidia_resiliency_ext.attribution.base import NVRxAttribution
1413

@@ -110,20 +109,20 @@ def apply_defaults(self, module_name: str, arguments: Dict[str, Any]) -> Dict[st
110109
metadata = self._modules.get(module_name)
111110
if not metadata:
112111
return arguments
113-
112+
114113
# Create a copy to avoid modifying the original
115114
result = dict(arguments)
116-
115+
117116
# Get the properties from the input schema
118117
input_schema = metadata.input_schema
119118
properties = input_schema.get("properties", {})
120-
119+
121120
# Apply defaults for missing arguments
122121
for param_name, param_schema in properties.items():
123122
if param_name not in result and "default" in param_schema:
124123
result[param_name] = param_schema["default"]
125124
logger.debug(f"Applied default for {param_name}: {param_schema['default']}")
126-
125+
127126
return result
128127

129128
def cache_result(self, module_name: str, arguments: Dict[str, Any], result: Any):

0 commit comments

Comments
 (0)