Skip to content

Commit fdb04a3

Browse files
committed
fix
1 parent 49ccd96 commit fdb04a3

File tree

2 files changed

+39
-58
lines changed

2 files changed

+39
-58
lines changed

benchmarks/benchmark_lib.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,13 @@ fi
481481
482482
echo "[profile-setup] Found handler at: $HANDLER"
483483
484+
# Guard: only append once (multiple MPI ranks run this script)
485+
if grep -q "InferenceX profiling patch" "$HANDLER" 2>/dev/null; then
486+
echo "[profile-setup] Patch already applied, skipping"
487+
exit 0
488+
fi
489+
484490
# Append the patch import and call to the end of handler_base.py
485-
# The patch installs a post-import hook that wraps generate() with torch.profiler
486491
cat >> "$HANDLER" <<'PATCH_APPEND'
487492
488493
# --- InferenceX profiling patch (auto-appended) ---

patches/trtllm-profiling.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
"""
2-
Patch for dynamo.trtllm to add torch profiler support.
2+
Patch for dynamo.trtllm v0.8.1 to add torch profiler support.
33
4-
Wraps inference with torch.profiler.profile() and writes chrome trace
5-
JSON to SGLANG_TORCH_PROFILER_DIR.
4+
Wraps HandlerBase.generate_locally() with torch.profiler.profile() and
5+
writes chrome trace JSON to SGLANG_TORCH_PROFILER_DIR.
6+
7+
Target: dynamo.trtllm.request_handlers.handler_base.HandlerBase.generate_locally
8+
- Async generator: async def generate_locally(self, request, context, embeddings=None)
9+
- Called by PrefillHandler.generate(), DecodeHandler.generate(), AggregatedHandler.generate()
10+
- Each call = one request; "step" here counts requests processed
611
712
Environment variables:
813
PROFILING_MODE - "prefill" or "decode" (from srtctl)
@@ -11,19 +16,14 @@
1116
SGLANG_TORCH_PROFILER_DIR - output dir for traces
1217
1318
Applied by appending import+call to handler_base.py via setup script.
14-
Falls back to a post-import hook if used standalone.
1519
"""
1620

17-
import importlib
18-
import logging
1921
import os
2022
import sys
2123

22-
logger = logging.getLogger("trtllm-profiling-patch")
23-
2424

2525
def _apply_patch():
26-
"""Actually monkey-patch HandlerBase.generate with profiler wrapping."""
26+
"""Monkey-patch HandlerBase.generate_locally with profiler wrapping."""
2727
mode = os.environ.get("PROFILING_MODE", "")
2828
start_step = int(os.environ.get(f"PROFILE_{mode.upper()}_START_STEP",
2929
os.environ.get("PROFILE_START_STEP", "5")))
@@ -34,9 +34,16 @@ def _apply_patch():
3434
import torch.profiler
3535
from dynamo.trtllm.request_handlers.handler_base import HandlerBase
3636

37+
if not hasattr(HandlerBase, "generate_locally"):
38+
methods = [m for m in dir(HandlerBase)
39+
if not m.startswith('_') and callable(getattr(HandlerBase, m, None))]
40+
print(f"[trtllm-patch] ERROR: HandlerBase has no generate_locally. "
41+
f"Available: {methods}", file=sys.stderr, flush=True)
42+
return
43+
3744
os.makedirs(output_dir, exist_ok=True)
3845

39-
_orig_generate = HandlerBase.generate
46+
_orig_generate_locally = HandlerBase.generate_locally
4047
_state = {"step": 0, "started": False, "stopped": False}
4148

4249
_state["profiler"] = torch.profiler.profile(
@@ -49,76 +56,45 @@ def _apply_patch():
4956
on_trace_ready=torch.profiler.tensorboard_trace_handler(output_dir),
5057
)
5158

52-
async def _patched_generate(self, request, *args, **kwargs):
59+
async def _patched_generate_locally(self, request, context, embeddings=None):
5360
_state["step"] += 1
5461
step = _state["step"]
5562

5663
if step == start_step and not _state["started"]:
5764
_state["profiler"].__enter__()
5865
_state["started"] = True
59-
print(f"[trtllm-patch] Step {step}: profiler started", file=sys.stderr, flush=True)
66+
print(f"[trtllm-patch] Step {step}: profiler started",
67+
file=sys.stderr, flush=True)
6068

61-
result = _orig_generate(self, request, *args, **kwargs)
62-
if hasattr(result, '__aiter__'):
63-
async for chunk in result:
64-
yield chunk
65-
else:
66-
yield await result
69+
async for chunk in _orig_generate_locally(self, request, context, embeddings):
70+
yield chunk
6771

6872
if step == stop_step and not _state["stopped"]:
6973
import torch.cuda
7074
torch.cuda.synchronize()
7175
_state["profiler"].__exit__(None, None, None)
7276
_state["stopped"] = True
73-
print(f"[trtllm-patch] Step {step}: profiler stopped, traces in {output_dir}", file=sys.stderr, flush=True)
77+
print(f"[trtllm-patch] Step {step}: profiler stopped, "
78+
f"traces in {output_dir}", file=sys.stderr, flush=True)
7479

75-
HandlerBase.generate = _patched_generate
76-
print(f"[trtllm-patch] Patched HandlerBase.generate (steps {start_step}-{stop_step}, output={output_dir})", file=sys.stderr, flush=True)
80+
HandlerBase.generate_locally = _patched_generate_locally
81+
print(f"[trtllm-patch] Patched HandlerBase.generate_locally "
82+
f"(steps {start_step}-{stop_step}, output={output_dir})",
83+
file=sys.stderr, flush=True)
7784

7885

7986
def patch():
80-
"""Patch HandlerBase.generate with profiler wrapping.
87+
"""Apply the profiling patch to HandlerBase.
8188
82-
When this file is appended to handler_base.py, HandlerBase is already
83-
defined in the current module scope — so we apply the patch directly.
84-
If called standalone before the module is imported, we install a
85-
post-import hook as a fallback.
89+
When appended to handler_base.py, HandlerBase is already defined,
90+
so the import succeeds and we patch immediately.
8691
"""
8792
mode = os.environ.get("PROFILING_MODE", "")
8893
if not mode:
8994
return
9095

91-
# If HandlerBase is already importable (e.g. this code was appended to
92-
# handler_base.py, or the module was already imported), patch immediately.
9396
try:
94-
from dynamo.trtllm.request_handlers.handler_base import HandlerBase # noqa: F401
9597
_apply_patch()
96-
return
97-
except ImportError:
98-
pass
9998
except Exception as e:
100-
print(f"[trtllm-patch] Direct patch failed: {e}", file=sys.stderr, flush=True)
101-
return
102-
103-
# Fallback: install a meta path finder for deferred patching
104-
class _PatchFinder:
105-
_patched = False
106-
107-
def find_module(self, fullname, path=None):
108-
if fullname == "dynamo.trtllm.request_handlers.handler_base" and not self._patched:
109-
return self
110-
return None
111-
112-
def load_module(self, fullname):
113-
self._patched = True
114-
if self in sys.meta_path:
115-
sys.meta_path.remove(self)
116-
mod = importlib.import_module(fullname)
117-
try:
118-
_apply_patch()
119-
except Exception as e:
120-
print(f"[trtllm-patch] Failed: {e}", file=sys.stderr, flush=True)
121-
return mod
122-
123-
sys.meta_path.insert(0, _PatchFinder())
124-
print(f"[trtllm-patch] Installed post-import hook for HandlerBase", file=sys.stderr, flush=True)
99+
print(f"[trtllm-patch] Failed to apply patch: {e}",
100+
file=sys.stderr, flush=True)

0 commit comments

Comments
 (0)