Skip to content

Commit 65f7382

Browse files
[nightshift] Remove dead functions and orphaned imports (#4390)
> *Dead branches pruned—* > *the forest breathes easier,* > *roots drink deeper now.* ## Summary Removed 5 private functions (85 lines) that are defined but never called anywhere in the codebase, confirmed via whole-repo grep. Also removed 3 imports (`re`, `jax`, `jax.numpy`) and 1 import (`get_scheduling_strategy`) that only served the dead code. | File | Removed | Age | |---|---|---| | `rl/environments/inference_ctx/vllm.py` | `_convert_vllm_state_dict_to_trainer_keys` (39 lines), `_check_weight_differences` (12 lines) + imports `re`, `jax`, `jax.numpy` | Nov 2025 — leftover debugging helpers from initial vLLM RL integration | | `processing/classification/autoscaler.py` | `_cleanup_completed_futures` (14 lines) | Nov 2025 — superseded by `_result_collector_loop` | | `evaluation/evaluators/evaluator.py` | `_get_scheduling_strategy` (4 lines) + import `get_scheduling_strategy` | Nov 2025 — never called by any `Evaluator` subclass | | `profiling/ingest.py` | `_active_category` (5 lines) | Feb 2026 — superseded by `_active_device_category` |
1 parent f3f2c4d commit 65f7382

4 files changed

Lines changed: 0 additions & 85 deletions

File tree

lib/marin/src/marin/evaluation/evaluators/evaluator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any
88

99
from fray.v1.cluster import Entrypoint, EnvironmentConfig, JobRequest, ResourceConfig, current_cluster
10-
from fray.v1.cluster.ray import get_scheduling_strategy
1110

1211
from marin.evaluation.evaluation_config import EvalTaskConfig
1312
from marin.utils import remove_tpu_lockfile_on_exit
@@ -58,11 +57,6 @@ class ModelConfig:
5857

5958

6059
class Evaluator(ABC):
61-
def _get_scheduling_strategy(self, resource_config: ResourceConfig | None):
62-
if resource_config is None:
63-
return None
64-
return get_scheduling_strategy(resource_config)
65-
6660
@abstractmethod
6761
def launch_evaluate_with_ray(
6862
self,

lib/marin/src/marin/processing/classification/autoscaler.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -373,22 +373,6 @@ def _dispatch_task(self, task: list[dict[str, Any]]) -> ray.ObjectRef:
373373
logger.debug(f"Assigning to actor: {actor} for task: {task}")
374374
return future
375375

376-
def _cleanup_completed_futures(self):
377-
"""Remove completed futures from tracking."""
378-
with self.actor_task_metadata_lock:
379-
for actor in self.actors:
380-
if actor in self.actor_futures:
381-
# Filter out completed futures
382-
pending = []
383-
for future in self.actor_futures[actor]:
384-
ready, _ = ray.wait([future], timeout=0)
385-
if not ready:
386-
pending.append(future)
387-
else:
388-
self.future_to_actor.pop(future, None)
389-
self.future_to_task.pop(future, None)
390-
self.actor_futures[actor] = pending
391-
392376
def shutdown(self):
393377
"""Shutdown the actor pool and clean up resources."""
394378
logger.info("Shutting down actor pool...")

lib/marin/src/marin/profiling/ingest.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,13 +1533,6 @@ def _estimate_periodicity(steps: list[int]) -> int | None:
15331533
return best_diff
15341534

15351535

1536-
def _active_category(active: dict[str, int]) -> str | None:
1537-
for category in ("communication", "compute", "stall", "host", "other"):
1538-
if active[category] > 0:
1539-
return category
1540-
return None
1541-
1542-
15431536
def _active_device_category(active: dict[str, int]) -> str | None:
15441537
for category in ("communication", "compute"):
15451538
if active[category] > 0:

lib/marin/src/marin/rl/environments/inference_ctx/vllm.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
import logging
66
import os
77
import time
8-
import re
9-
import jax
10-
import jax.numpy as jnp
118
import numpy as np
129
from enum import StrEnum
1310
from dataclasses import dataclass
@@ -150,59 +147,6 @@ def _get_llm_engine(inference_config: vLLMInferenceContextConfig):
150147
enforce_eager=inference_config.enforce_eager,
151148
)
152149

153-
def _convert_vllm_state_dict_to_trainer_keys(
154-
self, state_dict_trainer: dict, state_dict_vllm: dict, mapping: dict
155-
) -> dict:
156-
state_dict_vllm_with_trainer_keys = {}
157-
for src_path, _ in state_dict_trainer.items():
158-
src_key = ".".join(str(p) for p in src_path)
159-
160-
# Try to find a matching pattern
161-
matched = False
162-
for src_pattern, (dst_pattern, _) in mapping.items():
163-
164-
if not re.match(src_pattern, src_key):
165-
continue
166-
167-
match_layer_number = re.match(r".*layers\.(\d+).*", src_key)
168-
if match_layer_number:
169-
layer_number = int(match_layer_number.group(1))
170-
dst_path = []
171-
for part in dst_pattern.split("."):
172-
if part == "*":
173-
dst_path.append(layer_number)
174-
else:
175-
dst_path.append(part)
176-
dst_path = tuple(dst_path)
177-
if dst_path in state_dict_vllm:
178-
state_dict_vllm_with_trainer_keys[src_path] = state_dict_vllm[dst_path]
179-
matched = True
180-
break
181-
else:
182-
dst_path = tuple(dst_pattern.split("."))
183-
if dst_path in state_dict_vllm:
184-
state_dict_vllm_with_trainer_keys[src_path] = state_dict_vllm[dst_path]
185-
matched = True
186-
break
187-
188-
if not matched:
189-
print(f"Warning: No mapping found for {src_key}")
190-
191-
return state_dict_vllm_with_trainer_keys
192-
193-
def _check_weight_differences(self, state_dict: dict, state_dict_other: dict):
194-
for key in state_dict:
195-
if key in state_dict_other:
196-
assert (
197-
state_dict[key].shape == state_dict_other[key].shape
198-
), f"Shape mismatch for key {key}: {state_dict[key].shape} != {state_dict_other[key].shape}"
199-
weight = jax.device_get(state_dict[key]).astype(jnp.bfloat16)
200-
weight_other = jax.device_get(state_dict_other[key]).astype(jnp.bfloat16)
201-
print(
202-
f"Weight {key}, max diff: {jnp.max(jnp.abs(weight - weight_other))}, \
203-
mean diff: {jnp.mean(jnp.abs(weight - weight_other))}"
204-
)
205-
206150
def tokenize_prompt(self, prompt: str, choice: Choice | None = None, system_prompt: str | None = None) -> np.ndarray:
207151
"""Tokenize the prompt with the choice's prompt token IDs.
208152

0 commit comments

Comments
 (0)