Skip to content

Commit 58e664a

Browse files
committed
Kill processes properly; Minimize subtensor calls
1 parent 605ce3a commit 58e664a

File tree

6 files changed

+63
-43
lines changed

6 files changed

+63
-43
lines changed

neurons/validator.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
import os
3+
import signal
24
import sys
35
from multiprocessing.managers import AcquirerProxy
46

@@ -36,7 +38,7 @@ async def create_loop_process(
3638
reward_events: list,
3739
miners_dict: dict,
3840
mp_lock: AcquirerProxy,
39-
) -> None:
41+
):
4042
# Load settings and initialize external services.
4143
settings.shared_settings = settings.SharedSettings.load(mode="validator")
4244
if settings.shared_settings.WANDB_ON:
@@ -45,8 +47,10 @@ async def create_loop_process(
4547
# A list to keep references to all the tasks we spawn, so they can be cancelled later.
4648
all_tasks: list[asyncio.Task] = []
4749

48-
async def cleanup():
50+
async def cleanup(model_scheduler):
4951
logger.info("Cleaning up resources...")
52+
torch.distributed.destroy_process_group()
53+
await model_scheduler.llm_model.cleanup()
5054
for t in all_tasks:
5155
t.cancel()
5256
await asyncio.gather(*all_tasks, return_exceptions=True)
@@ -88,12 +92,12 @@ async def spawn_loops(task_queue: list, scoring_queue: list, reward_events: list
8892
await spawn_loops(task_queue, scoring_queue, reward_events, miners_dict)
8993
except MemoryError as e:
9094
logger.error(f"MemoryError encountered. Terminating program: {e}")
91-
await cleanup()
95+
await cleanup(model_scheduler)
9296
sys.exit(1)
9397
except Exception as e:
9498
logger.exception(f"Terminating loop process: {e}")
9599
finally:
96-
await cleanup()
100+
await cleanup(model_scheduler)
97101

98102

99103
def start_api(
@@ -260,10 +264,10 @@ async def main(
260264
step = 0
261265
while True:
262266
await asyncio.sleep(30)
263-
block = settings.shared_settings.SUBTENSOR.get_current_block()
267+
block = settings.shared_settings.block
264268
if (
265269
block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 500
266-
and step > 120
270+
and step > 150
267271
):
268272
last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID]
269273
logger.warning(
@@ -279,17 +283,27 @@ async def main(
279283
logger.error(f"Main loop error: {e}")
280284
raise
281285
finally:
286+
logger.warning("🚨 Force‑killing entire process‑group")
287+
288+
# 1. Cancel in‑process tasks so they stop touching the Manager.
282289
for t in tasks:
283290
t.cancel()
284-
await asyncio.gather(*tasks)
291+
await asyncio.gather(*tasks, return_exceptions=True)
292+
293+
# 2. Manager cleanup *first* (so its socket vanishes).
294+
manager.shutdown()
285295

286-
for process in processes:
287-
if process.is_alive():
288-
process.terminate()
289-
process.join()
290-
sys.exit(1)
296+
# 3. Sledgehammer.
297+
if os.name == "posix":
298+
os.killpg(0, signal.SIGKILL)
299+
else:
300+
logger.error(f"Unsupported OS: {os.name}")
301+
sys.exit(1)
291302

292303

293304
# The main function parses the configuration and runs the validator.
294305
if __name__ == "__main__":
306+
if os.name == "posix":
307+
# Become the leader of a new process group.
308+
os.setpgrp()
295309
asyncio.run(main())

prompting/datasets/huggingface_github.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def next(self) -> HuggingFaceGithubDatasetEntry | None:
121121
return self._process_entry(entry)
122122
except BaseException as e:
123123
logger.debug(f"Failed to sample from shard, skipping: {e}")
124-
raise ValueError(f"Failed to get sample from shard after {RETRIES} retries")
124+
raise ValueError(f"Failed to get sample from shard after {RETRIES} retries.")
125125

126126
def get(self) -> HuggingFaceGithubDatasetEntry:
127127
return self.next()

prompting/llms/model_manager.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from functools import partial
23
import gc
34
from multiprocessing.managers import AcquirerProxy
45
from typing import ClassVar
@@ -81,19 +82,22 @@ async def load_model(self, model_config: ModelConfig, force: bool = True) -> Rep
8182
logger.debug(f"Unloading {active_model.llm_model_id} to make room for {model_config.llm_model_id}")
8283

8384
await self._unload_model(active_model)
84-
await self._vram_cleanup()
85+
await self.cleanup()
8586

8687
retries_max = 1
8788
retry_counter = 0
8889
retry_delay = 15
8990
while True:
9091
try:
9192
GPUInfo.log_gpu_info()
92-
model = model_factory(model_config.llm_model_id)(
93+
# Wrap blocking model loading into thread.
94+
loader = partial(
95+
model_factory(model_config.llm_model_id),
9396
model_id=model_config.llm_model_id,
9497
device=settings.shared_settings.NEURON_DEVICE,
9598
sampling_params=settings.shared_settings.SAMPLING_PARAMS,
9699
)
100+
model: ReproducibleVLLM = await asyncio.to_thread(loader)
97101
self.used_ram += model_config.min_ram
98102
logger.info(
99103
f"Model {model_config.llm_model_id} has been successfully loaded. "
@@ -105,13 +109,13 @@ async def load_model(self, model_config: ModelConfig, force: bool = True) -> Rep
105109
except BaseException as e:
106110
if retry_counter > retries_max:
107111
logger.error(f"Failed to load model after {retries_max} retries. Terminating process")
108-
await self._vram_cleanup()
112+
await self.cleanup()
109113
# In case of VRAM leak, raise an exception to terminate the process.
110114
raise MemoryError
111115

112116
retry_counter += 1
113117
retry_delay += retry_counter
114-
await self._vram_cleanup()
118+
await self.cleanup()
115119
logger.error(
116120
f"Failed to load model {model_config.llm_model_id}. Retrying in {retry_delay} seconds. "
117121
f"Error: {str(e)}"
@@ -150,7 +154,7 @@ async def _unload_model(self, model_config: ModelConfig):
150154
logger.debug(f"Initial free GPU memory before unloading: {initial_free_memory} GB")
151155

152156
await self._cleanup_model(model_instance, cpu_offload=False)
153-
await self._vram_cleanup()
157+
await self.cleanup()
154158

155159
memory_freed = GPUInfo.free_memory - initial_free_memory
156160
logger.info(f"Successfully unloaded model {model_config.llm_model_id}. Memory freed: {memory_freed:.2f} GB")
@@ -219,7 +223,7 @@ async def generate_logits(
219223
continue_last_message=continue_last_message,
220224
)
221225

222-
async def _vram_cleanup(self):
226+
async def cleanup(self):
223227
"""Perform VRAM clean-up."""
224228
for _, model in self.active_models.items():
225229
del model.model

prompting/llms/vllm_llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def unload_model(self):
209209
if torch.cuda.is_available():
210210
torch.cuda.empty_cache()
211211

212+
def __del__(self):
213+
self.unload_model()
214+
212215
@staticmethod
213216
def format_messages(messages: list[str] | list[dict[str, str]]) -> list[dict[str, str | list[dict[str, str]]]]:
214217
return messages

prompting/tasks/task_sending.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from prompting.miner_availability.miner_availability import MinerAvailabilities
88

9-
# from prompting.rewards.scoring import task_scorer
109
from prompting.rewards.scoring_config import ScoringConfig
1110
from prompting.tasks.base_task import BaseTextTask
1211
from prompting.tasks.inference import InferenceTask
@@ -77,9 +76,6 @@ async def collect_responses(task: BaseTextTask, miners_dict: dict) -> DendriteRe
7776
class TaskSender(AsyncLoopRunner):
7877
interval: int = 10
7978
_lock: asyncio.Lock = asyncio.Lock()
80-
block_sync_last_time: float = 0
81-
block_sync_interval: float = 300
82-
8379
task_queue: list | None = None
8480
scoring_queue: list | None = None
8581
miners_dict: dict | None = None
@@ -93,17 +89,6 @@ async def start(self, task_queue, scoring_queue, miners_dict, **kwargs):
9389
self.miners_dict = miners_dict
9490
return await super().start(**kwargs)
9591

96-
@property
97-
def block(self) -> int:
98-
time_since_last_block = time.time() - self.block_sync_last_time
99-
if time_since_last_block > self.block_sync_interval:
100-
self._block = shared_settings.SUBTENSOR.get_current_block()
101-
self.block_sync_last_time = time.time()
102-
return self._block
103-
104-
blocks_passed = time_since_last_block // 12
105-
return self._block + blocks_passed
106-
10792
async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
10893
logger.info("Checking for tasks to be sent...")
10994
while len(self.scoring_queue) > shared_settings.SCORING_QUEUE_LENGTH_THRESHOLD:
@@ -124,15 +109,15 @@ async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
124109
task=task,
125110
response=response_event,
126111
dataset_entry=task.dataset_entry,
127-
block=self.block,
112+
block=shared_settings.block,
128113
step=self.step,
129114
task_id=task.task_id,
130115
)
131116
self.scoring_queue.append(scoring_config)
132117

133118
# Log the step event.
134119
return ValidatorLoggingEvent(
135-
block=self.block,
120+
block=shared_settings.block,
136121
step=self.step,
137122
step_time=timer.final_time,
138123
response_event=response_event,

shared/settings.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import bittensor as bt
1717
import dotenv
1818
from bittensor.core.metagraph import Metagraph
19+
from bittensor.core.subtensor import Subtensor
1920
from loguru import logger
2021
from pydantic import Field, model_validator
2122
from pydantic_settings import BaseSettings
@@ -32,6 +33,9 @@ class SharedSettings(BaseSettings):
3233
_instance_mode: Optional[str] = None
3334
_last_metagraph: Metagraph = None
3435
_last_update_time: float = 0
36+
_block_sync_last_time: float = 0
37+
_block_sync_interval: float = 300
38+
_subtensor: Subtensor | None = None
3539

3640
mode: Literal["api", "validator", "miner", "mock"] = Field("validator", env="MODE")
3741
MOCK: bool = False
@@ -258,16 +262,20 @@ def WALLET(self):
258262
return bt.wallet(name=wallet_name, hotkey=hotkey)
259263

260264
@cached_property
261-
def SUBTENSOR(self) -> bt.subtensor:
265+
def SUBTENSOR(self) -> Subtensor:
266+
"""Lazy subtensor initialization."""
267+
if self._subtensor is not None:
268+
return self._subtensor
262269
# TODO: Move chain-related stuff out of settings.
263270
subtensor_network = self.SUBTENSOR_NETWORK or os.environ.get("SUBTENSOR_NETWORK", "local")
264271
# bt_config = config()
265272
if subtensor_network.lower() == "local":
266273
subtensor_network = os.environ.get("SUBTENSOR_CHAIN_ENDPOINT") # bt_config.subtensor.chain_endpoint or
267274
else:
268-
subtensor_network = subtensor_network.lower() # bt_config.subtensor.network or
275+
subtensor_network = subtensor_network.lower()
269276
logger.info(f"Instantiating subtensor with network: {subtensor_network}")
270-
return bt.subtensor(network=subtensor_network)
277+
self._subtensor = Subtensor(network=subtensor_network)
278+
return self._subtensor
271279

272280
@property
273281
def METAGRAPH(self) -> Metagraph:
@@ -294,11 +302,17 @@ def UID(self) -> int:
294302
# TODO: Move chain-related stuff out of settings.
295303
return self.METAGRAPH.hotkeys.index(self.WALLET.hotkey.ss58_address)
296304

297-
@cached_property
298-
def DENDRITE(self) -> bt.dendrite:
305+
@property
306+
def block(self) -> int:
299307
# TODO: Move chain-related stuff out of settings.
300-
logger.info(f"Instantiating dendrite with wallet: {self.WALLET}")
301-
return bt.dendrite(wallet=self.WALLET)
308+
time_since_last_block = time.time() - self._block_sync_last_time
309+
if time_since_last_block > self._block_sync_interval:
310+
self._block = self.SUBTENSOR.get_current_block()
311+
self._block_sync_last_time = time.time()
312+
return self._block
313+
314+
blocks_passed = time_since_last_block // 12
315+
return self._block + blocks_passed
302316

303317

304318
try:

0 commit comments

Comments
 (0)