Open
Description
I tried:
weight_config = CastConfig(
scaling_type=ScalingType.DELAYED
)
config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather, # same as enable_fsdp_float8_all_gather
cast_config_weight=weight_config,
)
convert_to_float8_training(model, config=config)
and then:
model = FSDP(model, **kwargs)
pytorch throw: AttributeError: 'WeightWithDelayedFloat8CastTensor' object has no attribute '_tensor'. Did you mean: 'new_tensor'?
Activity