Skip to content

Ensuring correct device dispatch after module replacement #1308

@yiliu30

Description

@yiliu30

When loading a model across multiple devices using device_map="auto", the accelerate library attaches hooks to move module inputs to the correct devices automatically.
For certain models (for example, gpt‑oss), we replace the original MoE modules during initialization. However, after this replacement, the newly inserted modules do not have the required Accelerate hooks attached. As a result, input tensors may not be dispatched to the correct devices.
We need to investigate when Accelerate binds these hooks during the model loading process. Based on that, we should re-dispatch (or re-attach hooks to) the replaced modules if necessary to ensure correct multi-device execution.

If needed, we can dispatch the replaced model with dispatch_model:

from accelerate.big_modeling import dispatch_model

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions