11"""
2- Patch for dynamo.trtllm to add torch profiler support.
2+ Patch for dynamo.trtllm v0.8.1 to add torch profiler support.
33
4- Wraps inference with torch.profiler.profile() and writes chrome trace
5- JSON to SGLANG_TORCH_PROFILER_DIR.
4+ Wraps HandlerBase.generate_locally() with torch.profiler.profile() and
5+ writes chrome trace JSON to SGLANG_TORCH_PROFILER_DIR.
6+
7+ Target: dynamo.trtllm.request_handlers.handler_base.HandlerBase.generate_locally
8+ - Async generator: async def generate_locally(self, request, context, embeddings=None)
9+ - Called by PrefillHandler.generate(), DecodeHandler.generate(), AggregatedHandler.generate()
10+ - Each call = one request; "step" here counts requests processed
611
712Environment variables:
813 PROFILING_MODE - "prefill" or "decode" (from srtctl)
1116 SGLANG_TORCH_PROFILER_DIR - output dir for traces
1217
1318Applied by appending import+call to handler_base.py via setup script.
14- Falls back to a post-import hook if used standalone.
1519"""
1620
17- import importlib
18- import logging
1921import os
2022import sys
2123
22- logger = logging .getLogger ("trtllm-profiling-patch" )
23-
2424
2525def _apply_patch ():
26- """Actually monkey -patch HandlerBase.generate with profiler wrapping."""
26+ """Monkey -patch HandlerBase.generate_locally with profiler wrapping."""
2727 mode = os .environ .get ("PROFILING_MODE" , "" )
2828 start_step = int (os .environ .get (f"PROFILE_{ mode .upper ()} _START_STEP" ,
2929 os .environ .get ("PROFILE_START_STEP" , "5" )))
@@ -34,9 +34,16 @@ def _apply_patch():
3434 import torch .profiler
3535 from dynamo .trtllm .request_handlers .handler_base import HandlerBase
3636
37+ if not hasattr (HandlerBase , "generate_locally" ):
38+ methods = [m for m in dir (HandlerBase )
39+ if not m .startswith ('_' ) and callable (getattr (HandlerBase , m , None ))]
40+ print (f"[trtllm-patch] ERROR: HandlerBase has no generate_locally. "
41+ f"Available: { methods } " , file = sys .stderr , flush = True )
42+ return
43+
3744 os .makedirs (output_dir , exist_ok = True )
3845
39- _orig_generate = HandlerBase .generate
46+ _orig_generate_locally = HandlerBase .generate_locally
4047 _state = {"step" : 0 , "started" : False , "stopped" : False }
4148
4249 _state ["profiler" ] = torch .profiler .profile (
@@ -49,76 +56,45 @@ def _apply_patch():
4956 on_trace_ready = torch .profiler .tensorboard_trace_handler (output_dir ),
5057 )
5158
52- async def _patched_generate (self , request , * args , ** kwargs ):
59+ async def _patched_generate_locally (self , request , context , embeddings = None ):
5360 _state ["step" ] += 1
5461 step = _state ["step" ]
5562
5663 if step == start_step and not _state ["started" ]:
5764 _state ["profiler" ].__enter__ ()
5865 _state ["started" ] = True
59- print (f"[trtllm-patch] Step { step } : profiler started" , file = sys .stderr , flush = True )
66+ print (f"[trtllm-patch] Step { step } : profiler started" ,
67+ file = sys .stderr , flush = True )
6068
61- result = _orig_generate (self , request , * args , ** kwargs )
62- if hasattr (result , '__aiter__' ):
63- async for chunk in result :
64- yield chunk
65- else :
66- yield await result
69+ async for chunk in _orig_generate_locally (self , request , context , embeddings ):
70+ yield chunk
6771
6872 if step == stop_step and not _state ["stopped" ]:
6973 import torch .cuda
7074 torch .cuda .synchronize ()
7175 _state ["profiler" ].__exit__ (None , None , None )
7276 _state ["stopped" ] = True
73- print (f"[trtllm-patch] Step { step } : profiler stopped, traces in { output_dir } " , file = sys .stderr , flush = True )
77+ print (f"[trtllm-patch] Step { step } : profiler stopped, "
78+ f"traces in { output_dir } " , file = sys .stderr , flush = True )
7479
75- HandlerBase .generate = _patched_generate
76- print (f"[trtllm-patch] Patched HandlerBase.generate (steps { start_step } -{ stop_step } , output={ output_dir } )" , file = sys .stderr , flush = True )
80+ HandlerBase .generate_locally = _patched_generate_locally
81+ print (f"[trtllm-patch] Patched HandlerBase.generate_locally "
82+ f"(steps { start_step } -{ stop_step } , output={ output_dir } )" ,
83+ file = sys .stderr , flush = True )
7784
7885
7986def patch ():
80- """Patch HandlerBase.generate with profiler wrapping .
87+ """Apply the profiling patch to HandlerBase .
8188
82- When this file is appended to handler_base.py, HandlerBase is already
83- defined in the current module scope — so we apply the patch directly.
84- If called standalone before the module is imported, we install a
85- post-import hook as a fallback.
89+ When appended to handler_base.py, HandlerBase is already defined,
90+ so the import succeeds and we patch immediately.
8691 """
8792 mode = os .environ .get ("PROFILING_MODE" , "" )
8893 if not mode :
8994 return
9095
91- # If HandlerBase is already importable (e.g. this code was appended to
92- # handler_base.py, or the module was already imported), patch immediately.
9396 try :
94- from dynamo .trtllm .request_handlers .handler_base import HandlerBase # noqa: F401
9597 _apply_patch ()
96- return
97- except ImportError :
98- pass
9998 except Exception as e :
100- print (f"[trtllm-patch] Direct patch failed: { e } " , file = sys .stderr , flush = True )
101- return
102-
103- # Fallback: install a meta path finder for deferred patching
104- class _PatchFinder :
105- _patched = False
106-
107- def find_module (self , fullname , path = None ):
108- if fullname == "dynamo.trtllm.request_handlers.handler_base" and not self ._patched :
109- return self
110- return None
111-
112- def load_module (self , fullname ):
113- self ._patched = True
114- if self in sys .meta_path :
115- sys .meta_path .remove (self )
116- mod = importlib .import_module (fullname )
117- try :
118- _apply_patch ()
119- except Exception as e :
120- print (f"[trtllm-patch] Failed: { e } " , file = sys .stderr , flush = True )
121- return mod
122-
123- sys .meta_path .insert (0 , _PatchFinder ())
124- print (f"[trtllm-patch] Installed post-import hook for HandlerBase" , file = sys .stderr , flush = True )
99+ print (f"[trtllm-patch] Failed to apply patch: { e } " ,
100+ file = sys .stderr , flush = True )
0 commit comments