Skip to content

Commit ff729d1

Browse files
committed
refactor(dumper): extract _DumperConfig frozen dataclass
Move all 13 _Dumper init parameters into a @DataClass(frozen=True) _DumperConfig. _Dumper.__init__ now takes a single config parameter. - Runtime mutations (enable via HTTP, lazy partial_name) use dataclasses.replace() to swap the frozen config - _DumperConfig.from_env() centralizes env var parsing with defaults matching field defaults (verified by new UT) - Rename _pending_cleanup to _cleanup_previous_handled for consistency with _http_server_handled
1 parent e829bfe commit ff729d1

File tree

3 files changed

+95
-96
lines changed

3 files changed

+95
-96
lines changed

python/sglang/srt/debug_utils/dumper.py

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from contextlib import contextmanager
88
from copy import deepcopy
9+
from dataclasses import dataclass, replace
910
from functools import cached_property
1011
from http.server import BaseHTTPRequestHandler, HTTPServer
1112
from pathlib import Path
@@ -17,6 +18,47 @@
1718
# -------------------------------------- dumper core ------------------------------------------
1819

1920

21+
@dataclass(frozen=True)
22+
class _DumperConfig:
23+
enable: bool = False
24+
filter: Optional[str] = None
25+
base_dir: Path = Path("/tmp")
26+
enable_output_file: bool = True
27+
enable_output_console: bool = True
28+
enable_value: bool = True
29+
enable_grad: bool = False
30+
enable_model_value: bool = True
31+
enable_model_grad: bool = True
32+
partial_name: Optional[str] = None
33+
enable_http_server: bool = True
34+
cleanup_previous: bool = False
35+
collective_timeout: int = 60
36+
37+
@classmethod
38+
def from_env(cls) -> "_DumperConfig":
39+
return cls(
40+
enable=get_bool_env_var("SGLANG_DUMPER_ENABLE", "0"),
41+
filter=_get_str_env_var("SGLANG_DUMPER_FILTER"),
42+
base_dir=Path(_get_str_env_var("SGLANG_DUMPER_DIR", "/tmp")),
43+
enable_output_file=get_bool_env_var("SGLANG_DUMPER_OUTPUT_FILE", "1"),
44+
enable_output_console=get_bool_env_var(
45+
"SGLANG_DUMPER_OUTPUT_CONSOLE", "1"
46+
),
47+
enable_value=get_bool_env_var("SGLANG_DUMPER_ENABLE_VALUE", "1"),
48+
enable_grad=get_bool_env_var("SGLANG_DUMPER_ENABLE_GRAD", "0"),
49+
enable_model_value=get_bool_env_var(
50+
"SGLANG_DUMPER_ENABLE_MODEL_VALUE", "1"
51+
),
52+
enable_model_grad=get_bool_env_var("SGLANG_DUMPER_ENABLE_MODEL_GRAD", "1"),
53+
partial_name=_get_str_env_var("SGLANG_DUMPER_PARTIAL_NAME"),
54+
enable_http_server=get_bool_env_var(
55+
"SGLANG_ENABLE_DUMPER_HTTP_SERVER", "1"
56+
),
57+
cleanup_previous=get_bool_env_var("SGLANG_DUMPER_CLEANUP_PREVIOUS", "0"),
58+
collective_timeout=60,
59+
)
60+
61+
2062
class _Dumper:
2163
"""Utility to dump tensors, which can be useful when comparison checking models.
2264
@@ -44,75 +86,30 @@ class _Dumper:
4486
Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
4587
"""
4688

47-
def __init__(
48-
self,
49-
*,
50-
enable: bool,
51-
base_dir: Path,
52-
filter: Optional[str] = None,
53-
enable_output_file: bool = True,
54-
enable_output_console: bool = True,
55-
enable_value: bool = True,
56-
enable_grad: bool = False,
57-
enable_model_value: bool = True,
58-
enable_model_grad: bool = True,
59-
partial_name: Optional[str] = None,
60-
enable_http_server: bool = True,
61-
cleanup_previous: bool = False,
62-
collective_timeout: int = 60,
63-
):
64-
# Config
65-
self._enable = enable
89+
def __init__(self, *, config: _DumperConfig):
6690
# TODO (1) support filtering kv instead of name only (2) allow HTTP req change it
67-
self._filter = filter
68-
self._base_dir = base_dir
69-
self._enable_output_file = enable_output_file
70-
self._enable_output_console = enable_output_console
71-
self._enable_value = enable_value
72-
self._enable_grad = enable_grad
73-
self._enable_model_value = enable_model_value
74-
self._enable_model_grad = enable_model_grad
75-
self._collective_timeout = collective_timeout
76-
77-
# States
78-
self._partial_name = partial_name
91+
self._config = config
92+
93+
self._http_server_handled = not config.enable_http_server
94+
self._cleanup_previous_handled = not config.cleanup_previous
95+
7996
self._dump_index = 0
8097
self._forward_pass_id = 0
81-
self._global_ctx = {}
82-
self._override_enable = None
98+
self._global_ctx: dict = {}
99+
self._override_enable: Optional[bool] = None
83100
self._captured_output_data: Optional[dict] = None
84-
self._http_server_handled = not enable_http_server
85-
self._pending_cleanup = cleanup_previous
86101

87102
@classmethod
88103
def from_env(cls) -> "_Dumper":
89-
return cls(
90-
enable=get_bool_env_var("SGLANG_DUMPER_ENABLE", "0"),
91-
base_dir=Path(_get_str_env_var("SGLANG_DUMPER_DIR", "/tmp")),
92-
filter=_get_str_env_var("SGLANG_DUMPER_FILTER"),
93-
enable_output_file=get_bool_env_var("SGLANG_DUMPER_OUTPUT_FILE", "1"),
94-
enable_output_console=get_bool_env_var("SGLANG_DUMPER_OUTPUT_CONSOLE", "1"),
95-
enable_value=get_bool_env_var("SGLANG_DUMPER_ENABLE_VALUE", "1"),
96-
enable_grad=get_bool_env_var("SGLANG_DUMPER_ENABLE_GRAD", "0"),
97-
enable_model_value=get_bool_env_var(
98-
"SGLANG_DUMPER_ENABLE_MODEL_VALUE", "1"
99-
),
100-
enable_model_grad=get_bool_env_var("SGLANG_DUMPER_ENABLE_MODEL_GRAD", "1"),
101-
partial_name=_get_str_env_var("SGLANG_DUMPER_PARTIAL_NAME"),
102-
enable_http_server=get_bool_env_var(
103-
"SGLANG_ENABLE_DUMPER_HTTP_SERVER", "1"
104-
),
105-
cleanup_previous=get_bool_env_var("SGLANG_DUMPER_CLEANUP_PREVIOUS", "0"),
106-
collective_timeout=60,
107-
)
104+
return cls(config=_DumperConfig.from_env())
108105

109106
def on_forward_pass_start(self):
110107
"""This should be called on all ranks."""
111108

112109
# Even if SGLANG_DUMPER_ENABLE=0, users may want to use HTTP endpoint to enable it
113110
self._ensure_http_server()
114111

115-
if not self._enable:
112+
if not self._config.enable:
116113
return
117114

118115
# Users may want to `dump` only on some ranks, thus determine name here
@@ -127,14 +124,15 @@ def _ensure_http_server(self):
127124
if self._http_server_handled:
128125
return
129126
self._http_server_handled = True
130-
_start_maybe_http_server(self, timeout_seconds=self._collective_timeout)
127+
_start_maybe_http_server(self, timeout_seconds=self._config.collective_timeout)
131128

132129
def _ensure_partial_name(self):
133-
if self._partial_name is None:
134-
self._partial_name = _get_partial_name(
135-
timeout_seconds=self._collective_timeout
130+
if self._config.partial_name is None:
131+
name = _get_partial_name(
132+
timeout_seconds=self._config.collective_timeout
136133
)
137-
print(f"[Dumper] Choose partial_name={self._partial_name}")
134+
self._config = replace(self._config, partial_name=name)
135+
print(f"[Dumper] Choose partial_name={name}")
138136

139137
def set_ctx(self, **kwargs):
140138
"""
@@ -172,9 +170,9 @@ def dump(self, name: str, value, save: bool = True, **kwargs) -> None:
172170
value=value,
173171
extra_kwargs=kwargs,
174172
save=save,
175-
enable_value=self._enable_value,
173+
enable_value=self._config.enable_value,
176174
enable_curr_grad=False,
177-
enable_future_grad=self._enable_grad,
175+
enable_future_grad=self._config.enable_grad,
178176
value_tag="Dumper.Value",
179177
grad_tag="Dumper.Grad",
180178
)
@@ -192,8 +190,8 @@ def dump_model(
192190
value=param,
193191
extra_kwargs=kwargs,
194192
save=save,
195-
enable_value=self._enable_model_value,
196-
enable_curr_grad=self._enable_model_grad,
193+
enable_value=self._config.enable_model_value,
194+
enable_curr_grad=self._config.enable_model_grad,
197195
enable_future_grad=False,
198196
value_tag="Dumper.ParamValue",
199197
grad_tag="Dumper.ParamGrad",
@@ -214,9 +212,9 @@ def _dump_inner(
214212
) -> None:
215213
self._ensure_http_server()
216214

217-
if not (self._enable and (self._override_enable is not False)):
215+
if not (self._config.enable and (self._override_enable is not False)):
218216
return
219-
if (f := self._filter) is not None and re.search(f, name) is None:
217+
if (f := self._config.filter) is not None and re.search(f, name) is None:
220218
return
221219
if not (enable_value or enable_curr_grad or enable_future_grad):
222220
return
@@ -306,9 +304,9 @@ def _dump_single(
306304
**self._global_ctx,
307305
)
308306
full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
309-
path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
307+
path = self._config.base_dir / f"sglang_dump_{self._config.partial_name}" / full_filename
310308

311-
if self._enable_output_console:
309+
if self._config.enable_output_console:
312310
print(
313311
f"[{tag}] [{rank}, {time.time()}] {path} "
314312
f"type={type(value)} "
@@ -320,7 +318,7 @@ def _dump_single(
320318
)
321319

322320
capturing = self._captured_output_data is not None
323-
if save and (self._enable_output_file or capturing):
321+
if save and (self._config.enable_output_file or capturing):
324322
output_data = {
325323
"value": value.data if isinstance(value, torch.nn.Parameter) else value,
326324
"meta": dict(**full_kwargs, **self._static_meta),
@@ -330,9 +328,9 @@ def _dump_single(
330328
output_data["value"] = _deepcopy_or_clone(output_data["value"])
331329
self._captured_output_data[name] = output_data
332330
else:
333-
if self._pending_cleanup:
334-
self._pending_cleanup = False
335-
_cleanup_old_dumps(self._base_dir)
331+
if not self._cleanup_previous_handled:
332+
self._cleanup_previous_handled = True
333+
_cleanup_old_dumps(self._config.base_dir)
336334

337335
path.parent.mkdir(parents=True, exist_ok=True)
338336
_torch_save(output_data, str(path))
@@ -603,7 +601,7 @@ def __init__(self, dumper):
603601

604602
def set_enable(self, enable: bool):
605603
print(f"[DumperRpcHandler] set_enable {enable=}")
606-
self._dumper._enable = enable
604+
self._dumper._config = replace(self._dumper._config, enable=enable)
607605

608606

609607
# -------------------------------------- zmq rpc ------------------------------------------

test/registered/debug_utils/test_dump_comparator.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,22 @@ def test_main(self):
8989
from argparse import Namespace
9090

9191
from sglang.srt.debug_utils.dump_comparator import main
92-
from sglang.srt.debug_utils.dumper import _Dumper
92+
from sglang.srt.debug_utils.dumper import _Dumper, _DumperConfig
9393

9494
with tempfile.TemporaryDirectory() as d1, tempfile.TemporaryDirectory() as d2:
9595
baseline_tensor = torch.randn(10, 10)
9696
target_tensor = baseline_tensor + torch.randn(10, 10) * 0.01
9797

9898
dump_dirs = []
9999
for d, tensor in [(d1, baseline_tensor), (d2, target_tensor)]:
100-
with _with_env("SGLANG_DUMPER_DIR", d), _with_env(
101-
"SGLANG_DUMPER_SERVER_PORT", "-1"
102-
):
103-
dumper = _Dumper()
104-
dumper.on_forward_pass_start()
105-
dumper.dump("tensor_a", tensor)
106-
dumper.on_forward_pass_start()
107-
dumper.dump("tensor_b", tensor * 2)
108-
dump_dirs.append(Path(d) / f"sglang_dump_{dumper._partial_name}")
100+
dumper = _Dumper(config=_DumperConfig(
101+
enable=True, base_dir=Path(d), enable_http_server=False,
102+
))
103+
dumper.on_forward_pass_start()
104+
dumper.dump("tensor_a", tensor)
105+
dumper.on_forward_pass_start()
106+
dumper.dump("tensor_b", tensor * 2)
107+
dump_dirs.append(Path(d) / f"sglang_dump_{dumper._config.partial_name}")
109108

110109
args = Namespace(
111110
baseline_path=str(dump_dirs[0]),

test/registered/debug_utils/test_dumper.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_collect_sglang_parallel_info,
1616
_collective_with_timeout,
1717
_Dumper,
18+
_DumperConfig,
1819
_materialize_value,
1920
_obj_to_dict,
2021
_torch_save,
@@ -40,6 +41,11 @@ def _capture_stdout():
4041
sys.stdout = old_stdout
4142

4243

44+
class TestDumperConfig:
45+
def test_from_env_defaults_match_dataclass_defaults(self):
46+
assert _DumperConfig.from_env() == _DumperConfig()
47+
48+
4349
class TestDumperPureFunctions:
4450
def test_get_truncated_value(self):
4551
assert get_truncated_value(None) is None
@@ -175,11 +181,9 @@ def test_collective_timeout(self):
175181
@staticmethod
176182
def _test_collective_timeout_func(rank):
177183
dumper = _Dumper(
178-
enable=True,
179-
base_dir=Path("/tmp"),
180-
partial_name=None,
181-
enable_http_server=False,
182-
collective_timeout=3,
184+
config=_DumperConfig(
185+
enable=True, collective_timeout=3, enable_http_server=False,
186+
),
183187
)
184188

185189
with _capture_stdout() as captured:
@@ -202,7 +206,7 @@ def test_http_enable(self):
202206
def _test_http_func(rank):
203207
from sglang.srt.debug_utils.dumper import dumper
204208

205-
assert not dumper._enable
209+
assert not dumper._config.enable
206210
dumper.on_forward_pass_start()
207211

208212
for enable in [True, False]:
@@ -213,7 +217,7 @@ def _test_http_func(rank):
213217
"http://localhost:40000/dumper", json={"enable": enable}
214218
).raise_for_status()
215219
dist.barrier()
216-
assert dumper._enable == enable
220+
assert dumper._config.enable == enable
217221

218222
def test_file_content_correctness(self, tmp_path):
219223
with temp_set_env(
@@ -407,13 +411,11 @@ def test_dict_format_with_context(self, tmp_path):
407411

408412
def _make_test_dumper(tmp_path: Path, **overrides) -> _Dumper:
409413
"""Create a _Dumper for CPU testing without HTTP server or distributed."""
410-
defaults: dict = dict(
411-
enable=True,
412-
base_dir=tmp_path,
413-
partial_name="test",
414-
enable_http_server=False,
414+
config = _DumperConfig(
415+
enable=True, base_dir=tmp_path, partial_name="test",
416+
enable_http_server=False, **overrides,
415417
)
416-
d = _Dumper(**{**defaults, **overrides})
418+
d = _Dumper(config=config)
417419
d.on_forward_pass_start()
418420
return d
419421

0 commit comments

Comments
 (0)