diff --git a/src/dremioai/api/dremio/catalog.py b/src/dremioai/api/dremio/catalog.py index 68b5d82..11af23e 100644 --- a/src/dremioai/api/dremio/catalog.py +++ b/src/dremioai/api/dremio/catalog.py @@ -48,10 +48,10 @@ class DatasetSubType(UStrEnum): DIRECT = auto() -def subset_validator(elem: UStrEnum, values: List[UStrEnum]) -> UStrEnum: +def subset_validator(elem: UStrEnum, values: UStrEnum | List[UStrEnum]) -> UStrEnum: if elem in values: return elem - raise ValidationError(f"{elem} not in {values}") + raise ValueError(f"{elem} not in {values}") class LineageBase(BaseModel): @@ -64,7 +64,7 @@ class LineageBase(BaseModel): class LineageSource(LineageBase): type: Annotated[ CatalogItemType, - AfterValidator(partial(subset_validator, values=[CatalogItemType.CONTAINER])), + AfterValidator(partial(subset_validator, values=CatalogItemType)), ] container_type: Annotated[ ContainerSubType, @@ -80,15 +80,12 @@ class LineageSource(LineageBase): class LineageParents(LineageBase): type: Annotated[ CatalogItemType, - AfterValidator(partial(subset_validator, values=[CatalogItemType.DATASET])), + AfterValidator(partial(subset_validator, values=CatalogItemType)), ] dataset_type: Annotated[ DatasetSubType, AfterValidator( - partial( - subset_validator, - values=[DatasetSubType.PROMOTED, DatasetSubType.VIRTUAL], - ) + partial(subset_validator, values=DatasetSubType), ), ] = Field(..., alias="datasetType") @@ -96,7 +93,7 @@ class LineageParents(LineageBase): class LineageChildren(LineageBase): type: Annotated[ CatalogItemType, - AfterValidator(partial(subset_validator, values=[CatalogItemType.DATASET])), + AfterValidator(partial(subset_validator, values=CatalogItemType)), ] dataset_type: Annotated[ DatasetSubType,