Skip to content

Commit f9d9f53

Browse files
committed
Catch MemoryError in model manager
1 parent 5fd6a24 commit f9d9f53

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

neurons/validator.py

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ async def spawn_loops(task_queue: list, scoring_queue: list, reward_events: list
8080
logger.debug(
8181
f"Task Queue {len(task_queue)}. Scoring Queue {len(scoring_queue)}. Reward Events {len(reward_events)}"
8282
)
83+
if model_scheduler.memory_error is not None:
84+
raise model_scheduler.memory_error
8385
except asyncio.CancelledError:
8486
logger.info("spawn_loops received cancellation signal.")
8587
raise

prompting/llms/model_manager.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ class AsyncModelScheduler(AsyncLoopRunner):
255255
mp_lock: AcquirerProxy
256256
interval: int = 1200
257257
scoring_queue: list | None = None
258+
memory_error: MemoryError | None = None
258259

259260
model_config = ConfigDict(arbitrary_types_allowed=True)
260261

@@ -277,6 +278,9 @@ async def run_step(self):
277278
logger.info(f"Model {selected_model.llm_model_id} is already loaded.")
278279
return
279280

280-
await self.llm_model_manager.load_model(selected_model)
281+
try:
282+
await self.llm_model_manager.load_model(selected_model)
283+
except MemoryError as e:
284+
self.memory_error = e
281285
logger.debug(f"Active models: {self.llm_model_manager.active_models.keys()}")
282286
await asyncio.sleep(0.01)

prompting/tasks/task_sending.py

-5
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ class TaskSender(AsyncLoopRunner):
8282

8383
task_queue: list | None = None
8484
scoring_queue: list | None = None
85-
# subtensor: bt.Subtensor | None = None
8685
miners_dict: dict | None = None
8786

8887
class Config:
@@ -92,9 +91,6 @@ async def start(self, task_queue, scoring_queue, miners_dict, **kwargs):
9291
self.task_queue = task_queue
9392
self.scoring_queue = scoring_queue
9493
self.miners_dict = miners_dict
95-
96-
# # shared_settings is not initialised inside this process, meaning it cannot access any non-constants from here
97-
# self.subtensor = bt.subtensor(network=shared_settings.SUBTENSOR_NETWORK)
9894
return await super().start(**kwargs)
9995

10096
@property
@@ -133,7 +129,6 @@ async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
133129
task_id=task.task_id,
134130
)
135131
self.scoring_queue.append(scoring_config)
136-
# logger.debug(f"Scoring queue length: {len(self.scoring_queue)}")
137132

138133
# Log the step event.
139134
return ValidatorLoggingEvent(

shared/epistula.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,16 @@ async def query_miners(
149149
responses_valid = 0
150150
responses_error = 0
151151
responses_exception = 0
152-
exception_info: Exception | None = None
153152
results: list[SynapseStreamResult] = []
154153
for response, uid in zip(responses, uids):
155154
if isinstance(response, Exception):
156155
responses_exception += 1
157-
exception_info = response
158156
results.append(SynapseStreamResult(exception=str(response)))
159157
elif isinstance(response, tuple) and isinstance(response[0], ChatCompletion):
160158
if response and response[1]:
161159
responses_valid += 1
160+
else:
161+
responses_error += 1
162162
results.append(
163163
SynapseStreamResult(
164164
uid=uid,
@@ -174,8 +174,8 @@ async def query_miners(
174174

175175
logger.info(
176176
f"Responses success: {responses_valid}/{len(uids)}. "
177-
f"Responses exception: {responses_exception}/{len(uids)} [{exception_info}]. "
178-
f"Reponses error: {responses_error}/{len(uids)}"
177+
f"Responses exception: {responses_exception}/{len(uids)}. "
178+
f"Reponses invalid: {responses_error}/{len(uids)}"
179179
)
180180
return results
181181
except Exception as e:

0 commit comments

Comments
 (0)