diff --git a/changes/1169.added b/changes/1169.added new file mode 100644 index 000000000..b3ee33f68 --- /dev/null +++ b/changes/1169.added @@ -0,0 +1,2 @@ +Added `AttributeType` enum class. +Added class method to `NautobotModel` to get enum type of given attribute. \ No newline at end of file diff --git a/changes/1169.changed b/changes/1169.changed new file mode 100644 index 000000000..c216d88e5 --- /dev/null +++ b/changes/1169.changed @@ -0,0 +1 @@ +Changed `_handle_single_paramater` in `NautobotAdapter` to utilize `AttributeType` enum class. \ No newline at end of file diff --git a/changes/1169.fixed b/changes/1169.fixed new file mode 100644 index 000000000..448e891f8 --- /dev/null +++ b/changes/1169.fixed @@ -0,0 +1,2 @@ +Fixed bug in `DiffSyncModelUtilityMixin.get_attr_args()` for Annoted type hints wrapped in `Optional[]` tag. +Fix hashing issue on Custom Annotations. \ No newline at end of file diff --git a/nautobot_ssot/contrib/adapter.py b/nautobot_ssot/contrib/adapter.py index d8cba1633..d0598d13d 100644 --- a/nautobot_ssot/contrib/adapter.py +++ b/nautobot_ssot/contrib/adapter.py @@ -16,8 +16,8 @@ from nautobot.extras.models.metadata import MetadataType from nautobot_ssot.contrib.base import BaseNautobotAdapter, BaseNautobotModel +from nautobot_ssot.contrib.enums import AttributeType from nautobot_ssot.contrib.types import ( - CustomFieldAnnotation, CustomRelationshipAnnotation, RelationshipSideEnum, ) @@ -31,8 +31,7 @@ class NautobotAdapter(Adapter, BaseNautobotAdapter): - """ - Adapter for loading data from Nautobot through the ORM. + """Adapter for loading data from Nautobot through the ORM. This adapter is able to infer how to load data from Nautobot based on how the models attached to it are defined. """ @@ -76,52 +75,36 @@ def _load_objects(self, diffsync_model: BaseNautobotModel): self._load_single_object(database_object, diffsync_model, parameter_names) def _handle_single_parameter(self, parameters, parameter_name, database_object, diffsync_model): - # Handle custom fields and custom relationships. See CustomFieldAnnotation and CustomRelationshipAnnotation - # docstrings for more details. - annotation = diffsync_model.get_attr_annotation(parameter_name) - if isinstance(annotation, CustomFieldAnnotation): - field_key = annotation.key or annotation.name - if field_key in database_object.cf: - parameters[parameter_name] = database_object.cf[field_key] + # Handle parameter overrides + # TODO: Overrides would be better suited in the associated Model. + if hasattr(self, f"load_param_{parameter_name}"): + parameters[parameter_name] = getattr(self, f"load_param_{parameter_name}")(parameter_name, database_object) return - is_custom_relationship = isinstance(annotation, CustomRelationshipAnnotation) - - # Handling of foreign keys where the local side is the many and the remote side the one. - # Note: This includes the side of a generic foreign key that has the foreign key, i.e. - # the 'many' side. - if "__" in parameter_name: - if is_custom_relationship: + match diffsync_model.get_attr_enum(parameter_name): + case AttributeType.STANDARD: + parameters[parameter_name] = getattr(database_object, parameter_name) + case AttributeType.FOREIGN_KEY: + parameters[parameter_name] = orm_attribute_lookup(database_object, parameter_name) + case AttributeType.N_TO_MANY_RELATIONSHIP: + parameters[parameter_name] = self._handle_to_many_relationship( + database_object, diffsync_model, parameter_name + ) + case AttributeType.CUSTOM_FIELD: + annotation = diffsync_model.get_attr_annotation(parameter_name) + field_key = annotation.key or annotation.name + if field_key in database_object.cf: + parameters[parameter_name] = database_object.cf[field_key] + case AttributeType.CUSTOM_FOREIGN_KEY: parameters[parameter_name] = self._handle_custom_relationship_foreign_key( - database_object, parameter_name, annotation + database_object, + parameter_name, + diffsync_model.get_attr_annotation(parameter_name), + ) + case AttributeType.CUSTOM_N_TO_MANY_RELATIONSHIP: + parameters[parameter_name] = self._handle_custom_relationship_to_many_relationship( + database_object, diffsync_model, parameter_name, diffsync_model.get_attr_annotation(parameter_name) ) - else: - parameters[parameter_name] = orm_attribute_lookup(database_object, parameter_name) - return - - # Handling of one- and many-to custom relationship fields: - if annotation: - parameters[parameter_name] = self._handle_custom_relationship_to_many_relationship( - database_object, diffsync_model, parameter_name, annotation - ) - return - - database_field = diffsync_model._model._meta.get_field(parameter_name) - - # Handling of one- and many-to-many non-custom relationship fields. - # Note: This includes the side of a generic foreign key that constitutes the foreign key, - # i.e. the 'one' side. - if database_field.many_to_many or database_field.one_to_many: - parameters[parameter_name] = self._handle_to_many_relationship( - database_object, diffsync_model, parameter_name - ) - return - - # Handling of normal fields - as this is the default case, set the attribute directly. - if hasattr(self, f"load_param_{parameter_name}"): - parameters[parameter_name] = getattr(self, f"load_param_{parameter_name}")(parameter_name, database_object) - else: - parameters[parameter_name] = getattr(database_object, parameter_name) def _load_single_object(self, database_object, diffsync_model, parameter_names): """Load a single diffsync object from a single database object.""" diff --git a/nautobot_ssot/contrib/enums.py b/nautobot_ssot/contrib/enums.py index 7a99f5472..a3525e347 100644 --- a/nautobot_ssot/contrib/enums.py +++ b/nautobot_ssot/contrib/enums.py @@ -1,6 +1,6 @@ """Enums used in SSoT contrib processes.""" -from enum import Enum +from enum import Enum, auto class SortType(Enum): @@ -17,3 +17,14 @@ class RelationshipSideEnum(Enum): SOURCE = "SOURCE" DESTINATION = "DESTINATION" + + +class AttributeType(Enum): + """Enum for identifying DiffSync model attribute types as used in contrib.""" + + STANDARD = auto() + FOREIGN_KEY = auto() + N_TO_MANY_RELATIONSHIP = auto() + CUSTOM_FIELD = auto() + CUSTOM_FOREIGN_KEY = auto() + CUSTOM_N_TO_MANY_RELATIONSHIP = auto() diff --git a/nautobot_ssot/contrib/model.py b/nautobot_ssot/contrib/model.py index 29e22431a..2529f23b1 100644 --- a/nautobot_ssot/contrib/model.py +++ b/nautobot_ssot/contrib/model.py @@ -5,6 +5,7 @@ from collections import defaultdict from datetime import datetime +from functools import lru_cache from diffsync import DiffSyncModel from diffsync.exceptions import ObjectCrudException, ObjectNotCreated, ObjectNotDeleted, ObjectNotUpdated @@ -16,9 +17,9 @@ from nautobot.extras.models.metadata import ObjectMetadata from nautobot_ssot.contrib.base import BaseNautobotModel +from nautobot_ssot.contrib.enums import AttributeType from nautobot_ssot.contrib.types import ( CustomFieldAnnotation, - CustomRelationshipAnnotation, RelationshipSideEnum, ) from nautobot_ssot.utils.diffsync import DiffSyncModelUtilityMixin @@ -115,6 +116,24 @@ def create(cls, adapter, ids, attrs): return super().create(adapter, ids, attrs) + @classmethod + @lru_cache + def get_attr_enum(cls, attr_name: str) -> AttributeType: + """Return `AttributeType` enum value for type hinted attribute.""" + annotation = cls.get_attr_annotation(attr_name) + if isinstance(annotation, CustomFieldAnnotation): + return AttributeType.CUSTOM_FIELD + if "__" in attr_name: + if annotation: + return AttributeType.CUSTOM_FOREIGN_KEY + return AttributeType.FOREIGN_KEY + if annotation: + return AttributeType.CUSTOM_N_TO_MANY_RELATIONSHIP + django_field = cls._model._meta.get_field(attr_name) + if django_field.many_to_many or django_field.one_to_many: + return AttributeType.N_TO_MANY_RELATIONSHIP + return AttributeType.STANDARD + @classmethod def _handle_single_field(cls, field, obj, value, relationship_fields, adapter): # pylint: disable=too-many-arguments,too-many-locals, too-many-branches """Set a single field on a Django object to a given value, or, for relationship fields, prepare setting. @@ -128,88 +147,63 @@ def _handle_single_field(cls, field, obj, value, relationship_fields, adapter): """ cls._check_field(field) - # Handle custom fields. See CustomFieldAnnotation docstring for more details. annotation = cls.get_attr_annotation(field) - if isinstance(annotation, CustomFieldAnnotation): - obj.cf[annotation.key] = value - return - - custom_relationship_annotation = annotation if isinstance(annotation, CustomRelationshipAnnotation) else None - - # Prepare handling of foreign keys and custom relationship foreign keys. - # Example: If field is `tenant__group__name`, then - # `foreign_keys["tenant"]["group__name"] = value` or - # `custom_relationship_foreign_keys["tenant"]["group__name"] = value` - # Also, the model class will be added to the dictionary for normal foreign keys, so we can later use it - # for querying: - # `foreign_keys["tenant"]["_model_class"] = nautobot.tenancy.models.Tenant - # For custom relationship foreign keys, we add the annotation instead: - # `custom_relationship_foreign_keys["tenant"]["_annotation"] = CustomRelationshipAnnotation(...) - if "__" in field: - related_model, lookup = field.split("__", maxsplit=1) - # Custom relationship foreign keys - if custom_relationship_annotation: - relationship_fields["custom_relationship_foreign_keys"][related_model][lookup] = value - relationship_fields["custom_relationship_foreign_keys"][related_model]["_annotation"] = ( - custom_relationship_annotation - ) - # Normal foreign keys - else: + match cls.get_attr_enum(field): + case AttributeType.STANDARD: + setattr(obj, field, value) + case AttributeType.FOREIGN_KEY: + related_model, lookup = field.split("__", maxsplit=1) django_field = cls._model._meta.get_field(related_model) relationship_fields["foreign_keys"][related_model][lookup] = value # Add a special key to the dictionary to point to the related model's class relationship_fields["foreign_keys"][related_model]["_model_class"] = django_field.related_model - return - - # Prepare handling of custom relationship many-to-many fields. - if custom_relationship_annotation: - relationship = adapter.get_from_orm_cache({"label": custom_relationship_annotation.name}, Relationship) - if custom_relationship_annotation.side == RelationshipSideEnum.DESTINATION: - related_object_content_type = relationship.source_type - else: - related_object_content_type = relationship.destination_type - related_model_class = related_object_content_type.model_class() - if ( - relationship.type == RelationshipTypeChoices.TYPE_ONE_TO_MANY - and custom_relationship_annotation.side == RelationshipSideEnum.DESTINATION - ): - relationship_fields["custom_relationship_foreign_keys"][field] = { - **value, - "_annotation": custom_relationship_annotation, - } - else: - relationship_fields["custom_relationship_many_to_many_fields"][field] = { - "annotation": custom_relationship_annotation, - "objects": [adapter.get_from_orm_cache(parameters, related_model_class) for parameters in value], - } - - return - - django_field = cls._model._meta.get_field(field) - - # Prepare handling of many-to-many fields. If we are dealing with a many-to-many field, - # we get all the related objects here to later set them once the object has been saved. - if django_field.many_to_many or django_field.one_to_many: - try: - relationship_fields["many_to_many_fields"][field] = [ - adapter.get_from_orm_cache(parameters, django_field.related_model) for parameters in value - ] - except django_field.related_model.DoesNotExist as error: - raise ObjectCrudException( - f"Unable to populate many to many relationship '{django_field.name}' with parameters {value}, at least one related object not found." - ) from error - except MultipleObjectsReturned as error: - raise ObjectCrudException( - f"Unable to populate many to many relationship '{django_field.name}' with parameters {value}, at least one related object found twice." - ) from error - return - - # As the default case, just set the attribute directly - setattr(obj, field, value) + case AttributeType.N_TO_MANY_RELATIONSHIP: + django_field = cls._model._meta.get_field(field) + try: + relationship_fields["many_to_many_fields"][field] = [ + adapter.get_from_orm_cache(parameters, django_field.related_model) for parameters in value + ] + except django_field.related_model.DoesNotExist as error: + raise ObjectCrudException( + f"Unable to populate many to many relationship '{django_field.name}' with parameters {value}, at least one related object not found." + ) from error + except MultipleObjectsReturned as error: + raise ObjectCrudException( + f"Unable to populate many to many relationship '{django_field.name}' with parameters {value}, at least one related object found twice." + ) from error + case AttributeType.CUSTOM_FIELD: + obj.cf[annotation.key] = value + case AttributeType.CUSTOM_FOREIGN_KEY: + related_model, lookup = field.split("__", maxsplit=1) + relationship_fields["custom_relationship_foreign_keys"][related_model][lookup] = value + relationship_fields["custom_relationship_foreign_keys"][related_model]["_annotation"] = annotation + case AttributeType.CUSTOM_N_TO_MANY_RELATIONSHIP: + relationship = adapter.get_from_orm_cache({"label": annotation.name}, Relationship) + if annotation.side == RelationshipSideEnum.DESTINATION: + related_object_content_type = relationship.source_type + else: + related_object_content_type = relationship.destination_type + related_model_class = related_object_content_type.model_class() + if ( + relationship.type == RelationshipTypeChoices.TYPE_ONE_TO_MANY + and annotation.side == RelationshipSideEnum.DESTINATION + ): + relationship_fields["custom_relationship_foreign_keys"][field] = { + **value, + "_annotation": annotation, + } + else: + relationship_fields["custom_relationship_many_to_many_fields"][field] = { + "annotation": annotation, + "objects": [ + adapter.get_from_orm_cache(parameters, related_model_class) for parameters in value + ], + } @classmethod def _update_obj_with_parameters(cls, obj, parameters, adapter): """Update a given Nautobot ORM object with the given parameters.""" + # TODO: Use Dataclasses instead of dictionaries for structured data storage and tracking. relationship_fields = { # Example: {"group": {"name": "Group Name", "_model_class": TenantGroup}} "foreign_keys": defaultdict(dict), diff --git a/nautobot_ssot/contrib/types.py b/nautobot_ssot/contrib/types.py index cacae0650..049b6c257 100644 --- a/nautobot_ssot/contrib/types.py +++ b/nautobot_ssot/contrib/types.py @@ -12,6 +12,10 @@ class CustomAnnotation: """Base class used to identify custom annotations in SSoT operations.""" + def __hash__(self): + """Return a hash of the class instance.""" + return hash(frozenset({"class": self.__class__} | self.__dict__)) + @dataclass class CustomRelationshipAnnotation(CustomAnnotation): diff --git a/nautobot_ssot/jobs/examples.py b/nautobot_ssot/jobs/examples.py index 198f88684..ecf0d7da8 100644 --- a/nautobot_ssot/jobs/examples.py +++ b/nautobot_ssot/jobs/examples.py @@ -530,7 +530,7 @@ def load(self): self.load_device_types() self.load_platforms() self.load_devices() - # self.load_interfaces() + self.load_interfaces() def load_location_types(self): """Load LocationType data from the remote Nautobot instance. diff --git a/nautobot_ssot/tests/contrib/models.py b/nautobot_ssot/tests/contrib/models.py new file mode 100644 index 000000000..e62751361 --- /dev/null +++ b/nautobot_ssot/tests/contrib/models.py @@ -0,0 +1,62 @@ +"""Example NautobotModel instances for unittests.""" + +from typing import Annotated, List, Optional + +from nautobot.dcim.models import Device, LocationType + +from nautobot_ssot.contrib.enums import RelationshipSideEnum +from nautobot_ssot.contrib.model import NautobotModel +from nautobot_ssot.contrib.types import ( + CustomFieldAnnotation, + CustomRelationshipAnnotation, +) +from nautobot_ssot.tests.contrib.typeddicts import DeviceDict, SoftwareImageFileDict, TagDict + + +class LocationTypeModel(NautobotModel): + """Example model for LocationType in unittests.""" + + _modelname = "location_type" + _model = LocationType + + name: str + nestable: Optional[bool] + + +class DeviceModel(NautobotModel): + """Example model for unittests. + + NOTE: We only need the typehints for this set of unittests. + """ + + _modelname = "device" + _model = Device + + # Standard Attributes + name: str + vc_position: Optional[int] + + # Foreign Keys + status__name: str + tenant__name: Optional[str] + + # N to many Relationships + tags: List[TagDict] = [] + software_image_files: Optional[List[SoftwareImageFileDict]] + + # Custom Fields + custom_str: Annotated[str, CustomFieldAnnotation(name="custom_str")] + custom_int: Annotated[int, CustomFieldAnnotation(name="custom_int")] + custom_bool: Optional[Annotated[bool, CustomFieldAnnotation(name="custom_bool")]] + + # Custom Foreign Keys + parent__name: Annotated[str, CustomRelationshipAnnotation(name="device_parent", side=RelationshipSideEnum.SOURCE)] + + # Custom N to Many Relationships + children: Annotated[ + List[DeviceDict], + CustomRelationshipAnnotation(name="device_children", side=RelationshipSideEnum.DESTINATION), + ] + + # Invalid Fields + invalid_field: str diff --git a/nautobot_ssot/tests/contrib/nautobot_model/test_method_get_synced_parameters.py b/nautobot_ssot/tests/contrib/nautobot_model/test_method_get_synced_parameters.py deleted file mode 100644 index 1e79a4c52..000000000 --- a/nautobot_ssot/tests/contrib/nautobot_model/test_method_get_synced_parameters.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Tests for contrib.NautobotModel.""" - -from nautobot.core.testing import TestCase - -from nautobot_ssot.contrib.model import NautobotModel - - -class TestMethodGetSyncedParameters(TestCase): - """Tests for manipulating custom relationships through the shared base model code.""" - - def test_single_identifer(self): - """Test a single identifier.""" - - class LocalModel(NautobotModel): - _identifiers = ("name",) - _attributes = () - - name: str - - result = LocalModel.get_synced_attributes() - self.assertEqual(len(result), 1) - self.assertIn("name", result) - - def test_multiple_identifiers(self): - """Test multiple identifiers, including a related field.""" - - class LocalModel(NautobotModel): - _identifiers = ( - "name", - "parent__name", - ) - _attributes = () - - name: str - parent__name: str - - result = LocalModel.get_synced_attributes() - self.assertEqual(len(result), 2) - self.assertIn("name", result) - self.assertIn("parent__name", result) - - def test_only_attributes(self): - """Test only attributes.""" - - class LocalModel(NautobotModel): - _identifiers = () - _attributes = ("description", "status") - - description: str - status: str - - result = LocalModel.get_synced_attributes() - self.assertEqual(len(result), 2) - self.assertIn("description", result) - self.assertIn("status", result) - - def test_identifiers_and_attributes(self): - """Test both identifiers and attributes.""" - - class LocalModel(NautobotModel): - _identifiers = ("name",) - _attributes = ("description", "status") - - name: str - description: str - status: str - - result = LocalModel.get_synced_attributes() - self.assertEqual(len(result), 3) - self.assertIn("name", result) - self.assertIn("description", result) - self.assertIn("status", result) - - def test_empty_identifiers_and_attributes(self): - """Test empty identifiers and attributes.""" - - class LocalModel(NautobotModel): - _identifiers = () - _attributes = () - - result = LocalModel.get_synced_attributes() - self.assertEqual(len(result), 0) - self.assertEqual(LocalModel.get_synced_attributes(), []) diff --git a/nautobot_ssot/tests/contrib/test_nautobot_model.py b/nautobot_ssot/tests/contrib/test_nautobot_model.py new file mode 100644 index 000000000..c5fbe9bd0 --- /dev/null +++ b/nautobot_ssot/tests/contrib/test_nautobot_model.py @@ -0,0 +1,192 @@ +"""Tests for contrib.NautobotModel.""" + +from typing import TypedDict + +from django.core.exceptions import FieldDoesNotExist +from nautobot.core.testing import TestCase + +from nautobot_ssot.contrib.enums import AttributeType +from nautobot_ssot.contrib.model import NautobotModel +from nautobot_ssot.tests.contrib.models import DeviceModel, LocationTypeModel + + +class SoftwareImageFileDict(TypedDict): + """Example software image file dict.""" + + image_file_name: str + + +class TagDict(TypedDict): + """Exampe tag Dict.""" + + name: str + + +class DeviceDict(TypedDict): + """Example device dict.""" + + name: str + + +class TestGetAttrEnum(TestCase): + """Unittests for the `get_attr_enum` class method.""" + + # Standard Attributes + # =================== + def test_get_string_attribute(self): + """Test that 'DeviceModel.name' is detected as a standard attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("name"), AttributeType.STANDARD) + + def test_get_optional_integer_attribute(self): + """Test that 'DeviceModel.vc_position' is detected as a standard attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("vc_position"), AttributeType.STANDARD) + + def test_get_bool_attribute(self): + """Test that 'LocationType.nestable' is detected as a standard attribute.""" + self.assertEqual(LocationTypeModel.get_attr_enum("nestable"), AttributeType.STANDARD) + + # Foreign Keys + # ============ + def test_get_foreign_key_attribute(self): + """Test that 'DeviceModel.status__name' is detected as a foreign key attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("status__name"), AttributeType.FOREIGN_KEY) + + def test_get_optional_foreign_key_attribute(self): + """Test that 'DeviceModel.tenant__name' is detected as a foreign key attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("tenant__name"), AttributeType.FOREIGN_KEY) + + # N to Many Relationships + # ======================= + def test_get_n_to_many_attribute(self): + """Test that 'DeviceModel.tags' is detected as a N-to-many relationship attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("tags"), AttributeType.N_TO_MANY_RELATIONSHIP) + + def test_get_optional_n_to_many_attribute(self): + """Test that 'DeviceModel.software_image_files' is detected as a N-to-many relationship attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("software_image_files"), AttributeType.N_TO_MANY_RELATIONSHIP) + + # Custom Fields + # ============= + def test_get_custom_string(self): + """Test that 'DeviceModel.custom_str' is detected as a custom field attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("custom_str"), AttributeType.CUSTOM_FIELD) + + def test_get_custom_int(self): + """Test that 'DeviceModel.custom_int' is detected as a custom field attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("custom_int"), AttributeType.CUSTOM_FIELD) + + def test_get_custom_bool(self): + """Test that 'DeviceModel.custom_bool' is detected as a custom field attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("custom_bool"), AttributeType.CUSTOM_FIELD) + + # Custom Foreign Keys + # =================== + def test_get_custom_foreign_key_attribute(self): + """Test that 'DeviceModel.parent__name' is detected as a custom foreign key attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("parent__name"), AttributeType.CUSTOM_FOREIGN_KEY) + + # Custom N to Many Relationships + # ============================== + def test_get_custom_n_to_many_attribute(self): + """Test that 'DeviceModel.children' is detected as a custom N-to-many relationship attribute.""" + self.assertEqual(DeviceModel.get_attr_enum("children"), AttributeType.CUSTOM_N_TO_MANY_RELATIONSHIP) + + # Invalid attributes + # ================== + def test_non_existant_attribute(self): + """Test that an invalid field raises FieldDoesNotExist.""" + with self.assertRaises(FieldDoesNotExist): + DeviceModel.get_attr_enum("invalid_field") + + def test_undefined_attribute(self): + """Test that an undefined attribute raises KeyError.""" + with self.assertRaises(KeyError): + DeviceModel.get_attr_enum("undefined_attr") + + +class TestMethodGetSyncedParameters(TestCase): + """Tests for manipulating custom relationships through the shared base model code.""" + + def test_single_identifer(self): + """Test a single identifier.""" + + class LocalModel(NautobotModel): + """Example class for testing.""" + + _identifiers = ("name",) + _attributes = () + + name: str + + result = LocalModel.get_synced_attributes() + self.assertEqual(len(result), 1) + self.assertIn("name", result) + + def test_multiple_identifiers(self): + """Test multiple identifiers, including a related field.""" + + class LocalModel(NautobotModel): + """Example class for testing.""" + + _identifiers = ( + "name", + "parent__name", + ) + _attributes = () + + name: str + parent__name: str + + result = LocalModel.get_synced_attributes() + self.assertEqual(len(result), 2) + self.assertIn("name", result) + self.assertIn("parent__name", result) + + def test_only_attributes(self): + """Test only attributes.""" + + class LocalModel(NautobotModel): + """Example class for testing.""" + + _identifiers = () + _attributes = ("description", "status") + + description: str + status: str + + result = LocalModel.get_synced_attributes() + self.assertEqual(len(result), 2) + self.assertIn("description", result) + self.assertIn("status", result) + + def test_identifiers_and_attributes(self): + """Test both identifiers and attributes.""" + + class LocalModel(NautobotModel): + """Example class for testing.""" + + _identifiers = ("name",) + _attributes = ("description", "status") + + name: str + description: str + status: str + + result = LocalModel.get_synced_attributes() + self.assertEqual(len(result), 3) + self.assertIn("name", result) + self.assertIn("description", result) + self.assertIn("status", result) + + def test_empty_identifiers_and_attributes(self): + """Test empty identifiers and attributes.""" + + class LocalModel(NautobotModel): + """Example class for testing.""" + + _identifiers = () + _attributes = () + + result = LocalModel.get_synced_attributes() + self.assertEqual(len(result), 0) + self.assertEqual(LocalModel.get_synced_attributes(), []) diff --git a/nautobot_ssot/tests/contrib/typeddicts.py b/nautobot_ssot/tests/contrib/typeddicts.py new file mode 100644 index 000000000..ad6ca6d7e --- /dev/null +++ b/nautobot_ssot/tests/contrib/typeddicts.py @@ -0,0 +1,21 @@ +"""Tests for contrib.NautobotModel.""" + +from typing import TypedDict + + +class SoftwareImageFileDict(TypedDict): + """Example software image file dict.""" + + image_file_name: str + + +class TagDict(TypedDict): + """Exampe tag Dict.""" + + name: str + + +class DeviceDict(TypedDict): + """Example device dict.""" + + name: str diff --git a/nautobot_ssot/utils/diffsync.py b/nautobot_ssot/utils/diffsync.py index 00fa3ef41..3e1454d67 100644 --- a/nautobot_ssot/utils/diffsync.py +++ b/nautobot_ssot/utils/diffsync.py @@ -48,7 +48,13 @@ def get_attr_args(cls, attr_name: str) -> tuple: Returns: tuple: Type arguments for the attribute's type hint. """ - return get_args(cls.get_type_hints()[attr_name]) + attr_hints = cls.get_type_hints()[attr_name] + args = get_args(attr_hints) + # We don't care about Optional hints here, it's only relavent for data validation. + # Instead, we need the attribute args from inside the Optional tag. + if attr_hints.__name__ == "Optional": + return get_args(args[0]) + return args @classmethod @lru_cache