diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index ab22f6893..9d5ce33ed 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -23,7 +23,7 @@ import pandas as pd -from pandera.api.base.model import BaseModel +from pandera.api.base.model import BaseModel, MetaModel from pandera.api.checks import Check from pandera.api.pandas.components import Column, Index, MultiIndex from pandera.api.pandas.container import DataFrameSchema @@ -125,7 +125,21 @@ def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]: return checks -class DataFrameModel(BaseModel): +class MetaDataFrameModel(MetaModel): + """A metaclass for DataFrameModel to provide iter support.""" + + def to_schema(cls) -> DataFrameSchema: + """Create :class:`~pandera.DataFrameSchema` from the class.""" + raise NotImplementedError + + def __iter__(cls) -> Iterable[str]: + """Iterate over the fields of the schema""" + # False positive in metaclass context; pylint: disable=no-value-for-parameter + schema = cls.to_schema() + return iter(schema.columns) + + +class DataFrameModel(BaseModel, metaclass=MetaDataFrameModel): """Definition of a :class:`~pandera.api.pandas.container.DataFrameSchema`. *new in 0.5.0* @@ -151,17 +165,13 @@ class DataFrameModel(BaseModel): @docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__) def __new__(cls, *args, **kwargs) -> DataFrameBase[TDataFrameModel]: # type: ignore [misc] """%(validate_doc)s""" - return cast( - DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs) - ) + return cast(DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs)) def __init_subclass__(cls, **kwargs): """Ensure :class:`~pandera.api.pandas.model_components.FieldInfo` instances.""" if "Config" in cls.__dict__: cls.Config.name = ( - cls.Config.name - if hasattr(cls.Config, "name") - else cls.__name__ + cls.Config.name if hasattr(cls.Config, "name") else cls.__name__ ) else: cls.Config = type("Config", (BaseConfig,), {"name": cls.__name__}) @@ -201,9 +211,7 @@ def __class_getitem__( Type[TDataFrameModel], GENERIC_SCHEMA_CACHE[(cls, params)] ) - param_dict: Dict[TypeVar, Type[Any]] = dict( - zip(__parameters__, params) - ) + param_dict: Dict[TypeVar, Type[Any]] = dict(zip(__parameters__, params)) extra: Dict[str, Any] = {"__annotations__": {}} for field, (annot_info, field_info) in cls._collect_fields().items(): if isinstance(annot_info.arg, TypeVar): @@ -214,9 +222,7 @@ def __class_getitem__( extra["__annotations__"][field] = raw_annot extra[field] = copy.deepcopy(field_info) - parameterized_name = ( - f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]" - ) + parameterized_name = f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]" parameterized_cls = type(parameterized_name, (cls,), extra) GENERIC_SCHEMA_CACHE[(cls, params)] = parameterized_cls return parameterized_cls @@ -323,9 +329,7 @@ def example( **kwargs, ) -> DataFrameBase[TDataFrameModel]: """%(example_doc)s""" - return cast( - DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs) - ) + return cast(DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs)) @classmethod def _build_columns_index( # pylint:disable=too-many-locals @@ -335,8 +339,7 @@ def _build_columns_index( # pylint:disable=too-many-locals **multiindex_kwargs: Any, ) -> Tuple[Dict[str, Column], Optional[Union[Index, MultiIndex]],]: index_count = sum( - annotation.origin in INDEX_TYPES - for annotation, _ in fields.values() + annotation.origin in INDEX_TYPES for annotation, _ in fields.values() ) columns: Dict[str, Column] = {} @@ -385,9 +388,7 @@ def _build_columns_index( # pylint:disable=too-many-locals or annotation.raw_annotation in INDEX_TYPES ): if annotation.optional: - raise SchemaInitError( - f"Index '{field_name}' cannot be Optional." - ) + raise SchemaInitError(f"Index '{field_name}' cannot be Optional.") if check_name is False or ( # default single index @@ -442,7 +443,7 @@ def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]: raise SchemaInitError(f"Found missing annotations: {missing}") fields = {} - for field_name, annotation in annotations.items(): + for field_name, annotation in reversed(annotations.items()): field = attrs[field_name] # __init_subclass__ guarantees existence if not isinstance(field, FieldInfo): raise SchemaInitError( @@ -450,7 +451,7 @@ def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]: + f"not a '{type(field)}.'" ) fields[field.name] = (AnnotationInfo(annotation), field) - return fields + return dict(reversed(fields.items())) @classmethod def _collect_config_and_extras( @@ -458,18 +459,14 @@ def _collect_config_and_extras( ) -> Tuple[Type[BaseConfig], Dict[str, Any]]: """Collect config options from bases, splitting off unknown options.""" bases = inspect.getmro(cls)[:-1] - bases = tuple( - base for base in bases if issubclass(base, DataFrameModel) - ) + bases = tuple(base for base in bases if issubclass(base, DataFrameModel)) root_model, *models = reversed(bases) options, extras = _extract_config_options_and_extras(root_model.Config) for model in models: config = getattr(model, _CONFIG_KEY, {}) - base_options, base_extras = _extract_config_options_and_extras( - config - ) + base_options, base_extras = _extract_config_options_and_extras(config) options.update(base_options) extras.update(base_extras) @@ -482,9 +479,7 @@ def _collect_check_infos(cls, key: str) -> List[CheckInfo]: walk the inheritance tree. """ bases = inspect.getmro(cls)[:-2] # bases -> DataFrameModel -> object - bases = tuple( - base for base in bases if issubclass(base, DataFrameModel) - ) + bases = tuple(base for base in bases if issubclass(base, DataFrameModel)) method_names = set() check_infos = [] @@ -648,7 +643,6 @@ def _field_json_schema(field): "title": dataframe_schema.name or "pandera.DataFrameSchema", "type": "object", "properties": { - field["name"]: _field_json_schema(field) - for field in table_schema["fields"] + field["name"]: _field_json_schema(field) for field in table_schema["fields"] }, } diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9b086915a..b0fa2df93 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1426,3 +1426,85 @@ class Config: } } assert PanderaSchema.get_metadata() == expected + + +def test_iter_fieldnames(): + """ + Test we can iterate over the `DataFrameModel` to get the field names. + """ + + class PanderaSchema(pa.DataFrameModel): + id: Series[int] + product_name: Series[str] + price: Series[float] + + expected = [ + PanderaSchema.id, + PanderaSchema.product_name, + PanderaSchema.price, + ] + assert list(PanderaSchema) == expected + + +def test_iter_fieldnames_inheritance(): + """ + Test iterating over the fieldnames respects the order of inheritance. + """ + + class PanderaSchema1(pa.DataFrameModel): + id: Series[int] + product_name: Series[str] + price: Series[float] + + # Note: order of definition differs from order of inheritance + class PanderaSchema3(pa.DataFrameModel): + quality: Series[str] + + class PanderaSchema2(pa.DataFrameModel): + quantity: Series[int] + + class CombinedSchema(PanderaSchema1, PanderaSchema2, PanderaSchema3): + pass + + expected = [ + PanderaSchema1.id, + PanderaSchema1.product_name, + PanderaSchema1.price, + PanderaSchema2.quantity, + PanderaSchema3.quality, + ] + assert list(CombinedSchema) == expected + + +def test_iter_fieldnames_df_index(): + """ + Test iterating over the fieldnames as a way to index all columns of a dataframe. + """ + + class PanderaSchema(pa.DataFrameModel): + id: Series[int] + product_name: Series[str] + price: Series[float] + + class Config: + order = True + + df = pd.DataFrame( + { + PanderaSchema.price: [1.0, 2.0, 3.0], + PanderaSchema.id: [1, 2, 3], + PanderaSchema.product_name: ["A", "B", "C"], + } + ) + assert df.columns == [ + PanderaSchema.price, + PanderaSchema.id, + PanderaSchema.product_name, + ] + df = df[list(PanderaSchema)].copy() + assert df.columns == [ + PanderaSchema.id, + PanderaSchema.product_name, + PanderaSchema.price, + ] + PanderaSchema.validate(df)