Skip to content

Commit bb0af7d

Browse files
Better oauth verification flow
1 parent dcc517b commit bb0af7d

File tree

6 files changed

+184
-72
lines changed

6 files changed

+184
-72
lines changed

src/mcp_importer/api.py

Lines changed: 118 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownParameterType=false
22
import asyncio
3-
from collections.abc import Awaitable
3+
import contextlib
44
from enum import Enum
55
from pathlib import Path
66
from typing import Any
77

8+
from fastmcp import Client as FastMCPClient
89
from fastmcp import FastMCP
10+
from fastmcp.client.auth import OAuth
11+
from fastmcp.client.auth.oauth import FileTokenStorage
912
from loguru import logger as log
1013

1114
from src.config import Config, MCPServerConfig, get_config_json_path
@@ -22,7 +25,7 @@
2225
import_from_vscode,
2326
)
2427
from src.mcp_importer.merge import MergePolicy, merge_servers
25-
from src.oauth_manager import get_oauth_manager
28+
from src.oauth_manager import OAuthStatus, get_oauth_manager
2629

2730

2831
class CLIENT(str, Enum):
@@ -37,15 +40,16 @@ def __repr__(self) -> str:
3740
return str(self)
3841

3942

40-
def detect_clients() -> set[CLIENT]:
41-
detected: set[CLIENT] = set()
43+
def detect_clients() -> list[CLIENT]:
44+
detected: list[CLIENT] = []
4245
if _paths.detect_cursor_config_path() is not None:
43-
detected.add(CLIENT.CURSOR)
46+
detected.append(CLIENT.CURSOR)
4447
if _paths.detect_vscode_config_path() is not None:
45-
detected.add(CLIENT.VSCODE)
48+
detected.append(CLIENT.VSCODE)
4649
if _paths.detect_claude_code_config_path() is not None:
47-
detected.add(CLIENT.CLAUDE_CODE)
48-
return detected
50+
detected.append(CLIENT.CLAUDE_CODE)
51+
# Return clients sorted alphabetically by identifier
52+
return sorted(detected, key=lambda c: c.value)
4953

5054

5155
def import_from(client: CLIENT) -> list[MCPServerConfig]:
@@ -136,8 +140,105 @@ def verify_mcp_server(server: MCPServerConfig) -> bool: # noqa
136140
async def _verify_async() -> bool:
137141
if not server.command.strip():
138142
return False
143+
oauth_info = None
139144

140-
# Inline backend config and capability listing (no extra helpers)
145+
# If this is a remote server, consult OAuth requirement first. Only skip
146+
# verification when OAuth is actually required and no tokens are present.
147+
try:
148+
if server.is_remote_server():
149+
remote_url: str | None = server.get_remote_url()
150+
if remote_url:
151+
oauth_info = await get_oauth_manager().check_oauth_requirement(
152+
server.name, remote_url
153+
)
154+
if oauth_info.status != OAuthStatus.NOT_REQUIRED:
155+
# Token presence check
156+
storage = FileTokenStorage(
157+
server_url=remote_url, cache_dir=get_oauth_manager().cache_dir
158+
)
159+
tokens = await storage.get_tokens()
160+
no_tokens: bool = not tokens or (
161+
not getattr(tokens, "access_token", None)
162+
and not getattr(tokens, "refresh_token", None)
163+
)
164+
# Detect if inline headers are present in args (translated from config)
165+
has_inline_headers: bool = any(
166+
(a == "--header" or a.startswith("--header")) for a in server.args
167+
)
168+
if (
169+
oauth_info.status == OAuthStatus.NEEDS_AUTH
170+
and no_tokens
171+
and not has_inline_headers
172+
):
173+
log.info(
174+
"Skipping verification for remote server '{}' pending OAuth",
175+
server.name,
176+
)
177+
return True
178+
except Exception:
179+
# If token inspection fails, continue with normal verification path
180+
pass
181+
182+
# Remote servers
183+
if server.is_remote_server():
184+
remote_url = server.get_remote_url()
185+
if remote_url:
186+
# If inline headers are specified (e.g., API key), verify via proxy to honor headers
187+
has_inline_headers: bool = any(
188+
(a == "--header" or a.startswith("--header")) for a in server.args
189+
)
190+
if has_inline_headers:
191+
backend_cfg: dict[str, Any] = {
192+
"mcpServers": {
193+
server.name: {
194+
"command": server.command,
195+
"args": server.args,
196+
"env": server.env or {},
197+
**({"roots": server.roots} if server.roots else {}),
198+
}
199+
}
200+
}
201+
proxy: FastMCP[Any] | None = None
202+
host: FastMCP[Any] | None = None
203+
try:
204+
proxy = FastMCP.as_proxy(backend_cfg)
205+
host = FastMCP(name=f"open-edison-verify-host-{server.name}")
206+
host.mount(proxy, prefix=server.name)
207+
208+
async def _list_tools_only() -> Any:
209+
return await host._tool_manager.list_tools() # type: ignore[attr-defined]
210+
211+
await asyncio.wait_for(_list_tools_only(), timeout=10.0)
212+
return True
213+
except Exception as e:
214+
log.error(
215+
"MCP remote (headers) verification failed for '{}': {}", server.name, e
216+
)
217+
return False
218+
finally:
219+
for obj in (host, proxy):
220+
if isinstance(obj, FastMCP):
221+
with contextlib.suppress(Exception):
222+
result = obj.shutdown() # type: ignore[attr-defined]
223+
await asyncio.wait_for(result, timeout=2.0) # type: ignore[func-returns-value]
224+
# Otherwise, avoid triggering OAuth flows during verification
225+
try:
226+
if oauth_info is None:
227+
oauth_info = await get_oauth_manager().check_oauth_requirement(
228+
server.name, remote_url
229+
)
230+
# If OAuth is needed or we are already authenticated, don't initiate browser flows here
231+
if oauth_info.status in (OAuthStatus.NEEDS_AUTH, OAuthStatus.AUTHENTICATED):
232+
return True
233+
# NOT_REQUIRED: quick unauthenticated ping
234+
async with FastMCPClient(remote_url, auth=None) as client: # type: ignore
235+
await asyncio.wait_for(client.ping(), timeout=10.0)
236+
return True
237+
except Exception as e: # noqa: BLE001
238+
log.error("MCP remote verification failed for '{}': {}", server.name, e)
239+
return False
240+
241+
# Local/stdio servers: mount via proxy and perform a single light operation (tools only)
141242
backend_cfg: dict[str, Any] = {
142243
"mcpServers": {
143244
server.name: {
@@ -156,36 +257,20 @@ async def _verify_async() -> bool:
156257
host = FastMCP(name=f"open-edison-verify-host-{server.name}")
157258
host.mount(proxy, prefix=server.name)
158259

159-
async def _call_list(kind: str) -> Any:
160-
manager_name = {
161-
"tools": "_tool_manager",
162-
"resources": "_resource_manager",
163-
"prompts": "_prompt_manager",
164-
}[kind]
165-
manager = getattr(host, manager_name)
166-
return await getattr(manager, f"list_{kind}")()
167-
168-
await asyncio.wait_for(
169-
asyncio.gather(
170-
_call_list("tools"),
171-
_call_list("resources"),
172-
_call_list("prompts"),
173-
),
174-
timeout=30.0,
175-
)
260+
async def _list_tools_only() -> Any:
261+
return await host._tool_manager.list_tools() # type: ignore[attr-defined]
262+
263+
await asyncio.wait_for(_list_tools_only(), timeout=10.0)
176264
return True
177265
except Exception as e:
178266
log.error("MCP verification failed for '{}': {}", server.name, e)
179267
return False
180268
finally:
181-
try:
182-
for obj in (host, proxy):
183-
if isinstance(obj, FastMCP):
269+
for obj in (host, proxy):
270+
if isinstance(obj, FastMCP):
271+
with contextlib.suppress(Exception):
184272
result = obj.shutdown() # type: ignore[attr-defined]
185-
if isinstance(result, Awaitable):
186-
await result # type: ignore[func-returns-value]
187-
except Exception:
188-
pass
273+
await asyncio.wait_for(result, timeout=2.0) # type: ignore[func-returns-value]
189274

190275
return asyncio.run(_verify_async())
191276

@@ -209,10 +294,6 @@ async def _authorize_async() -> bool:
209294
oauth_manager = get_oauth_manager()
210295

211296
try:
212-
# Import lazily to avoid import-time side effects
213-
from fastmcp import Client as FastMCPClient # type: ignore
214-
from fastmcp.client.auth import OAuth # type: ignore
215-
216297
# Debug info prior to starting OAuth
217298
print(
218299
"[OAuth] Starting authorization",
@@ -244,8 +325,6 @@ async def _authorize_async() -> bool:
244325

245326
# Post-authorization token inspection (no secrets printed)
246327
try:
247-
from fastmcp.client.auth.oauth import FileTokenStorage # type: ignore
248-
249328
storage = FileTokenStorage(server_url=remote_url, cache_dir=oauth_manager.cache_dir)
250329
tokens = await storage.get_tokens()
251330
access_present = bool(getattr(tokens, "access_token", None)) if tokens else False
@@ -286,8 +365,6 @@ async def _check_async() -> bool:
286365
return False
287366

288367
try:
289-
from fastmcp.client.auth.oauth import FileTokenStorage # type: ignore
290-
291368
storage = FileTokenStorage(
292369
server_url=remote_url, cache_dir=get_oauth_manager().cache_dir
293370
)

src/mcp_importer/export_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55

66
from .exporters import ExportError, export_to_claude_code, export_to_cursor, export_to_vscode
77
from .paths import (
8+
detect_claude_code_config_path,
89
detect_cursor_config_path,
910
detect_vscode_config_path,
11+
get_default_claude_code_config_path,
1012
get_default_cursor_config_path,
1113
get_default_vscode_config_path,
1214
)
@@ -135,8 +137,6 @@ def _handle_vscode(args: argparse.Namespace) -> int:
135137

136138

137139
def _handle_claude_code(args: argparse.Namespace) -> int:
138-
from .paths import detect_claude_code_config_path, get_default_claude_code_config_path
139-
140140
detected = detect_claude_code_config_path()
141141
target_path: Path = detected if detected else get_default_claude_code_config_path()
142142

src/mcp_importer/parsers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pyright: reportUnknownArgumentType=false, reportUnknownVariableType=false, reportMissingImports=false, reportUnknownMemberType=false
22

33
import json
4+
import shlex
45
from pathlib import Path
56
from typing import Any, cast
67

@@ -67,12 +68,35 @@ def _coerce_server_entry(name: str, node: dict[str, Any], default_enabled: bool)
6768

6869
args: list[str] = [str(a) for a in args_raw]
6970

71+
# If command is provided as a full string with flags, split into program + args
72+
if command and (" " in command or command.endswith(("\t", "\n"))):
73+
try:
74+
parts = shlex.split(command)
75+
if parts:
76+
command = parts[0]
77+
# Prepend split args before any provided args to preserve order
78+
args = parts[1:] + args
79+
except Exception:
80+
# If shlex fails, keep original command/args
81+
pass
82+
7083
env_raw = node.get("env") or node.get("environment") or {}
7184
env: dict[str, str] = {}
7285
if isinstance(env_raw, dict):
7386
for k, v in env_raw.items():
7487
env[str(k)] = str(v)
7588

89+
# Support Cursor-style remote config: { "url": "...", "headers": {...} }
90+
# Translate to `npx mcp-remote <url> [--header Key: Value]*` so downstream verification works.
91+
url_val = node.get("url")
92+
if isinstance(url_val, str) and url_val:
93+
command = "npx"
94+
args = ["-y", "mcp-remote", url_val]
95+
headers_raw = node.get("headers")
96+
if isinstance(headers_raw, dict):
97+
for hk, hv in headers_raw.items():
98+
args.extend(["--header", f"{str(hk)}: {str(hv)}"])
99+
76100
enabled = bool(node.get("enabled", default_enabled))
77101

78102
roots_raw = node.get("roots") or node.get("rootPaths") or []

src/server.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
)
2525
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
2626
from fastapi.staticfiles import StaticFiles
27+
from fastmcp import Client as FastMCPClient
2728
from fastmcp import FastMCP
29+
from fastmcp.client.auth import OAuth
2830
from loguru import logger as log
2931
from pydantic import BaseModel, Field
3032

@@ -952,10 +954,6 @@ async def oauth_test_connection(
952954

953955
log.info(f"🔗 Testing connection to {server_name} at {remote_url}")
954956

955-
# Import FastMCP client for testing
956-
from fastmcp import Client as FastMCPClient
957-
from fastmcp.client.auth import OAuth
958-
959957
# Create OAuth auth object
960958
oauth = OAuth(
961959
mcp_url=remote_url,

src/setup_tui/main.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import asyncio
23

34
import questionary
45

@@ -13,6 +14,7 @@
1314
save_imported_servers,
1415
verify_mcp_server,
1516
)
17+
from src.oauth_manager import OAuthStatus, get_oauth_manager
1618

1719

1820
def show_welcome_screen(*, dry_run: bool = False) -> None:
@@ -54,30 +56,43 @@ def handle_mcp_source(
5456
print(f"Verifying the configuration for {config.name}... ")
5557
result = verify_mcp_server(config)
5658
if result:
57-
# If this is a remote server, check OAuth status and optionally authorize
59+
# For remote servers, only prompt if OAuth is actually required
5860
if config.is_remote_server():
59-
# Always check token presence; if absent, prompt for OAuth unless skipped
60-
tokens_present: bool = has_oauth_tokens(config)
61-
if not tokens_present:
62-
if skip_oauth:
63-
print(
64-
f"Skipping OAuth for {config.name} due to --skip-oauth (no tokens present). This server will not be imported."
61+
# Heuristic: if inline headers are present (e.g., API key), treat as not requiring OAuth
62+
has_inline_headers: bool = any(
63+
(a == "--header" or a.startswith("--header")) for a in config.args
64+
)
65+
if not has_inline_headers:
66+
# Prefer cached result from verification; only check if missing
67+
oauth_mgr = get_oauth_manager()
68+
info = oauth_mgr.get_server_info(config.name)
69+
if info is None:
70+
info = asyncio.run(
71+
oauth_mgr.check_oauth_requirement(config.name, config.get_remote_url())
6572
)
66-
continue
67-
68-
if questionary.confirm(
69-
f"{config.name} is a remote server and no OAuth credentials were found. Obtain credentials now?",
70-
default=True,
71-
).ask():
72-
success = authorize_server_oauth(config)
73-
if not success:
74-
print(
75-
f"Failed to obtain OAuth credentials for {config.name}. Skipping this server."
76-
)
77-
continue
78-
else:
79-
print(f"Skipping {config.name} per user choice.")
80-
continue
73+
74+
if info.status == OAuthStatus.NEEDS_AUTH:
75+
tokens_present: bool = has_oauth_tokens(config)
76+
if not tokens_present:
77+
if skip_oauth:
78+
print(
79+
f"Skipping OAuth for {config.name} due to --skip-oauth (OAuth required, no tokens). This server will not be imported."
80+
)
81+
continue
82+
83+
if questionary.confirm(
84+
f"{config.name} requires OAuth and no credentials were found. Obtain credentials now?",
85+
default=True,
86+
).ask():
87+
success = authorize_server_oauth(config)
88+
if not success:
89+
print(
90+
f"Failed to obtain OAuth credentials for {config.name}. Skipping this server."
91+
)
92+
continue
93+
else:
94+
print(f"Skipping {config.name} per user choice.")
95+
continue
8196

8297
verified_configs.append(config)
8398
else:

0 commit comments

Comments
 (0)