|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import io |
6 | | -from collections.abc import Callable, Sequence |
| 6 | +from collections.abc import Callable, Iterator, Sequence |
| 7 | +from contextlib import contextmanager |
7 | 8 | from pathlib import Path |
8 | 9 | from typing import Any |
9 | 10 |
|
| 11 | +from agents.sandbox.errors import SandboxError |
10 | 12 | from agents.sandbox.session.sandbox_client import BaseSandboxClient |
11 | 13 | from agents.sandbox.session.sandbox_session import SandboxSession |
12 | 14 |
|
|
34 | 36 | from temporalio.contrib.openai_agents.sandbox._temporal_activity_models import ( |
35 | 37 | ExecResult as ExecResultModel, |
36 | 38 | ) |
| 39 | +from temporalio.exceptions import ApplicationError |
| 40 | + |
| 41 | + |
| 42 | +@contextmanager |
| 43 | +def _translate_sandbox_errors() -> Iterator[None]: |
| 44 | + # Temporal retries every activity exception by default, so only a SandboxError |
| 45 | + # the library has classified as terminal (retryable is False) is turned into a |
| 46 | + # non-retryable ApplicationError. |
| 47 | + try: |
| 48 | + yield |
| 49 | + except SandboxError as e: |
| 50 | + if e.retryable is False: |
| 51 | + raise ApplicationError( |
| 52 | + str(e), type=str(e.error_code), non_retryable=True |
| 53 | + ) from e |
| 54 | + raise |
37 | 55 |
|
38 | 56 |
|
39 | 57 | class SandboxClientProvider: |
@@ -99,133 +117,147 @@ def _get_activities(self) -> Sequence[Callable[..., Any]]: |
99 | 117 |
|
100 | 118 | @activity.defn(name=f"{prefix}-sandbox_client_create") |
101 | 119 | async def create_session(args: CreateSessionArgs) -> SessionResult: |
102 | | - session = await self._client.create( |
103 | | - snapshot=args.snapshot_spec, |
104 | | - manifest=args.manifest, |
105 | | - options=args.client_options, |
106 | | - ) |
107 | | - self._sessions[str(session.state.session_id)] = session |
108 | | - return SessionResult( |
109 | | - state=session.state, supports_pty=session.supports_pty() |
110 | | - ) |
| 120 | + with _translate_sandbox_errors(): |
| 121 | + session = await self._client.create( |
| 122 | + snapshot=args.snapshot_spec, |
| 123 | + manifest=args.manifest, |
| 124 | + options=args.client_options, |
| 125 | + ) |
| 126 | + self._sessions[str(session.state.session_id)] = session |
| 127 | + return SessionResult( |
| 128 | + state=session.state, supports_pty=session.supports_pty() |
| 129 | + ) |
111 | 130 |
|
112 | 131 | @activity.defn(name=f"{prefix}-sandbox_client_resume") |
113 | 132 | async def resume_session(args: ResumeSessionArgs) -> SessionResult: |
114 | | - session = await self._client.resume(args.state) |
115 | | - self._sessions[str(session.state.session_id)] = session |
116 | | - return SessionResult( |
117 | | - state=session.state, supports_pty=session.supports_pty() |
118 | | - ) |
| 133 | + with _translate_sandbox_errors(): |
| 134 | + session = await self._client.resume(args.state) |
| 135 | + self._sessions[str(session.state.session_id)] = session |
| 136 | + return SessionResult( |
| 137 | + state=session.state, supports_pty=session.supports_pty() |
| 138 | + ) |
119 | 139 |
|
120 | 140 | @activity.defn(name=f"{prefix}-sandbox_client_delete") |
121 | 141 | async def delete_session(args: StopArgs) -> None: |
122 | | - session = await self._session(args) |
123 | | - await self._client.delete(session) |
124 | | - return None |
| 142 | + with _translate_sandbox_errors(): |
| 143 | + session = await self._session(args) |
| 144 | + await self._client.delete(session) |
| 145 | + return None |
125 | 146 |
|
126 | 147 | # -- Session-level operations (I/O and lifecycle) -- |
127 | 148 |
|
128 | 149 | @activity.defn(name=f"{prefix}-sandbox_session_exec") |
129 | 150 | async def exec_(args: ExecArgs) -> ExecResultModel: |
130 | | - session = await self._session(args) |
131 | | - result = await session.exec( |
132 | | - *args.command, |
133 | | - timeout=args.timeout, |
134 | | - shell=args.shell, |
135 | | - user=args.user, |
136 | | - ) |
137 | | - return ExecResultModel( |
138 | | - stdout=result.stdout, |
139 | | - stderr=result.stderr, |
140 | | - exit_code=result.exit_code, |
141 | | - ) |
| 151 | + with _translate_sandbox_errors(): |
| 152 | + session = await self._session(args) |
| 153 | + result = await session.exec( |
| 154 | + *args.command, |
| 155 | + timeout=args.timeout, |
| 156 | + shell=args.shell, |
| 157 | + user=args.user, |
| 158 | + ) |
| 159 | + return ExecResultModel( |
| 160 | + stdout=result.stdout, |
| 161 | + stderr=result.stderr, |
| 162 | + exit_code=result.exit_code, |
| 163 | + ) |
142 | 164 |
|
143 | 165 | @activity.defn(name=f"{prefix}-sandbox_session_read") |
144 | 166 | async def read(args: ReadArgs) -> ReadResult: |
145 | | - session = await self._session(args) |
146 | | - handle = await session.read(Path(args.path)) |
147 | | - return ReadResult(data=handle.read()) |
| 167 | + with _translate_sandbox_errors(): |
| 168 | + session = await self._session(args) |
| 169 | + handle = await session.read(Path(args.path)) |
| 170 | + return ReadResult(data=handle.read()) |
148 | 171 |
|
149 | 172 | @activity.defn(name=f"{prefix}-sandbox_session_write") |
150 | 173 | async def write(args: WriteArgs) -> None: |
151 | | - session = await self._session(args) |
152 | | - await session.write(Path(args.path), io.BytesIO(args.data)) |
153 | | - return None |
| 174 | + with _translate_sandbox_errors(): |
| 175 | + session = await self._session(args) |
| 176 | + await session.write(Path(args.path), io.BytesIO(args.data)) |
| 177 | + return None |
154 | 178 |
|
155 | 179 | @activity.defn(name=f"{prefix}-sandbox_session_running") |
156 | 180 | async def running(args: RunningArgs) -> RunningResult: |
157 | | - session = await self._session(args) |
158 | | - return RunningResult(is_running=await session.running()) |
| 181 | + with _translate_sandbox_errors(): |
| 182 | + session = await self._session(args) |
| 183 | + return RunningResult(is_running=await session.running()) |
159 | 184 |
|
160 | 185 | @activity.defn(name=f"{prefix}-sandbox_session_persist_workspace") |
161 | 186 | async def persist_workspace( |
162 | 187 | args: PersistWorkspaceArgs, |
163 | 188 | ) -> PersistWorkspaceResult: |
164 | | - session = await self._session(args) |
165 | | - stream = await session.persist_workspace() |
166 | | - return PersistWorkspaceResult(data=stream.read()) |
| 189 | + with _translate_sandbox_errors(): |
| 190 | + session = await self._session(args) |
| 191 | + stream = await session.persist_workspace() |
| 192 | + return PersistWorkspaceResult(data=stream.read()) |
167 | 193 |
|
168 | 194 | @activity.defn(name=f"{prefix}-sandbox_session_hydrate_workspace") |
169 | 195 | async def hydrate_workspace(args: HydrateWorkspaceArgs) -> None: |
170 | | - session = await self._session(args) |
171 | | - await session.hydrate_workspace(io.BytesIO(args.data)) |
172 | | - return None |
| 196 | + with _translate_sandbox_errors(): |
| 197 | + session = await self._session(args) |
| 198 | + await session.hydrate_workspace(io.BytesIO(args.data)) |
| 199 | + return None |
173 | 200 |
|
174 | 201 | @activity.defn(name=f"{prefix}-sandbox_session_pty_exec_start") |
175 | 202 | async def pty_exec_start(args: PtyExecStartArgs) -> PtyExecUpdateResult: |
176 | | - session = await self._session(args) |
177 | | - update = await session.pty_exec_start( |
178 | | - *args.command, |
179 | | - timeout=args.timeout, |
180 | | - shell=args.shell, |
181 | | - user=args.user, |
182 | | - tty=args.tty, |
183 | | - yield_time_s=args.yield_time_s, |
184 | | - max_output_tokens=args.max_output_tokens, |
185 | | - ) |
186 | | - return PtyExecUpdateResult( |
187 | | - process_id=update.process_id, |
188 | | - output=update.output, |
189 | | - exit_code=update.exit_code, |
190 | | - original_token_count=update.original_token_count, |
191 | | - ) |
| 203 | + with _translate_sandbox_errors(): |
| 204 | + session = await self._session(args) |
| 205 | + update = await session.pty_exec_start( |
| 206 | + *args.command, |
| 207 | + timeout=args.timeout, |
| 208 | + shell=args.shell, |
| 209 | + user=args.user, |
| 210 | + tty=args.tty, |
| 211 | + yield_time_s=args.yield_time_s, |
| 212 | + max_output_tokens=args.max_output_tokens, |
| 213 | + ) |
| 214 | + return PtyExecUpdateResult( |
| 215 | + process_id=update.process_id, |
| 216 | + output=update.output, |
| 217 | + exit_code=update.exit_code, |
| 218 | + original_token_count=update.original_token_count, |
| 219 | + ) |
192 | 220 |
|
193 | 221 | @activity.defn(name=f"{prefix}-sandbox_session_pty_write_stdin") |
194 | 222 | async def pty_write_stdin(args: PtyWriteStdinArgs) -> PtyExecUpdateResult: |
195 | | - session = await self._session(args) |
196 | | - update = await session.pty_write_stdin( |
197 | | - session_id=args.session_id, |
198 | | - chars=args.chars, |
199 | | - yield_time_s=args.yield_time_s, |
200 | | - max_output_tokens=args.max_output_tokens, |
201 | | - ) |
202 | | - return PtyExecUpdateResult( |
203 | | - process_id=update.process_id, |
204 | | - output=update.output, |
205 | | - exit_code=update.exit_code, |
206 | | - original_token_count=update.original_token_count, |
207 | | - ) |
| 223 | + with _translate_sandbox_errors(): |
| 224 | + session = await self._session(args) |
| 225 | + update = await session.pty_write_stdin( |
| 226 | + session_id=args.session_id, |
| 227 | + chars=args.chars, |
| 228 | + yield_time_s=args.yield_time_s, |
| 229 | + max_output_tokens=args.max_output_tokens, |
| 230 | + ) |
| 231 | + return PtyExecUpdateResult( |
| 232 | + process_id=update.process_id, |
| 233 | + output=update.output, |
| 234 | + exit_code=update.exit_code, |
| 235 | + original_token_count=update.original_token_count, |
| 236 | + ) |
208 | 237 |
|
209 | 238 | @activity.defn(name=f"{prefix}-sandbox_session_start") |
210 | 239 | async def start(args: StartArgs) -> None: |
211 | | - session = await self._session(args) |
212 | | - await session.start() |
213 | | - return None |
| 240 | + with _translate_sandbox_errors(): |
| 241 | + session = await self._session(args) |
| 242 | + await session.start() |
| 243 | + return None |
214 | 244 |
|
215 | 245 | @activity.defn(name=f"{prefix}-sandbox_session_stop") |
216 | 246 | async def session_stop(args: StopArgs) -> None: |
217 | | - session = await self._session(args) |
218 | | - await session.stop() |
219 | | - return None |
| 247 | + with _translate_sandbox_errors(): |
| 248 | + session = await self._session(args) |
| 249 | + await session.stop() |
| 250 | + return None |
220 | 251 |
|
221 | 252 | @activity.defn(name=f"{prefix}-sandbox_session_shutdown") |
222 | 253 | async def session_shutdown(args: StopArgs) -> None: |
223 | | - key = str(args.state.session_id) |
224 | | - session = self._sessions.get(key) |
225 | | - if session is not None: |
226 | | - await session.shutdown() |
227 | | - del self._sessions[key] |
228 | | - return None |
| 254 | + with _translate_sandbox_errors(): |
| 255 | + key = str(args.state.session_id) |
| 256 | + session = self._sessions.get(key) |
| 257 | + if session is not None: |
| 258 | + await session.shutdown() |
| 259 | + del self._sessions[key] |
| 260 | + return None |
229 | 261 |
|
230 | 262 | return [ |
231 | 263 | create_session, |
|
0 commit comments