66import time
77from contextlib import contextmanager
88from copy import deepcopy
9+ from dataclasses import dataclass , replace
910from functools import cached_property
1011from http .server import BaseHTTPRequestHandler , HTTPServer
1112from pathlib import Path
1718# -------------------------------------- dumper core ------------------------------------------
1819
1920
21+ @dataclass (frozen = True )
22+ class _DumperConfig :
23+ enable : bool = False
24+ filter : Optional [str ] = None
25+ base_dir : Path = Path ("/tmp" )
26+ enable_output_file : bool = True
27+ enable_output_console : bool = True
28+ enable_value : bool = True
29+ enable_grad : bool = False
30+ enable_model_value : bool = True
31+ enable_model_grad : bool = True
32+ partial_name : Optional [str ] = None
33+ enable_http_server : bool = True
34+ cleanup_previous : bool = False
35+ collective_timeout : int = 60
36+
37+ @classmethod
38+ def from_env (cls ) -> "_DumperConfig" :
39+ return cls (
40+ enable = get_bool_env_var ("SGLANG_DUMPER_ENABLE" , "0" ),
41+ filter = _get_str_env_var ("SGLANG_DUMPER_FILTER" ),
42+ base_dir = Path (_get_str_env_var ("SGLANG_DUMPER_DIR" , "/tmp" )),
43+ enable_output_file = get_bool_env_var ("SGLANG_DUMPER_OUTPUT_FILE" , "1" ),
44+ enable_output_console = get_bool_env_var (
45+ "SGLANG_DUMPER_OUTPUT_CONSOLE" , "1"
46+ ),
47+ enable_value = get_bool_env_var ("SGLANG_DUMPER_ENABLE_VALUE" , "1" ),
48+ enable_grad = get_bool_env_var ("SGLANG_DUMPER_ENABLE_GRAD" , "0" ),
49+ enable_model_value = get_bool_env_var (
50+ "SGLANG_DUMPER_ENABLE_MODEL_VALUE" , "1"
51+ ),
52+ enable_model_grad = get_bool_env_var ("SGLANG_DUMPER_ENABLE_MODEL_GRAD" , "1" ),
53+ partial_name = _get_str_env_var ("SGLANG_DUMPER_PARTIAL_NAME" ),
54+ enable_http_server = get_bool_env_var (
55+ "SGLANG_ENABLE_DUMPER_HTTP_SERVER" , "1"
56+ ),
57+ cleanup_previous = get_bool_env_var ("SGLANG_DUMPER_CLEANUP_PREVIOUS" , "0" ),
58+ collective_timeout = 60 ,
59+ )
60+
61+
2062class _Dumper :
2163 """Utility to dump tensors, which can be useful when comparison checking models.
2264
@@ -44,75 +86,30 @@ class _Dumper:
4486 Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
4587 """
4688
47- def __init__ (
48- self ,
49- * ,
50- enable : bool ,
51- base_dir : Path ,
52- filter : Optional [str ] = None ,
53- enable_output_file : bool = True ,
54- enable_output_console : bool = True ,
55- enable_value : bool = True ,
56- enable_grad : bool = False ,
57- enable_model_value : bool = True ,
58- enable_model_grad : bool = True ,
59- partial_name : Optional [str ] = None ,
60- enable_http_server : bool = True ,
61- cleanup_previous : bool = False ,
62- collective_timeout : int = 60 ,
63- ):
64- # Config
65- self ._enable = enable
89+ def __init__ (self , * , config : _DumperConfig ):
6690 # TODO (1) support filtering kv instead of name only (2) allow HTTP req change it
67- self ._filter = filter
68- self ._base_dir = base_dir
69- self ._enable_output_file = enable_output_file
70- self ._enable_output_console = enable_output_console
71- self ._enable_value = enable_value
72- self ._enable_grad = enable_grad
73- self ._enable_model_value = enable_model_value
74- self ._enable_model_grad = enable_model_grad
75- self ._collective_timeout = collective_timeout
76-
77- # States
78- self ._partial_name = partial_name
91+ self ._config = config
92+
93+ self ._http_server_handled = not config .enable_http_server
94+ self ._cleanup_previous_handled = not config .cleanup_previous
95+
7996 self ._dump_index = 0
8097 self ._forward_pass_id = 0
81- self ._global_ctx = {}
82- self ._override_enable = None
98+ self ._global_ctx : dict = {}
99+ self ._override_enable : Optional [ bool ] = None
83100 self ._captured_output_data : Optional [dict ] = None
84- self ._http_server_handled = not enable_http_server
85- self ._pending_cleanup = cleanup_previous
86101
87102 @classmethod
88103 def from_env (cls ) -> "_Dumper" :
89- return cls (
90- enable = get_bool_env_var ("SGLANG_DUMPER_ENABLE" , "0" ),
91- base_dir = Path (_get_str_env_var ("SGLANG_DUMPER_DIR" , "/tmp" )),
92- filter = _get_str_env_var ("SGLANG_DUMPER_FILTER" ),
93- enable_output_file = get_bool_env_var ("SGLANG_DUMPER_OUTPUT_FILE" , "1" ),
94- enable_output_console = get_bool_env_var ("SGLANG_DUMPER_OUTPUT_CONSOLE" , "1" ),
95- enable_value = get_bool_env_var ("SGLANG_DUMPER_ENABLE_VALUE" , "1" ),
96- enable_grad = get_bool_env_var ("SGLANG_DUMPER_ENABLE_GRAD" , "0" ),
97- enable_model_value = get_bool_env_var (
98- "SGLANG_DUMPER_ENABLE_MODEL_VALUE" , "1"
99- ),
100- enable_model_grad = get_bool_env_var ("SGLANG_DUMPER_ENABLE_MODEL_GRAD" , "1" ),
101- partial_name = _get_str_env_var ("SGLANG_DUMPER_PARTIAL_NAME" ),
102- enable_http_server = get_bool_env_var (
103- "SGLANG_ENABLE_DUMPER_HTTP_SERVER" , "1"
104- ),
105- cleanup_previous = get_bool_env_var ("SGLANG_DUMPER_CLEANUP_PREVIOUS" , "0" ),
106- collective_timeout = 60 ,
107- )
104+ return cls (config = _DumperConfig .from_env ())
108105
109106 def on_forward_pass_start (self ):
110107 """This should be called on all ranks."""
111108
112109 # Even if SGLANG_DUMPER_ENABLE=0, users may want to use HTTP endpoint to enable it
113110 self ._ensure_http_server ()
114111
115- if not self ._enable :
112+ if not self ._config . enable :
116113 return
117114
118115 # Users may want to `dump` only on some ranks, thus determine name here
@@ -127,14 +124,15 @@ def _ensure_http_server(self):
127124 if self ._http_server_handled :
128125 return
129126 self ._http_server_handled = True
130- _start_maybe_http_server (self , timeout_seconds = self ._collective_timeout )
127+ _start_maybe_http_server (self , timeout_seconds = self ._config . collective_timeout )
131128
132129 def _ensure_partial_name (self ):
133- if self ._partial_name is None :
134- self . _partial_name = _get_partial_name (
135- timeout_seconds = self ._collective_timeout
130+ if self ._config . partial_name is None :
131+ name = _get_partial_name (
132+ timeout_seconds = self ._config . collective_timeout
136133 )
137- print (f"[Dumper] Choose partial_name={ self ._partial_name } " )
134+ self ._config = replace (self ._config , partial_name = name )
135+ print (f"[Dumper] Choose partial_name={ name } " )
138136
139137 def set_ctx (self , ** kwargs ):
140138 """
@@ -172,9 +170,9 @@ def dump(self, name: str, value, save: bool = True, **kwargs) -> None:
172170 value = value ,
173171 extra_kwargs = kwargs ,
174172 save = save ,
175- enable_value = self ._enable_value ,
173+ enable_value = self ._config . enable_value ,
176174 enable_curr_grad = False ,
177- enable_future_grad = self ._enable_grad ,
175+ enable_future_grad = self ._config . enable_grad ,
178176 value_tag = "Dumper.Value" ,
179177 grad_tag = "Dumper.Grad" ,
180178 )
@@ -192,8 +190,8 @@ def dump_model(
192190 value = param ,
193191 extra_kwargs = kwargs ,
194192 save = save ,
195- enable_value = self ._enable_model_value ,
196- enable_curr_grad = self ._enable_model_grad ,
193+ enable_value = self ._config . enable_model_value ,
194+ enable_curr_grad = self ._config . enable_model_grad ,
197195 enable_future_grad = False ,
198196 value_tag = "Dumper.ParamValue" ,
199197 grad_tag = "Dumper.ParamGrad" ,
@@ -214,9 +212,9 @@ def _dump_inner(
214212 ) -> None :
215213 self ._ensure_http_server ()
216214
217- if not (self ._enable and (self ._override_enable is not False )):
215+ if not (self ._config . enable and (self ._override_enable is not False )):
218216 return
219- if (f := self ._filter ) is not None and re .search (f , name ) is None :
217+ if (f := self ._config . filter ) is not None and re .search (f , name ) is None :
220218 return
221219 if not (enable_value or enable_curr_grad or enable_future_grad ):
222220 return
@@ -306,9 +304,9 @@ def _dump_single(
306304 ** self ._global_ctx ,
307305 )
308306 full_filename = "___" .join (f"{ k } ={ v } " for k , v in full_kwargs .items ()) + ".pt"
309- path = self ._base_dir / f"sglang_dump_{ self ._partial_name } " / full_filename
307+ path = self ._config . base_dir / f"sglang_dump_{ self ._config . partial_name } " / full_filename
310308
311- if self ._enable_output_console :
309+ if self ._config . enable_output_console :
312310 print (
313311 f"[{ tag } ] [{ rank } , { time .time ()} ] { path } "
314312 f"type={ type (value )} "
@@ -320,7 +318,7 @@ def _dump_single(
320318 )
321319
322320 capturing = self ._captured_output_data is not None
323- if save and (self ._enable_output_file or capturing ):
321+ if save and (self ._config . enable_output_file or capturing ):
324322 output_data = {
325323 "value" : value .data if isinstance (value , torch .nn .Parameter ) else value ,
326324 "meta" : dict (** full_kwargs , ** self ._static_meta ),
@@ -330,9 +328,9 @@ def _dump_single(
330328 output_data ["value" ] = _deepcopy_or_clone (output_data ["value" ])
331329 self ._captured_output_data [name ] = output_data
332330 else :
333- if self ._pending_cleanup :
334- self ._pending_cleanup = False
335- _cleanup_old_dumps (self ._base_dir )
331+ if not self ._cleanup_previous_handled :
332+ self ._cleanup_previous_handled = True
333+ _cleanup_old_dumps (self ._config . base_dir )
336334
337335 path .parent .mkdir (parents = True , exist_ok = True )
338336 _torch_save (output_data , str (path ))
@@ -603,7 +601,7 @@ def __init__(self, dumper):
603601
604602 def set_enable (self , enable : bool ):
605603 print (f"[DumperRpcHandler] set_enable { enable = } " )
606- self ._dumper ._enable = enable
604+ self ._dumper ._config = replace ( self . _dumper . _config , enable = enable )
607605
608606
609607# -------------------------------------- zmq rpc ------------------------------------------
0 commit comments