Skip to content

Commit a1e6b5e

Browse files
kevin1chunclaude
andcommitted
fix: address code review findings — security, correctness, formatting
- Fix token refresh body encoding: use urlencode() instead of manual string concatenation to prevent corruption of tokens with special chars - Fix refresh concurrency guard: share awaitable across concurrent 401s instead of returning None (matches TypeScript SDK behavior) - Set restrictive file permissions (0600) on encrypted token files - Apply _safe_segment validation to all URL builders that interpolate user-supplied parameters (defense-in-depth against path traversal) - Default order_stock time_in_force to "gfd" instead of "gtc" for safety - Use asyncio.to_thread for sync keyring calls to avoid blocking event loop - Fix trailing_peg payload to omit unused fields instead of sending nulls - Fix ruff formatting issues that caused CI failure - Add pagination + untrusted URL rejection tests for _http.py - Add parametrized path traversal tests for all URL builders Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bcb02e6 commit a1e6b5e

File tree

8 files changed

+177
-63
lines changed

8 files changed

+177
-63
lines changed

python/src/robinhood_agents/_auth.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
from __future__ import annotations
88

9-
from dataclasses import dataclass
9+
import asyncio
10+
from dataclasses import dataclass, field
1011
from time import time
1112
from typing import TYPE_CHECKING
13+
from urllib.parse import urlencode
1214

1315
import httpx
1416

@@ -31,7 +33,8 @@ class AuthState:
3133

3234
tokens: TokenData
3335
store: TokenStore
34-
refreshing: bool = False
36+
_refresh_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
37+
_refresh_task: asyncio.Task[str | None] | None = field(default=None, repr=False)
3538

3639

3740
async def _refresh_tokens(state: AuthState) -> str | None:
@@ -60,7 +63,7 @@ async def _refresh_tokens(state: AuthState) -> str | None:
6063
"Content-Type": "application/x-www-form-urlencoded; charset=utf-8",
6164
"X-Robinhood-API-Version": "1.431.4",
6265
},
63-
content="&".join(f"{k}={v}" for k, v in body.items()),
66+
content=urlencode(body),
6467
)
6568
except Exception:
6669
return None
@@ -99,16 +102,22 @@ async def _refresh_tokens(state: AuthState) -> str | None:
99102
def _create_refresh_callback(
100103
state: AuthState,
101104
) -> Callable[[], Awaitable[str | None]]:
102-
"""Create a 401-refresh callback with a concurrency guard."""
105+
"""Create a 401-refresh callback with a concurrency guard.
106+
107+
Concurrent 401s coalesce onto a single refresh attempt — all waiters
108+
get the same result, matching the TypeScript SDK's behaviour.
109+
"""
103110

104111
async def _refresh() -> str | None:
105-
if state.refreshing:
106-
return None # Another refresh is in progress
107-
state.refreshing = True
112+
async with state._refresh_lock:
113+
# If another coroutine already refreshed while we waited, return that result
114+
if state._refresh_task is not None:
115+
return await state._refresh_task
116+
state._refresh_task = asyncio.ensure_future(_refresh_tokens(state))
108117
try:
109-
return await _refresh_tokens(state)
118+
return await state._refresh_task
110119
finally:
111-
state.refreshing = False
120+
state._refresh_task = None
112121

113122
return _refresh
114123

@@ -152,7 +161,12 @@ async def logout(session: Session, state: AuthState | None) -> None:
152161
async with httpx.AsyncClient(timeout=5.0) as client:
153162
await client.post(
154163
"https://api.robinhood.com/oauth2/revoke_token/",
155-
content=f"client_id={_CLIENT_ID}&token={state.tokens.access_token}",
164+
content=urlencode(
165+
{
166+
"client_id": _CLIENT_ID,
167+
"token": state.tokens.access_token,
168+
}
169+
),
156170
headers={"Content-Type": "application/x-www-form-urlencoded; charset=utf-8"},
157171
)
158172
except Exception:

python/src/robinhood_agents/_client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ async def order_stock(
544544
stop_price: float | None = None,
545545
trail_amount: float | None = None,
546546
trail_type: str | None = None,
547-
time_in_force: str | None = None,
547+
time_in_force: str = "gfd",
548548
extended_hours: bool = False,
549549
account_number: str | None = None,
550550
) -> StockOrder:
@@ -588,7 +588,7 @@ async def order_stock(
588588
"quantity": str(quantity),
589589
"type": order_type,
590590
"trigger": trigger,
591-
"time_in_force": "gfd" if is_fractional else (time_in_force or "gtc"),
591+
"time_in_force": "gfd" if is_fractional else time_in_force,
592592
"extended_hours": extended_hours,
593593
"ref_id": str(uuid.uuid4()),
594594
}
@@ -599,11 +599,12 @@ async def order_stock(
599599
payload["stop_price"] = str(stop_price)
600600
if trail_amount is not None:
601601
t_type = trail_type or "percentage"
602-
payload["trailing_peg"] = {
603-
"type": t_type,
604-
"percentage": str(trail_amount) if t_type != "amount" else None,
605-
"price": {"amount": str(trail_amount)} if t_type == "amount" else None,
606-
}
602+
peg: dict[str, object] = {"type": t_type}
603+
if t_type == "amount":
604+
peg["price"] = {"amount": str(trail_amount)}
605+
else:
606+
peg["percentage"] = str(trail_amount)
607+
payload["trailing_peg"] = peg
607608

608609
# Market buys get a 5% price collar
609610
if order_type == "market" and side == "buy" and trigger == "immediate":

python/src/robinhood_agents/_session.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
_DEFAULT_TIMEOUT = 16.0
1717

1818
# Trusted Robinhood origins for redirect safety.
19-
_TRUSTED_ORIGINS = frozenset({
20-
"https://api.robinhood.com",
21-
"https://nummus.robinhood.com",
22-
"https://robinhood.com",
23-
})
19+
_TRUSTED_ORIGINS = frozenset(
20+
{
21+
"https://api.robinhood.com",
22+
"https://nummus.robinhood.com",
23+
"https://robinhood.com",
24+
}
25+
)
2426

2527

2628
async def _safe_fetch(
@@ -104,7 +106,11 @@ async def _fetch_with_retry(
104106
) -> httpx.Response:
105107
"""Fetch with single-retry on 401."""
106108
resp = await _safe_fetch(
107-
self._client, method, url, headers=headers, content=content,
109+
self._client,
110+
method,
111+
url,
112+
headers=headers,
113+
content=content,
108114
timeout=timeout,
109115
)
110116

@@ -114,7 +120,11 @@ async def _fetch_with_retry(
114120
self._access_token = new_token
115121
headers = {**headers, "Authorization": f"Bearer {new_token}"}
116122
resp = await _safe_fetch(
117-
self._client, method, url, headers=headers, content=content,
123+
self._client,
124+
method,
125+
url,
126+
headers=headers,
127+
content=content,
118128
timeout=timeout,
119129
)
120130

@@ -123,7 +133,9 @@ async def _fetch_with_retry(
123133
async def get(self, url: str, params: dict[str, str] | None = None) -> httpx.Response:
124134
target = f"{url}?{urlencode(params)}" if params else url
125135
return await self._fetch_with_retry(
126-
"GET", target, headers=self._auth_headers(self._headers),
136+
"GET",
137+
target,
138+
headers=self._auth_headers(self._headers),
127139
)
128140

129141
async def post(
@@ -153,16 +165,20 @@ async def post(
153165
[(k, str(v)) for k, v in (body or {}).items()],
154166
)
155167

156-
req_timeout = (
157-
httpx.Timeout(timeout) if timeout and timeout != self._timeout else None
158-
)
168+
req_timeout = httpx.Timeout(timeout) if timeout and timeout != self._timeout else None
159169
return await self._fetch_with_retry(
160-
"POST", url, headers=headers, content=content, timeout=req_timeout,
170+
"POST",
171+
url,
172+
headers=headers,
173+
content=content,
174+
timeout=req_timeout,
161175
)
162176

163177
async def delete(self, url: str) -> httpx.Response:
164178
return await self._fetch_with_retry(
165-
"DELETE", url, headers=self._auth_headers(self._headers),
179+
"DELETE",
180+
url,
181+
headers=self._auth_headers(self._headers),
166182
)
167183

168184
async def close(self) -> None:

python/src/robinhood_agents/_token_store.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import json
2021
import os
22+
import stat
2123
from dataclasses import dataclass
2224
from pathlib import Path
2325
from time import time
@@ -111,13 +113,18 @@ async def delete(self) -> None: ...
111113

112114

113115
class KeychainTokenStore:
114-
"""Store tokens in the OS keychain via the ``keyring`` library."""
116+
"""Store tokens in the OS keychain via the ``keyring`` library.
117+
118+
Keyring calls are synchronous and may block (especially on macOS with
119+
authorization prompts), so they are dispatched to a thread via
120+
``asyncio.to_thread`` to avoid blocking the event loop.
121+
"""
115122

116123
async def load(self) -> TokenData | None:
117124
try:
118125
import keyring
119126

120-
raw = keyring.get_password(KEYRING_SERVICE, KEYRING_TOKENS)
127+
raw = await asyncio.to_thread(keyring.get_password, KEYRING_SERVICE, KEYRING_TOKENS)
121128
if raw:
122129
data = json.loads(raw)
123130
return TokenData.from_dict(data)
@@ -128,13 +135,15 @@ async def load(self) -> TokenData | None:
128135
async def save(self, tokens: TokenData) -> None:
129136
import keyring
130137

131-
keyring.set_password(KEYRING_SERVICE, KEYRING_TOKENS, json.dumps(tokens.to_dict()))
138+
await asyncio.to_thread(
139+
keyring.set_password, KEYRING_SERVICE, KEYRING_TOKENS, json.dumps(tokens.to_dict())
140+
)
132141

133142
async def delete(self) -> None:
134143
try:
135144
import keyring
136145

137-
keyring.delete_password(KEYRING_SERVICE, KEYRING_TOKENS)
146+
await asyncio.to_thread(keyring.delete_password, KEYRING_SERVICE, KEYRING_TOKENS)
138147
except Exception:
139148
pass
140149

@@ -170,7 +179,9 @@ def _resolve_encryption_key() -> bytes:
170179
import keyring
171180

172181
keyring.set_password(
173-
KEYRING_SERVICE, KEYRING_ENCRYPTION_KEY, base64.b64encode(key).decode(),
182+
KEYRING_SERVICE,
183+
KEYRING_ENCRYPTION_KEY,
184+
base64.b64encode(key).decode(),
174185
)
175186
except Exception:
176187
pass # Key lives only in memory this session
@@ -232,6 +243,7 @@ async def save(self, tokens: TokenData) -> None:
232243
path = Path(self._file_path)
233244
path.parent.mkdir(parents=True, exist_ok=True)
234245
path.write_text(json.dumps(blob), "utf-8")
246+
path.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0600 — owner only
235247

236248
async def delete(self) -> None:
237249
Path(self._file_path).unlink(missing_ok=True)

0 commit comments

Comments
 (0)