Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler:


def data_collator_with_truncation(
features: list[dict[str, Any]], return_tensors: str = "pt"
) -> dict[str, Any]:
features: list[dict], return_tensors: str = "pt"
) -> dict:
for key in ("input_ids", "labels", "attention_mask"):
if any(key not in feature for feature in features):
continue
Expand Down
31 changes: 20 additions & 11 deletions src/llmcompressor/pipelines/sequential/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
:param module: module whose forward method should be replaced
:param ignore: explicit list of function names to wrap
"""
# check forward method is implemented
if module.forward.__name__ == "_forward_unimplemented":
raise ValueError(
"Cannot calibrate model which does not implement `forward` method. Please "
"either implement a forward method on the model, or pass a submodule to "
"`oneshot`. For example, `oneshot(model.thinker, ...)`"
)

# get source code of module forward
source = inspect.getsource(module.forward)
source = textwrap.dedent(source)
Expand All @@ -64,15 +72,16 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
# compile new forward function from autowrapped code
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
code = compile(source, filename=filename, mode="exec")
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap
with append_autowrap_source_on_fail():
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap

# enable better tracebacks if autowrapped code fails
linecache.cache[filename] = (
len(source),
None,
[line + "\n" for line in source.splitlines()],
filename,
)
# enable better tracebacks if autowrapped code fails
linecache.cache[filename] = (
len(source),
None,
[line + "\n" for line in source.splitlines()],
filename,
)

# patch forward with autowrapped forward
new_forward = namespace["forward"].__get__(module)
Expand All @@ -99,9 +108,9 @@ def append_autowrap_source_on_fail():
for i, line in enumerate(source_lines)
]

message = f"{exception}\n\n"
message += f"\n--- {frame.filename}:{lineno} ---\n"
message = f"--- {frame.filename}:{lineno} ---\n"
message += "".join(source_lines)
raise RuntimeError(message) from exception
message += f"\n\n{exception}"
raise RuntimeError(message) from exc_tb

raise exception
Loading