Skip to content

Commit 1abb78a

Browse files
committed
feat(network): add timeouts for subtensor operations to prevent indefinite hangs and improve error handling
1 parent 2d92b16 commit 1abb78a

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

grail/infrastructure/network.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
logger = logging.getLogger(__name__)
1414

15+
# Timeouts for subtensor lifecycle operations to prevent indefinite hangs
16+
CLOSE_TIMEOUT = float(os.getenv("BT_CLOSE_TIMEOUT", "60.0"))
17+
INIT_TIMEOUT = float(os.getenv("BT_INIT_TIMEOUT", "120.0"))
18+
1519

1620
class ResilientSubtensor:
1721
"""
@@ -104,30 +108,40 @@ def _should_open_circuit(self) -> bool:
104108
circuit_threshold = object.__getattribute__(self, "_circuit_threshold")
105109
return failure_count >= circuit_threshold
106110

111+
async def _close_subtensor_safe(self, subtensor: Any, timeout: float = CLOSE_TIMEOUT) -> None:
112+
"""Close subtensor with timeout, logging but not raising on failure."""
113+
if not hasattr(subtensor, "close"):
114+
return
115+
try:
116+
if asyncio.iscoroutinefunction(subtensor.close):
117+
await asyncio.wait_for(subtensor.close(), timeout=timeout)
118+
else:
119+
subtensor.close()
120+
except asyncio.TimeoutError:
121+
logger.warning("⏱️ Timeout closing subtensor after %.0fs", timeout)
122+
except asyncio.CancelledError:
123+
raise
124+
except Exception as exc:
125+
logger.warning("Failed to close subtensor: %s", exc)
126+
107127
async def _restart_subtensor(self) -> None:
108-
"""Restart subtensor connection."""
128+
"""Restart subtensor connection with timeouts to prevent indefinite hangs."""
109129
logger.warning("🔄 Restarting subtensor connection...")
110130
subtensor = object.__getattribute__(self, "_subtensor")
111-
network = (
112-
subtensor.network
113-
if hasattr(subtensor, "network")
114-
else os.getenv("BT_NETWORK", "finney")
115-
)
131+
network = getattr(subtensor, "network", None) or os.getenv("BT_NETWORK", "finney")
116132

117-
# Close old subtensor to prevent resource leaks
118-
try:
119-
if hasattr(subtensor, "close"):
120-
if asyncio.iscoroutinefunction(subtensor.close):
121-
await subtensor.close()
122-
else:
123-
subtensor.close()
124-
logger.debug("Closed old subtensor connection")
125-
except Exception as exc:
126-
logger.warning("Failed to close old subtensor: %s", exc)
133+
await self._close_subtensor_safe(subtensor)
127134

128135
new_subtensor = bt.async_subtensor(network=network)
129-
await new_subtensor.initialize()
136+
try:
137+
await asyncio.wait_for(new_subtensor.initialize(), timeout=INIT_TIMEOUT)
138+
except (asyncio.TimeoutError, asyncio.CancelledError):
139+
logger.error("❌ Subtensor init failed, cleaning up")
140+
await self._close_subtensor_safe(new_subtensor, timeout=5.0)
141+
raise
142+
130143
object.__setattr__(self, "_subtensor", new_subtensor)
144+
object.__setattr__(self, "_last_call_timestamp", time.time())
131145
self._reset_circuit_breaker()
132146
logger.info("✅ Subtensor connection restarted")
133147

@@ -136,13 +150,8 @@ async def restart(self) -> None:
136150
await self._restart_subtensor()
137151

138152
async def close(self) -> None:
139-
"""Close the underlying subtensor connection."""
140-
subtensor = object.__getattribute__(self, "_subtensor")
141-
if hasattr(subtensor, "close"):
142-
if asyncio.iscoroutinefunction(subtensor.close):
143-
await subtensor.close()
144-
else:
145-
subtensor.close()
153+
"""Close the underlying subtensor connection with timeout."""
154+
await self._close_subtensor_safe(object.__getattribute__(self, "_subtensor"))
146155

147156
async def _handle_circuit_open(self, method_name: str, args: tuple) -> Any:
148157
"""Handle method call when circuit breaker is open."""
@@ -199,7 +208,8 @@ def _handle_all_retries_failed(self, method_name: str, args: tuple, retries: int
199208

200209
if self._should_open_circuit():
201210
self._open_circuit_breaker()
202-
asyncio.create_task(self._restart_subtensor())
211+
task = asyncio.create_task(self._restart_subtensor(), name="subtensor_restart")
212+
task.add_done_callback(self._on_restart_complete)
203213

204214
# Try to return cached metagraph as last resort
205215
if method_name == "metagraph" and args:
@@ -211,6 +221,14 @@ def _handle_all_retries_failed(self, method_name: str, args: tuple, retries: int
211221
logger.error("❌ %s failed after %d attempts", method_name, retries)
212222
raise TimeoutError(f"{method_name} failed after {retries} attempts")
213223

224+
def _on_restart_complete(self, task: asyncio.Task) -> None:
225+
"""Log errors from background restart task."""
226+
try:
227+
if (exc := task.exception()) is not None:
228+
logger.error("⚠️ Background subtensor restart failed: %s", exc)
229+
except (asyncio.CancelledError, asyncio.InvalidStateError):
230+
pass
231+
214232
async def _call_with_retry(
215233
self, method_name: str, method: Any, args: tuple, kwargs: dict
216234
) -> Any:
@@ -227,21 +245,20 @@ async def _call_with_retry(
227245
if method_name == "metagraph":
228246
timeout = timeout * 2
229247

230-
# Double timeout if connection has been idle for 20+ seconds
231-
# Research-based threshold:
232-
# - Bittensor WebSocket auto-closes after 10s inactivity
233-
# - Substrate layer closes after ~60s inactivity
234-
# - 20s catches stale connections early without false positives
235-
# - Critical for upload worker (40-300s idle during R2 uploads)
248+
# Restart if connection idle for 60s+ (WebSocket likely stale)
236249
last_call_timestamp = object.__getattribute__(self, "_last_call_timestamp")
237250
idle_duration = time.time() - last_call_timestamp
238251
if idle_duration > 60.0:
239252
logger.warning(
240-
"⏰ Connection idle for %.1fs, restarting subtensor and doubling timeout for %s",
241-
idle_duration,
242-
method_name,
253+
"⏰ Connection idle for %.1fs, restarting for %s", idle_duration, method_name
243254
)
244-
await self._restart_subtensor()
255+
try:
256+
await self._restart_subtensor()
257+
except asyncio.CancelledError:
258+
raise
259+
except Exception as exc:
260+
logger.warning("⚠️ Restart failed (%s), doubling timeout", exc)
261+
timeout *= 2
245262

246263
# Retry loop
247264
for retry in range(retries):
@@ -380,12 +397,27 @@ async def create_subtensor(*, resilient: bool = True) -> bt.subtensor | Resilien
380397
logger.info("Connecting to Bittensor %s (network=%s)", label, network)
381398
subtensor = bt.async_subtensor(network=network)
382399

383-
await await_with_stall_log(
384-
subtensor.initialize(),
385-
label="subtensor.initialize",
386-
threshold_seconds=120.0,
387-
log=logger,
388-
)
400+
try:
401+
await asyncio.wait_for(
402+
await_with_stall_log(
403+
subtensor.initialize(),
404+
label="subtensor.initialize",
405+
threshold_seconds=60.0,
406+
log=logger,
407+
),
408+
timeout=INIT_TIMEOUT,
409+
)
410+
except asyncio.TimeoutError:
411+
logger.error("❌ Subtensor init timed out after %.0fs", INIT_TIMEOUT)
412+
if hasattr(subtensor, "close"):
413+
try:
414+
if asyncio.iscoroutinefunction(subtensor.close):
415+
await asyncio.wait_for(subtensor.close(), timeout=5.0)
416+
else:
417+
subtensor.close()
418+
except Exception:
419+
pass
420+
raise
389421

390422
if resilient:
391423
# Wrap with resilience layer for production use

0 commit comments

Comments
 (0)