Skip to content

Commit 7564985

Browse files
authored
v2.18.3
Changes: - Check for full completion; - Add inference reward scale for streaming variance between chunks; - Scale inference reward with cosine similarity with grouth truth logits; - Add tests for inference; - Add deps required for asyncio tests; - Close all processes before exit.
2 parents 29eaf46 + 572974e commit 7564985

File tree

6 files changed

+434
-143
lines changed

6 files changed

+434
-143
lines changed

neurons/validator.py

+79-28
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import asyncio
2+
import atexit
23
import os
34
import signal
45
import sys
6+
import time
57
from multiprocessing.managers import AcquirerProxy
8+
from multiprocessing.synchronize import Event
69

710
import netaddr
11+
import psutil
812
import requests
913
import torch
1014
import torch.multiprocessing as mp
@@ -102,6 +106,7 @@ def start_api(
102106
scoring_queue: list,
103107
reward_events: list,
104108
miners_dict: dict,
109+
event_stop: Event,
105110
):
106111
from prompting.api.api import start_scoring_api # noqa: F401
107112

@@ -124,7 +129,7 @@ async def start():
124129
logger.warning(f"Failed to serve scoring api to chain: {e}")
125130
await start_scoring_api(task_scorer, scoring_queue, reward_events, miners_dict)
126131

127-
while True:
132+
while not event_stop.is_set():
128133
await asyncio.sleep(10)
129134

130135
asyncio.run(start())
@@ -134,6 +139,7 @@ def start_task_sending_loop(
134139
task_queue: list,
135140
scoring_queue: list,
136141
miners_dict: dict,
142+
event_stop: Event,
137143
):
138144
async def spawn_loops(task_queue, scoring_queue, miners_dict: dict):
139145
from prompting.tasks.task_sending import TaskSender
@@ -142,7 +148,8 @@ async def spawn_loops(task_queue, scoring_queue, miners_dict: dict):
142148
task_sender = TaskSender()
143149
asyncio.create_task(task_sender.start(task_queue, scoring_queue, miners_dict, simultaneous_loops=1))
144150
logger.debug("Task sending loop started")
145-
while True:
151+
152+
while not event_stop.is_set():
146153
await asyncio.sleep(5)
147154
logger.debug("Task sending loop is running")
148155

@@ -155,13 +162,13 @@ async def spawn_loops(task_queue, scoring_queue, miners_dict: dict):
155162
raise
156163

157164

158-
def start_availability_checking_loop(miners_dict: dict):
165+
def start_availability_checking_loop(miners_dict: dict, event_stop: Event):
159166
async def spawn_loops(miners_dict: dict):
160167
from prompting.miner_availability.miner_availability import availability_checking_loop
161168

162169
logger.info("Starting availability checking loop in validator...")
163170
asyncio.create_task(availability_checking_loop.start(miners_dict))
164-
while True:
171+
while not event_stop.is_set():
165172
await asyncio.sleep(5)
166173
logger.debug("Availability checking loop is running")
167174

@@ -174,13 +181,13 @@ async def spawn_loops(miners_dict: dict):
174181
raise
175182

176183

177-
def start_weight_setter_loop(reward_events):
184+
def start_weight_setter_loop(reward_events, event_stop: Event):
178185
async def spawn_loops(reward_events):
179186
from prompting.weight_setting.weight_setter import weight_setter
180187

181188
logger.info("Starting weight setter loop in validator...")
182189
asyncio.create_task(weight_setter.start(reward_events))
183-
while True:
190+
while not event_stop.is_set():
184191
await asyncio.sleep(5)
185192
logger.debug("Weight setter loop is running")
186193

@@ -193,6 +200,34 @@ async def spawn_loops(reward_events):
193200
raise
194201

195202

203+
def health_check(parent_pid: int, event_stop: Event):
204+
"""Monitor parent process and kill all child processes in case of emergency."""
205+
step = 0
206+
while True:
207+
try:
208+
if not psutil.pid_exists(parent_pid):
209+
event_stop.set()
210+
logger.warning("Parent process died, killing all child processes")
211+
os.killpg(0, signal.SIGKILL)
212+
213+
block = settings.shared_settings.block
214+
if block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 320 and step > 60:
215+
event_stop.set()
216+
last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID]
217+
logger.warning(
218+
f"Metagraph hasn't been updated for {block - last_update_block} blocks. "
219+
f"Staled block: {block}, Last update: {last_update_block}"
220+
)
221+
os.killpg(0, signal.SIGKILL)
222+
step += 1
223+
224+
except Exception as e:
225+
logger.error(f"Failed to kill process group: {e}")
226+
finally:
227+
sys.exit(1)
228+
time.sleep(60)
229+
230+
196231
async def main(
197232
cache_rewards: list | None = None,
198233
cache_scores: list | None = None,
@@ -208,6 +243,7 @@ async def main(
208243
mp_lock = manager.Lock()
209244
processes: list[mp.Process] = []
210245
tasks: list[asyncio.Task] = []
246+
event_stop = mp.Event()
211247

212248
model_scheduler = AsyncModelScheduler(llm_model_manager=ModelManager(), mp_lock=mp_lock, sync=True)
213249

@@ -216,15 +252,19 @@ async def main(
216252
if settings.shared_settings.DEPLOY_SCORING_API and not settings.shared_settings.NEURON_DISABLE_SET_WEIGHTS:
217253
# Use multiprocessing to bypass API blocking issue
218254
api_process = mp.Process(
219-
target=start_api, args=(scoring_queue, reward_events, miners_dict), name="APIProcess"
255+
target=start_api,
256+
args=(scoring_queue, reward_events, miners_dict, event_stop),
257+
name="APIProcess",
258+
daemon=True,
220259
)
221260
api_process.start()
222261
processes.append(api_process)
223262

224263
availability_process = mp.Process(
225264
target=start_availability_checking_loop,
226-
args=(miners_dict,),
265+
args=(miners_dict, event_stop),
227266
name="AvailabilityProcess",
267+
daemon=True,
228268
)
229269
availability_process.start()
230270
processes.append(availability_process)
@@ -243,62 +283,73 @@ async def main(
243283

244284
sending_task = mp.Process(
245285
target=start_task_sending_loop,
246-
args=(task_queue, scoring_queue, miners_dict),
286+
args=(task_queue, scoring_queue, miners_dict, event_stop),
247287
name="SendingTaskProcess",
288+
daemon=True,
248289
)
249290
sending_task.start()
250291
processes.append(sending_task)
251292

252293
weight_setter_process = mp.Process(
253294
target=start_weight_setter_loop,
254-
args=(reward_events,),
295+
args=(reward_events, event_stop),
255296
name="WeightSetterProcess",
297+
daemon=True,
256298
)
257299
weight_setter_process.start()
258300
processes.append(weight_setter_process)
259301

260-
GPUInfo.log_gpu_info()
302+
health_check_process = mp.Process(
303+
target=health_check,
304+
args=(os.getpid(), event_stop),
305+
name="HealthCheckProcess",
306+
daemon=True,
307+
)
308+
health_check_process.start()
309+
processes.append(health_check_process)
261310

262-
step = 0
311+
GPUInfo.log_gpu_info()
263312
while True:
264313
await asyncio.sleep(30)
265-
block = settings.shared_settings.block
266-
if (
267-
block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 500
268-
and step > 150
269-
):
270-
last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID]
271-
logger.warning(
272-
f"Metagraph hasn't been updated for {block - last_update_block} blocks. "
273-
f"Staled block: {block}, Last update: {last_update_block}"
274-
)
275-
break
276-
step += 1
277314

278315
except KeyboardInterrupt:
316+
event_stop.set()
279317
logger.info("KeyboardInterrupt detected. Shutting down gracefully...")
280318
except Exception as e:
281319
logger.error(f"Main loop error: {e}")
282320
raise
283321
finally:
284-
logger.warning("🚨 Force‑killing entire process‑group")
322+
logger.warning("🚨 Force‑killing entire process‑group")
285323

286324
# 1. Cancel in‑process tasks so they stop touching the Manager.
287325
for t in tasks:
288326
t.cancel()
289327
await asyncio.gather(*tasks, return_exceptions=True)
328+
await asyncio.sleep(5)
290329

291330
# 2. Manager cleanup *first* (so its socket vanishes).
292331
manager.shutdown()
293332

294333
# 3. Sledgehammer.
295-
if os.name == "posix":
334+
try:
296335
os.killpg(0, signal.SIGKILL)
297-
else:
298-
logger.error(f"Unsupported OS: {os.name}")
336+
except Exception as e:
337+
logger.error(f"Failed to kill process group: {e}")
299338
sys.exit(1)
300339

301340

341+
def kill_process_group():
342+
try:
343+
os.killpg(os.getpgid(0), signal.SIGKILL)
344+
except Exception as e:
345+
logger.error(f"Failed to kill process group: {e}")
346+
347+
302348
# The main function parses the configuration and runs the validator.
303349
if __name__ == "__main__":
350+
try:
351+
os.setpgrp()
352+
atexit.register(kill_process_group)
353+
except BaseException:
354+
logger.warning("Failed to set process group; emergency termination may not work.")
304355
asyncio.run(main())

prompting/llms/model_manager.py

+31-52
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,12 @@ async def __aexit__(self, exc_type, exc, tb):
4949

5050
class ModelManager(BaseModel):
5151
model_config = ConfigDict(arbitrary_types_allowed=True)
52-
always_active_models: list[ModelConfig] = []
5352
total_ram: float = settings.shared_settings.LLM_MODEL_RAM
5453
active_models: dict[ModelConfig, ReproducibleVLLM] = {}
54+
loading_tasks: dict[ModelConfig, asyncio.Future] = {}
5555
used_ram: float = 0.0
5656
lock: ClassVar[AsyncRLock] = AsyncRLock()
57-
58-
async def load_always_active_models(self):
59-
for model_config in self.always_active_models:
60-
await self.load_model(model_config=model_config)
57+
# lock: ClassVar[AsyncRLock] = asyncio.Lock()
6158

6259
async def load_model(self, model_config: ModelConfig, force: bool = True) -> ReproducibleVLLM:
6360
"""Load model into GPU.
@@ -69,56 +66,40 @@ async def load_model(self, model_config: ModelConfig, force: bool = True) -> Rep
6966
force: If enabled, will unload all other models.
7067
"""
7168
async with self.lock:
72-
if model_config in self.active_models.keys():
69+
# Copy active models, since they will be modified in the loop.
70+
active_models = set(self.active_models.keys())
71+
72+
if model_config in active_models:
7373
logger.debug(f"Model {model_config.llm_model_id} is already loaded.")
7474
return self.active_models[model_config]
7575

7676
if force:
7777
logger.debug(f"Forcing model {model_config.llm_model_id} to load.")
78-
for active_model in list(self.active_models.keys()):
79-
if active_model in self.always_active_models:
80-
continue
78+
for active_model in active_models:
8179
logger.debug(f"Unloading {active_model.llm_model_id} to make room for {model_config.llm_model_id}")
82-
8380
await self._unload_model(active_model)
8481
await self.cleanup()
8582

86-
retries_max = 1
87-
retry_counter = 0
88-
retry_delay = 15
89-
while True:
90-
try:
91-
GPUInfo.log_gpu_info()
92-
model_class = model_factory(model_config.llm_model_id)
93-
model = model_class(
94-
model_id=model_config.llm_model_id,
95-
device=settings.shared_settings.NEURON_DEVICE,
96-
sampling_params=settings.shared_settings.SAMPLING_PARAMS,
97-
)
98-
self.used_ram += model_config.min_ram
99-
logger.info(
100-
f"Model {model_config.llm_model_id} has been successfully loaded. "
101-
f"Approx. used VRAM: {self.used_ram:.0f}GB"
102-
)
103-
self.active_models[model_config] = model
104-
await asyncio.sleep(1.0)
105-
return model
106-
except BaseException as e:
107-
if retry_counter > retries_max:
108-
logger.error(f"Failed to load model after {retries_max} retries. Terminating process")
109-
await self.cleanup()
110-
# In case of VRAM leak, raise an exception to terminate the process.
111-
raise MemoryError
112-
113-
retry_counter += 1
114-
retry_delay += retry_counter
115-
await self.cleanup()
116-
logger.error(
117-
f"Failed to load model {model_config.llm_model_id}. Retrying in {retry_delay} seconds. "
118-
f"Error: {str(e)}"
119-
)
120-
logger.debug(f"Current active models: {self.active_models}")
121-
await asyncio.sleep(retry_delay)
83+
try:
84+
GPUInfo.log_gpu_info()
85+
model_class = model_factory(model_config.llm_model_id)
86+
model = model_class(
87+
model_id=model_config.llm_model_id,
88+
device=settings.shared_settings.NEURON_DEVICE,
89+
sampling_params=settings.shared_settings.SAMPLING_PARAMS,
90+
)
91+
self.active_models[model_config] = model
92+
self.used_ram += model_config.min_ram
93+
logger.info(
94+
f"Model {model_config.llm_model_id} has been successfully loaded. "
95+
f"Approx. used VRAM: {self.used_ram:.0f}GB"
96+
)
97+
await asyncio.sleep(1.0)
98+
return model
99+
except BaseException as e:
100+
await self.cleanup()
101+
# In case of VRAM leak, raise an exception to terminate the process.
102+
raise MemoryError(f"Failed to load model {model_config.llm_model_id}: {e}")
122103

123104
async def _cleanup_model(self, model_instance: ReproducibleVLLM, cpu_offload: bool = False):
124105
"""Free VRAM from given model."""
@@ -144,12 +125,10 @@ async def _unload_model(self, model_config: ModelConfig):
144125
return
145126

146127
try:
147-
model_instance = self.active_models.pop(model_config)
148-
149-
# Record initial memory state for debugging.
150128
initial_free_memory = GPUInfo.free_memory
151129
logger.debug(f"Initial free GPU memory before unloading: {initial_free_memory} GB")
152-
130+
# async with self.rlock:
131+
model_instance = self.active_models.pop(model_config)
153132
await self._cleanup_model(model_instance, cpu_offload=False)
154133
await self.cleanup()
155134

@@ -167,13 +146,13 @@ async def _unload_model(self, model_config: ModelConfig):
167146
async def get_model(self, llm_model: ModelConfig | str) -> ReproducibleVLLM:
168147
async with self.lock:
169148
if not llm_model:
170-
llm_model = list(self.active_models.keys())[0] if self.active_models else ModelZoo.get_random()
149+
llm_model = next(iter(self.active_models.keys())) if self.active_models else ModelZoo.get_random()
171150
if isinstance(llm_model, str):
172151
llm_model = ModelZoo.get_model_by_id(llm_model)
173152
if llm_model in self.active_models:
174153
return self.active_models[llm_model]
175154

176-
return await self.load_model(llm_model, force=True)
155+
return await self.load_model(llm_model)
177156

178157
async def generate(
179158
self,

prompting/llms/vllm_llm.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,17 @@ def unload_model(self):
191191
if hasattr(self.model, "llm_engine") and hasattr(self.model.llm_engine, "driver_worker"):
192192
del self.model.llm_engine.driver_worker
193193
if hasattr(self.model, "model"):
194-
self.model = None
195194
del self.model
196195
if hasattr(self.model, "tokenizer"):
197-
self.tokenizer = None
198196
del self.tokenizer
199197

200198
gc.collect()
201-
if torch.cuda.is_available():
202-
torch.cuda.empty_cache()
203199
if torch.distributed.is_initialized():
204200
torch.distributed.destroy_process_group()
205-
201+
if torch.cuda.is_available():
202+
torch.cuda.empty_cache()
206203
logger.info("Successfully deleted the LLM pipeline and freed GPU memory")
207-
208-
except Exception as e:
204+
except BaseException as e:
209205
logger.error(f"An error occurred during model unloading: {e}")
210206
gc.collect()
211207
if torch.cuda.is_available():

0 commit comments

Comments
 (0)