Skip to content

EfficientNet weight_class_name_fn() gives Efficientnet_B0_Weights instead of EfficientNet_B0_Weights and breaks test #293

@kmabeeTT

Description

@kmabeeTT

Bug in here from #272 today, it doesn't work as described by the comment in the function:

                def weight_class_name_fn(name: str) -> str:
                    # Handle efficientnet_b0 -> EfficientNet_B0_Weights
                    # Split by underscore and capitalize each part, then join with underscore
                    parts = name.split("_")
                    # Capitalize first letter of each part and join with underscore
                    capitalized_parts = [p.capitalize() for p in parts]
                    return "_".join(capitalized_parts) + "_Weights"

It should use EfficientNet_B0_Weights not Efficientnet_B0_Weights.

I don't see reasonable workaround except special casing this right now as quick fix for tt-xla main breaks. Going to use this for now as temporary fix to solve regression and leave this ticket open for ~better fix if we want one.

                    # Workaround because this logic does not work as advertised.
                    capitalized_parts = [p.replace("Efficientnet", "EfficientNet") for p in capitalized_parts]

Error was:

tests/runner/test_models.py::test_all_models_torch[efficientnet/pytorch-efficientnet_b0-single_device-full-inference]

>           raise AttributeError(
                f"Weight class '{weight_class_name}' not found in torchvision.models. "
                f"Model name: {self.model_name}. "
                f"Available weight classes: {sorted(available_weights)[:10]}..."
                f"(showing first 10 of {len(available_weights)}). "
                f"Please check your weight_class_name_fn implementation."
            )
E           AttributeError: Weight class 'Efficientnet_B0_Weights' not found in torchvision.models. Model name: efficientnet_b0. Available weight classes: ['AlexNet_Weights', 'ConvNeXt_Base_Weights', 'ConvNeXt_Large_Weights', 'ConvNeXt_Small_Weights', 'ConvNeXt_Tiny_Weights', 'DenseNet121_Weights', 'DenseNet161_Weights', 'DenseNet169_Weights', 'DenseNet201_Weights', 'EfficientNet_B0_Weights']...(showing first 10 of 80). Please check your weight_class_name_fn implementation.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions