diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 7bfe1ec8e67..c2ee9579f2e 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1507,7 +1507,11 @@ def infer_auto_device_map( # -> split, we replace the module studied by its children + parameters if verbose: print(f"Splitting {name}.") - modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_children = ( + list(module.named_parameters(recurse=False)) + + list(module.named_buffers(recurse=False)) + + modules_children + ) modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat # Update the max layer size. max_layer_size, max_layer_names = get_max_layer_size(