Skip to content

Float8 with FSDP and DelayedScaling: 'WeightWithDelayedFloat8CastTensor' object has no attribute '_tensor'. #1605

Open
@fmo-mt

Description

@fmo-mt

I tried:

weight_config = CastConfig(
     scaling_type=ScalingType.DELAYED
)
config = Float8LinearConfig(
        enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
        force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather, # same as enable_fsdp_float8_all_gather
        cast_config_weight=weight_config,
)
convert_to_float8_training(model, config=config)

and then:

model = FSDP(model, **kwargs)

pytorch throw: AttributeError: 'WeightWithDelayedFloat8CastTensor' object has no attribute '_tensor'. Did you mean: 'new_tensor'?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions