From efc53b6cb656b63e73dd37c307bef513f8907a46 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Sat, 5 Aug 2023 22:52:39 +1200 Subject: [PATCH 1/6] Update model.py Signed-off-by: Nathan McDougall --- pandera/api/pandas/model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index ab22f6893..96b9b7755 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -124,8 +124,18 @@ def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]: return checks +class MetaDataFrameModel(MetaModel): + """A metaclass for DataFrameModel to provide iter support.""" + + def to_schema(cls) -> DataFrameSchema: + """Create :class:`~pandera.DataFrameSchema` from the class.""" + raise NotImplementedError + + def __iter__(cls) -> Iterable[str]: + """Iterate over the fields of the schema""" + return iter(cls.to_schema().columns) -class DataFrameModel(BaseModel): +class DataFrameModel(BaseModel, metaclass=MetaDataFrameModel): """Definition of a :class:`~pandera.api.pandas.container.DataFrameSchema`. *new in 0.5.0* From 0e36018ab439545f000b85a3d29f9d9043bd9298 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Sat, 5 Aug 2023 23:28:43 +1200 Subject: [PATCH 2/6] Add missing import, and a docstring Signed-off-by: Nathan McDougall --- pandera/api/pandas/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index 96b9b7755..64f189f7f 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -23,7 +23,7 @@ import pandas as pd -from pandera.api.base.model import BaseModel +from pandera.api.base.model import BaseModel, MetaModel from pandera.api.checks import Check from pandera.api.pandas.components import Column, Index, MultiIndex from pandera.api.pandas.container import DataFrameSchema @@ -126,14 +126,16 @@ def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]: class MetaDataFrameModel(MetaModel): """A metaclass for DataFrameModel to provide iter support.""" - + def to_schema(cls) -> DataFrameSchema: - """Create :class:`~pandera.DataFrameSchema` from the class.""" + """Create :class:`~pandera.DataFrameSchema` from the class.""" raise NotImplementedError def __iter__(cls) -> Iterable[str]: """Iterate over the fields of the schema""" - return iter(cls.to_schema().columns) + # False positive in metaclass context; pylint: disable=no-value-for-parameter + schema = cls.to_schema() + return iter(schema.columns) class DataFrameModel(BaseModel, metaclass=MetaDataFrameModel): """Definition of a :class:`~pandera.api.pandas.container.DataFrameSchema`. From a05d155749c19470bb088de86d24c5bdf213d610 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Sat, 5 Aug 2023 23:29:08 +1200 Subject: [PATCH 3/6] Add tests Signed-off-by: Nathan McDougall --- tests/core/test_model.py | 70 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9b086915a..c7e4ef9e2 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1426,3 +1426,73 @@ class Config: } } assert PanderaSchema.get_metadata() == expected + +def test_iter_fieldnames(): + """ + Test we can iterate over the `DataFrameModel` to get the field names. + """ + + class PanderaSchema(pa.DataFrameModel): + id: Series[int] + product_name: Series[str] + price: Series[float] + + class Config: + name = "product_info" + strict = True + coerce = True + + expected = ["id", "product_name", "price"] + assert list(PanderaSchema) == expected + + +def test_iter_fieldnames_inheritance(): + """ + Test iterating over the fieldnames respects the order of inheritance. + """ + + class PanderaSchema1(pa.DataFrameModel): + id: Series[int] + product_name: Series[str] + price: Series[float] + + # Note: order of definition differs from order of inheritance + class PanderaSchema3(pa.DataFrameModel): + quality: Series[str] + + class PanderaSchema2(pa.DataFrameModel): + quantity: Series[int] + + class CombinedSchema(PanderaSchema1, PanderaSchema2, PanderaSchema3): + pass + + expected = ["id", "product_name", "price", "quantity", "quality"] + assert list(CombinedSchema) == expected + +def test_iter_fieldnames_df_index(): + """ + Test iterating over the fieldnames as a way to index all columns of a dataframe. + """ + + class PanderaSchema(pa.DataFrameModel): + id: Series[int] + product_name: Series[str] + price: Series[float] + + class Config: + order = True + + df = pd.DataFrame( + { + PanderaSchema.price: [1.0, 2.0, 3.0], + PanderaSchema.id: [1, 2, 3], + PanderaSchema.product_name: ["A", "B", "C"], + } + ) + assert df.columns == [ + PanderaSchema.price, PanderaSchema.id, PanderaSchema.product_name + ] + df = df[list(PanderaSchema)].copy() + assert df.columns == [ + PanderaSchema.id, PanderaSchema.product_name, PanderaSchema.price + ] From b117492e5345cd1d7c627128ec4ad34930e057a4 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Sat, 5 Aug 2023 23:50:56 +1200 Subject: [PATCH 4/6] Black formatting use class instead of hard-coded strings Signed-off-by: Nathan McDougall --- pandera/api/pandas/model.py | 2 ++ tests/core/test_model.py | 29 ++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index 64f189f7f..d5ff0e6b9 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -124,6 +124,7 @@ def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]: return checks + class MetaDataFrameModel(MetaModel): """A metaclass for DataFrameModel to provide iter support.""" @@ -137,6 +138,7 @@ def __iter__(cls) -> Iterable[str]: schema = cls.to_schema() return iter(schema.columns) + class DataFrameModel(BaseModel, metaclass=MetaDataFrameModel): """Definition of a :class:`~pandera.api.pandas.container.DataFrameSchema`. diff --git a/tests/core/test_model.py b/tests/core/test_model.py index c7e4ef9e2..737a74eeb 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1427,6 +1427,7 @@ class Config: } assert PanderaSchema.get_metadata() == expected + def test_iter_fieldnames(): """ Test we can iterate over the `DataFrameModel` to get the field names. @@ -1437,12 +1438,11 @@ class PanderaSchema(pa.DataFrameModel): product_name: Series[str] price: Series[float] - class Config: - name = "product_info" - strict = True - coerce = True - - expected = ["id", "product_name", "price"] + expected = [ + PanderaSchema.id, + PanderaSchema.product_name, + PanderaSchema.price, + ] assert list(PanderaSchema) == expected @@ -1466,9 +1466,16 @@ class PanderaSchema2(pa.DataFrameModel): class CombinedSchema(PanderaSchema1, PanderaSchema2, PanderaSchema3): pass - expected = ["id", "product_name", "price", "quantity", "quality"] + expected = [ + PanderaSchema1.id, + PanderaSchema1.product_name, + PanderaSchema1.price, + PanderaSchema2.quantity, + PanderaSchema3.quality, + ] assert list(CombinedSchema) == expected + def test_iter_fieldnames_df_index(): """ Test iterating over the fieldnames as a way to index all columns of a dataframe. @@ -1490,9 +1497,13 @@ class Config: } ) assert df.columns == [ - PanderaSchema.price, PanderaSchema.id, PanderaSchema.product_name + PanderaSchema.price, + PanderaSchema.id, + PanderaSchema.product_name, ] df = df[list(PanderaSchema)].copy() assert df.columns == [ - PanderaSchema.id, PanderaSchema.product_name, PanderaSchema.price + PanderaSchema.id, + PanderaSchema.product_name, + PanderaSchema.price, ] From 98702dd4b03f91e117293d1b61ec20df106d05ff Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Sun, 6 Aug 2023 01:09:53 +1200 Subject: [PATCH 5/6] Add validation to test Signed-off-by: Nathan McDougall --- tests/core/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 737a74eeb..b0fa2df93 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1507,3 +1507,4 @@ class Config: PanderaSchema.product_name, PanderaSchema.price, ] + PanderaSchema.validate(df) From 9036b9743a7d0d37bb23c9f61779ee430948a276 Mon Sep 17 00:00:00 2001 From: Nathan McDougall Date: Thu, 17 Aug 2023 08:54:45 +1200 Subject: [PATCH 6/6] Maintain order of fields in order of inheritance Signed-off-by: Nathan McDougall --- pandera/api/pandas/model.py | 46 +++++++++++-------------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index d5ff0e6b9..9d5ce33ed 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -165,17 +165,13 @@ class DataFrameModel(BaseModel, metaclass=MetaDataFrameModel): @docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__) def __new__(cls, *args, **kwargs) -> DataFrameBase[TDataFrameModel]: # type: ignore [misc] """%(validate_doc)s""" - return cast( - DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs) - ) + return cast(DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs)) def __init_subclass__(cls, **kwargs): """Ensure :class:`~pandera.api.pandas.model_components.FieldInfo` instances.""" if "Config" in cls.__dict__: cls.Config.name = ( - cls.Config.name - if hasattr(cls.Config, "name") - else cls.__name__ + cls.Config.name if hasattr(cls.Config, "name") else cls.__name__ ) else: cls.Config = type("Config", (BaseConfig,), {"name": cls.__name__}) @@ -215,9 +211,7 @@ def __class_getitem__( Type[TDataFrameModel], GENERIC_SCHEMA_CACHE[(cls, params)] ) - param_dict: Dict[TypeVar, Type[Any]] = dict( - zip(__parameters__, params) - ) + param_dict: Dict[TypeVar, Type[Any]] = dict(zip(__parameters__, params)) extra: Dict[str, Any] = {"__annotations__": {}} for field, (annot_info, field_info) in cls._collect_fields().items(): if isinstance(annot_info.arg, TypeVar): @@ -228,9 +222,7 @@ def __class_getitem__( extra["__annotations__"][field] = raw_annot extra[field] = copy.deepcopy(field_info) - parameterized_name = ( - f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]" - ) + parameterized_name = f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]" parameterized_cls = type(parameterized_name, (cls,), extra) GENERIC_SCHEMA_CACHE[(cls, params)] = parameterized_cls return parameterized_cls @@ -337,9 +329,7 @@ def example( **kwargs, ) -> DataFrameBase[TDataFrameModel]: """%(example_doc)s""" - return cast( - DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs) - ) + return cast(DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs)) @classmethod def _build_columns_index( # pylint:disable=too-many-locals @@ -349,8 +339,7 @@ def _build_columns_index( # pylint:disable=too-many-locals **multiindex_kwargs: Any, ) -> Tuple[Dict[str, Column], Optional[Union[Index, MultiIndex]],]: index_count = sum( - annotation.origin in INDEX_TYPES - for annotation, _ in fields.values() + annotation.origin in INDEX_TYPES for annotation, _ in fields.values() ) columns: Dict[str, Column] = {} @@ -399,9 +388,7 @@ def _build_columns_index( # pylint:disable=too-many-locals or annotation.raw_annotation in INDEX_TYPES ): if annotation.optional: - raise SchemaInitError( - f"Index '{field_name}' cannot be Optional." - ) + raise SchemaInitError(f"Index '{field_name}' cannot be Optional.") if check_name is False or ( # default single index @@ -456,7 +443,7 @@ def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]: raise SchemaInitError(f"Found missing annotations: {missing}") fields = {} - for field_name, annotation in annotations.items(): + for field_name, annotation in reversed(annotations.items()): field = attrs[field_name] # __init_subclass__ guarantees existence if not isinstance(field, FieldInfo): raise SchemaInitError( @@ -464,7 +451,7 @@ def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]: + f"not a '{type(field)}.'" ) fields[field.name] = (AnnotationInfo(annotation), field) - return fields + return dict(reversed(fields.items())) @classmethod def _collect_config_and_extras( @@ -472,18 +459,14 @@ def _collect_config_and_extras( ) -> Tuple[Type[BaseConfig], Dict[str, Any]]: """Collect config options from bases, splitting off unknown options.""" bases = inspect.getmro(cls)[:-1] - bases = tuple( - base for base in bases if issubclass(base, DataFrameModel) - ) + bases = tuple(base for base in bases if issubclass(base, DataFrameModel)) root_model, *models = reversed(bases) options, extras = _extract_config_options_and_extras(root_model.Config) for model in models: config = getattr(model, _CONFIG_KEY, {}) - base_options, base_extras = _extract_config_options_and_extras( - config - ) + base_options, base_extras = _extract_config_options_and_extras(config) options.update(base_options) extras.update(base_extras) @@ -496,9 +479,7 @@ def _collect_check_infos(cls, key: str) -> List[CheckInfo]: walk the inheritance tree. """ bases = inspect.getmro(cls)[:-2] # bases -> DataFrameModel -> object - bases = tuple( - base for base in bases if issubclass(base, DataFrameModel) - ) + bases = tuple(base for base in bases if issubclass(base, DataFrameModel)) method_names = set() check_infos = [] @@ -662,7 +643,6 @@ def _field_json_schema(field): "title": dataframe_schema.name or "pandera.DataFrameSchema", "type": "object", "properties": { - field["name"]: _field_json_schema(field) - for field in table_schema["fields"] + field["name"]: _field_json_schema(field) for field in table_schema["fields"] }, }