Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/llmcompressor/pipelines/sequential/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
:param module: module whose forward method should be replaced
:param ignore: explicit list of function names to wrap
"""
# check forward method is implemented
if module.forward.__name__ == "_forward_unimplemented":
raise ValueError(
"Cannot calibrate model which does not implement `forward` method. Please "
"either implement a forward method on the model, or pass a submodule to "
"`oneshot`. For example, `oneshot(model.thinker, ...)`"
)

# get source code of module forward
source = inspect.getsource(module.forward)
source = textwrap.dedent(source)
Expand All @@ -64,7 +72,8 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
# compile new forward function from autowrapped code
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
code = compile(source, filename=filename, mode="exec")
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap
with append_autowrap_source_on_fail():
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap

# enable better tracebacks if autowrapped code fails
linecache.cache[filename] = (
Expand Down Expand Up @@ -99,9 +108,9 @@ def append_autowrap_source_on_fail():
for i, line in enumerate(source_lines)
]

message = f"{exception}\n\n"
message += f"\n--- {frame.filename}:{lineno} ---\n"
message = f"--- {frame.filename}:{lineno} ---\n"
message += "".join(source_lines)
message += f"\n\n{exception}"
raise RuntimeError(message) from exception

raise exception