[diffusion] refactor: separate runtime metadata from arch config#22678
[diffusion] refactor: separate runtime metadata from arch config#22678
Conversation
This reverts commit 61c2845.
There was a problem hiding this comment.
Code Review
This pull request refactors the multimodal model configuration system by migrating architectural fields from ArchConfig to ModelConfig and implementing a structured refresh mechanism for derived fields. It also introduces support for the LTX-2.3 model, including a new materialization overlay and specialized sampling parameter logic for two-stage pipelines. Review feedback highlights the need for a more robust check when detecting legacy post_init methods, recommends moving model-specific logic out of generic sampling utilities to prevent side effects, and suggests explicitly specifying UTF-8 encoding for file I/O in the new materialization script.
| legacy_post_init = type(arch_config).__dict__.get("__post_init__") | ||
| if ( | ||
| legacy_post_init is not None | ||
| and legacy_post_init is not ArchConfig.__post_init__ | ||
| ): | ||
| legacy_post_init(arch_config) |
There was a problem hiding this comment.
The check type(arch_config).__dict__.get("__post_init__") is fragile because it only looks at the immediate class's dictionary. If a custom ArchConfig inherits its __post_init__ from a parent class (other than the base ArchConfig), it won't be detected here. A more robust check would be to compare the bound method's underlying function to the base implementation.
if arch_config.__post_init__.__func__ is not ArchConfig.__post_init__:
arch_config.__post_init__()
else:
arch_config.refresh_derived_fields()| if ( | ||
| getattr(server_args, "pipeline_class_name", None) == "LTX2TwoStagePipeline" | ||
| and sampling_params.__class__.__name__ == "LTX23SamplingParams" | ||
| ): | ||
| if "height" not in user_kwargs and sampling_params.height is not None: | ||
| sampling_params.height *= 2 | ||
| if "width" not in user_kwargs and sampling_params.width is not None: | ||
| sampling_params.width *= 2 | ||
|
|
There was a problem hiding this comment.
This block introduces model-specific logic into a generic utility method. It also modifies the sampling_params argument in-place, which is unexpected for a factory method and could lead to bugs if the default sampling_params instance is shared across requests. Consider moving this logic into a specialized method within LTX23SamplingParams or handling it during pipeline initialization.
| def _load_json(path: str) -> dict: | ||
| with open(path) as f: | ||
| return json.load(f) | ||
|
|
||
|
|
||
| def _write_json(path: str, payload: dict) -> None: | ||
| with open(path, "w") as f: | ||
| json.dump(payload, f, indent=2) | ||
| f.write("\n") | ||
|
|
There was a problem hiding this comment.
Specify encoding="utf-8" when opening files for reading or writing to ensure consistent behavior across different operating systems and environments.
| def _load_json(path: str) -> dict: | |
| with open(path) as f: | |
| return json.load(f) | |
| def _write_json(path: str, payload: dict) -> None: | |
| with open(path, "w") as f: | |
| json.dump(payload, f, indent=2) | |
| f.write("\n") | |
| def _load_json(path: str) -> dict: | |
| with open(path, encoding="utf-8") as f: | |
| return json.load(f) | |
| def _write_json(path: str, payload: dict) -> None: | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(payload, f, indent=2) | |
| f.write("\n") |
Motivation
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci