Skip to content

Commit 6811ab7

Browse files
committed
propagate agent_id where needed
1 parent 2c8b708 commit 6811ab7

8 files changed

Lines changed: 70 additions & 18 deletions

File tree

surogates/api/routes/sessions.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,12 @@ def _get_session_store(request: Request) -> SessionStore:
8080

8181

8282
async def _get_session_for_tenant(
83-
store: SessionStore, session_id: UUID, tenant: TenantContext
83+
request: Request,
84+
session_id: UUID,
85+
tenant: TenantContext,
8486
) -> Session:
85-
"""Fetch a session and verify it belongs to the tenant's org."""
87+
"""Fetch a session and verify it belongs to the tenant's org and this agent."""
88+
store = _get_session_store(request)
8689
try:
8790
session = await store.get_session(session_id)
8891
except SessionNotFoundError:
@@ -91,7 +94,8 @@ async def _get_session_for_tenant(
9194
detail=f"Session {session_id} not found.",
9295
)
9396

94-
if session.org_id != tenant.org_id:
97+
agent_id = request.app.state.settings.agent_id
98+
if session.org_id != tenant.org_id or session.agent_id != agent_id:
9599
raise HTTPException(
96100
status_code=status.HTTP_404_NOT_FOUND,
97101
detail=f"Session {session_id} not found.",
@@ -170,7 +174,7 @@ async def send_message(
170174
) -> SendMessageResponse:
171175
"""Send a user message to a session, triggering agent processing."""
172176
store = _get_session_store(request)
173-
session = await _get_session_for_tenant(store, session_id, tenant)
177+
session = await _get_session_for_tenant(request, session_id, tenant)
174178

175179
if session.status not in ("active", "idle", "failed", "paused"):
176180
raise HTTPException(
@@ -226,8 +230,7 @@ async def confirm_disclosure(
226230
enforcement is enabled. Typically called by the frontend after
227231
showing the AI disclosure notice to the user.
228232
"""
229-
store = _get_session_store(request)
230-
await _get_session_for_tenant(store, session_id, tenant)
233+
await _get_session_for_tenant(request, session_id, tenant)
231234

232235
governance = getattr(request.app.state, "governance_gate", None)
233236
if governance is not None:
@@ -241,8 +244,7 @@ async def get_session(
241244
tenant: TenantContext = Depends(get_current_tenant),
242245
) -> Session:
243246
"""Retrieve metadata for a single session."""
244-
store = _get_session_store(request)
245-
return await _get_session_for_tenant(store, session_id, tenant)
247+
return await _get_session_for_tenant(request, session_id, tenant)
246248

247249

248250
@router.get("/sessions", response_model=ListSessionsResponse)
@@ -252,8 +254,9 @@ async def list_sessions(
252254
limit: int = 50,
253255
offset: int = 0,
254256
) -> ListSessionsResponse:
255-
"""List the authenticated user's sessions (paginated, newest first)."""
257+
"""List the authenticated user's sessions for this agent, newest first."""
256258
store = _get_session_store(request)
259+
settings = request.app.state.settings
257260

258261
if limit < 1:
259262
limit = 1
@@ -265,6 +268,7 @@ async def list_sessions(
265268
sessions = await store.list_sessions(
266269
org_id=tenant.org_id,
267270
user_id=tenant.user_id,
271+
agent_id=settings.agent_id,
268272
limit=limit,
269273
offset=offset,
270274
)
@@ -285,7 +289,7 @@ async def pause_session(
285289
) -> Session:
286290
"""Pause an active session."""
287291
store = _get_session_store(request)
288-
session = await _get_session_for_tenant(store, session_id, tenant)
292+
session = await _get_session_for_tenant(request, session_id, tenant)
289293

290294
if session.status not in ("active", "processing", "paused"):
291295
raise HTTPException(
@@ -319,7 +323,7 @@ async def resume_session(
319323
) -> Session:
320324
"""Resume a paused session."""
321325
store = _get_session_store(request)
322-
session = await _get_session_for_tenant(store, session_id, tenant)
326+
session = await _get_session_for_tenant(request, session_id, tenant)
323327

324328
if session.status != "paused":
325329
raise HTTPException(
@@ -348,7 +352,7 @@ async def delete_session(
348352
) -> None:
349353
"""Archive (soft-delete) a session and delete its workspace storage."""
350354
store = _get_session_store(request)
351-
await _get_session_for_tenant(store, session_id, tenant)
355+
await _get_session_for_tenant(request, session_id, tenant)
352356

353357
await store.update_session_status(session_id, "archived")
354358

surogates/channels/identity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ async def get_or_create_channel_session(
113113
result = await db.execute(
114114
select(SessionRow)
115115
.where(SessionRow.user_id == user_id)
116+
.where(SessionRow.agent_id == agent_id)
116117
.where(SessionRow.channel == channel)
117118
.where(SessionRow.status.in_(["active", "processing", "paused"]))
118119
.where(

surogates/harness/tool_exec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ async def execute_single_tool(
569569
tool_name,
570570
tool_args,
571571
session_id=str(session.id),
572+
agent_id=session.agent_id,
572573
tenant=tenant,
573574
session_store=store,
574575
redis=redis,

surogates/jobs/reset_idle_sessions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ async def reset_idle_sessions(dry_run: bool = False) -> int:
517517
daily_at_hour = reset_cfg.at_hour if reset_cfg.mode in ("daily", "both") else None
518518
idle_sessions = await session_store.find_idle_sessions(
519519
idle_minutes=reset_cfg.idle_minutes,
520+
agent_id=settings.agent_id,
520521
daily_at_hour=daily_at_hour,
521522
mode=reset_cfg.mode,
522523
)

surogates/orchestrator/worker.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,17 @@ async def run_worker(settings: Settings) -> None:
167167
)
168168
configured_org_id = UUID(settings.org_id)
169169

170+
# Resolve the agent_id from config (required). Sessions belong to an
171+
# agent; a worker refuses to process sessions that belong to a
172+
# different agent.
173+
if not settings.agent_id:
174+
raise RuntimeError(
175+
"SUROGATES_AGENT_ID is not set. Each worker instance serves "
176+
"exactly one agent. Set agent_id in config.yaml or "
177+
"SUROGATES_AGENT_ID env var."
178+
)
179+
configured_agent_id = settings.agent_id
180+
170181
# 7. Harness factory -- creates a fully-wired AgentHarness for a given session.
171182
async def harness_factory(session_id: UUID) -> AgentHarness:
172183
"""Build an AgentHarness with all dependencies injected.
@@ -176,6 +187,15 @@ async def harness_factory(session_id: UUID) -> AgentHarness:
176187
# Load session to get user_id.
177188
session = await session_store.get_session(session_id)
178189

190+
# Refuse to process sessions that belong to a different agent —
191+
# defence-in-depth in case a foreign session id leaks into this
192+
# worker's queue.
193+
if session.agent_id != configured_agent_id:
194+
raise RuntimeError(
195+
f"session {session_id} belongs to agent {session.agent_id!r}, "
196+
f"this worker serves agent {configured_agent_id!r}"
197+
)
198+
179199
# Load org + user from DB.
180200
from sqlalchemy import select as sa_select
181201
from surogates.db.models import Org, User

surogates/session/store.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,19 @@ async def list_sessions(
136136
self,
137137
org_id: UUID,
138138
user_id: UUID,
139+
agent_id: str,
139140
*,
140141
limit: int = 50,
141142
offset: int = 0,
142143
) -> list[Session]:
143-
"""Return sessions for a user within an org, newest first."""
144+
"""Return sessions for a user within an org, scoped to one agent, newest first."""
144145
async with self._sf() as db:
145146
result = await db.execute(
146147
select(SessionRow)
147148
.where(
148149
SessionRow.org_id == org_id,
149150
SessionRow.user_id == user_id,
151+
SessionRow.agent_id == agent_id,
150152
SessionRow.status != "archived",
151153
)
152154
.order_by(SessionRow.created_at.desc())
@@ -508,6 +510,7 @@ async def get_pending_events(self, session_id: UUID) -> list[Event]:
508510
async def find_idle_sessions(
509511
self,
510512
idle_minutes: int,
513+
agent_id: str,
511514
*,
512515
daily_at_hour: int | None = None,
513516
mode: str = "idle",
@@ -557,13 +560,18 @@ async def find_idle_sessions(
557560
LEFT JOIN session_leases l
558561
ON l.session_id = s.id AND l.expires_at > now()
559562
WHERE s.status IN ('active', 'idle')
563+
AND s.agent_id = :agent_id
560564
AND l.session_id IS NULL
561565
AND {where_clause}
562566
ORDER BY s.updated_at ASC
563567
LIMIT :lim
564568
"""
565569

566-
params: dict[str, Any] = {"idle_minutes": idle_minutes, "lim": limit}
570+
params: dict[str, Any] = {
571+
"idle_minutes": idle_minutes,
572+
"agent_id": agent_id,
573+
"lim": limit,
574+
}
567575
if daily_at_hour is not None:
568576
params["at_hour"] = daily_at_hour
569577

surogates/tools/builtin/session_search.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ async def _list_recent_sessions(
266266
session_store: Any,
267267
org_id: UUID,
268268
user_id: UUID,
269+
agent_id: str,
269270
limit: int,
270271
current_session_id: UUID | None = None,
271272
) -> str:
@@ -274,6 +275,7 @@ async def _list_recent_sessions(
274275
sessions = await session_store.list_sessions(
275276
org_id=org_id,
276277
user_id=user_id,
278+
agent_id=agent_id,
277279
limit=limit + 5, # fetch extra to skip current
278280
)
279281

@@ -355,6 +357,7 @@ async def session_search(
355357
session_store: Any = None,
356358
org_id: UUID | None = None,
357359
user_id: UUID | None = None,
360+
agent_id: str = "",
358361
current_session_id: UUID | None = None,
359362
auxiliary_fn: Any | None = None,
360363
) -> str:
@@ -386,13 +389,21 @@ async def session_search(
386389
"error": "Tenant context (org_id/user_id) not available.",
387390
})
388391

392+
# Search stays within the current session's agent — cross-agent history is
393+
# not addressable.
394+
if not agent_id:
395+
return json.dumps({
396+
"success": False,
397+
"error": "agent_id is required to scope search to this agent.",
398+
})
399+
389400
limit = min(limit, 5) # Cap at 5 sessions to avoid excessive LLM calls
390401

391402
# Recent sessions mode: when query is empty, return metadata for recent sessions.
392403
# No LLM calls -- just DB queries for titles, previews, timestamps.
393404
if not query or not query.strip():
394405
return await _list_recent_sessions(
395-
session_store, org_id, user_id, limit, current_session_id,
406+
session_store, org_id, user_id, agent_id, limit, current_session_id,
396407
)
397408

398409
query = query.strip()
@@ -447,6 +458,7 @@ async def session_search(
447458
JOIN sessions s ON s.id = e.session_id
448459
WHERE s.org_id = :org_id
449460
AND s.user_id = :user_id
461+
AND s.agent_id = :agent_id
450462
AND s.status != 'archived'
451463
AND to_tsvector('english', COALESCE(e.data->>'content', '') || ' ' || COALESCE(e.data->>'result', ''))
452464
@@ plainto_tsquery('english', :query)
@@ -458,6 +470,7 @@ async def session_search(
458470
"query": query,
459471
"org_id": org_id,
460472
"user_id": user_id,
473+
"agent_id": agent_id,
461474
"limit": 50, # Get more matches to find unique sessions
462475
}
463476

@@ -740,6 +753,7 @@ async def _session_search_handler(
740753
tenant = kwargs.get("tenant", {})
741754
org_id = tenant.get("org_id") if isinstance(tenant, dict) else getattr(tenant, "org_id", None)
742755
user_id = tenant.get("user_id") if isinstance(tenant, dict) else getattr(tenant, "user_id", None)
756+
agent_id = kwargs.get("agent_id", "")
743757
current_session_id = kwargs.get("session_id")
744758
auxiliary_fn = kwargs.get("auxiliary_fn")
745759

@@ -758,6 +772,7 @@ async def _session_search_handler(
758772
session_store=store,
759773
org_id=org_id,
760774
user_id=user_id,
775+
agent_id=agent_id,
761776
current_session_id=current_session_id,
762777
auxiliary_fn=auxiliary_fn,
763778
)

tests/integration/test_session_store.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,19 @@ async def test_list_sessions(session_store, session_factory):
9797
ids.append(s.id)
9898

9999
# Fetch all
100-
sessions = await session_store.list_sessions(org_id, user_id)
100+
sessions = await session_store.list_sessions(org_id, user_id, "test-agent")
101101
returned_ids = {s.id for s in sessions}
102102
assert all(sid in returned_ids for sid in ids)
103103

104104
# Pagination: limit=2
105-
page1 = await session_store.list_sessions(org_id, user_id, limit=2)
105+
page1 = await session_store.list_sessions(
106+
org_id, user_id, "test-agent", limit=2
107+
)
106108
assert len(page1) == 2
107109

108110
# Pagination: offset=2
109111
page2 = await session_store.list_sessions(
110-
org_id, user_id, limit=10, offset=2
112+
org_id, user_id, "test-agent", limit=10, offset=2
111113
)
112114
assert len(page2) >= 1
113115

0 commit comments

Comments
 (0)