|
| 1 | +"""Functions for sorting DiffSync model lists ensuring they are sorted to prevent false actions.""" |
| 2 | + |
| 3 | +import sys |
| 4 | + |
| 5 | +from diffsync import Adapter, DiffSyncModel |
| 6 | +from typing_extensions import TypedDict, get_type_hints |
| 7 | + |
| 8 | +from nautobot_ssot.contrib.typeddicts import SortKey |
| 9 | +from nautobot_ssot.contrib.types import SortType |
| 10 | + |
| 11 | + |
| 12 | +def _is_sortable_field(attribute_type_hints) -> bool: |
| 13 | + """Check if a DiffSync attribute is a sortable field.""" |
| 14 | + minor_ver = sys.version_info[1] |
| 15 | + try: |
| 16 | + # For Python 3.9 and older |
| 17 | + if minor_ver <= 9: |
| 18 | + attr_name = attribute_type_hints._name # pylint: disable=protected-access |
| 19 | + else: |
| 20 | + attr_name = attribute_type_hints.__name__ |
| 21 | + except AttributeError: |
| 22 | + return False |
| 23 | + |
| 24 | + return str(attr_name) in [ |
| 25 | + "list", |
| 26 | + "List", |
| 27 | + ] |
| 28 | + |
| 29 | + |
| 30 | +def _get_sort_key_from_typed_dict(sortable_content_type) -> str: |
| 31 | + """Get the dictionary key from a TypedDict if found.""" |
| 32 | + for key, value in sortable_content_type.__annotations__.items(): |
| 33 | + try: |
| 34 | + metadata = value.__metadata__ |
| 35 | + except AttributeError: |
| 36 | + continue |
| 37 | + for entry in metadata: |
| 38 | + if entry == SortKey: |
| 39 | + return key |
| 40 | + return None |
| 41 | + |
| 42 | + |
| 43 | +def get_sortable_fields_from_model(model: DiffSyncModel) -> dict: |
| 44 | + """Get a list of sortable fields and their sort key from a DiffSync model.""" |
| 45 | + sortable_fields = {} |
| 46 | + model_type_hints = get_type_hints(model, include_extras=True) |
| 47 | + |
| 48 | + for model_attribute_name in model._attributes: # pylint: disable=protected-access |
| 49 | + attribute_type_hints = model_type_hints.get(model_attribute_name) |
| 50 | + |
| 51 | + if not _is_sortable_field(attribute_type_hints): |
| 52 | + continue |
| 53 | + |
| 54 | + sortable_content_type = attribute_type_hints.__args__[0] |
| 55 | + |
| 56 | + if issubclass(sortable_content_type, dict) or issubclass(sortable_content_type, TypedDict): |
| 57 | + sort_key = _get_sort_key_from_typed_dict(sortable_content_type) |
| 58 | + if not sort_key: |
| 59 | + continue |
| 60 | + sortable_fields[model_attribute_name] = { |
| 61 | + "sort_type": SortType.DICT, |
| 62 | + "sort_key": sort_key, |
| 63 | + } |
| 64 | + # Add additional items here |
| 65 | + |
| 66 | + return sortable_fields |
| 67 | + |
| 68 | + |
| 69 | +def _sort_dict_attr(obj, attribute, key): |
| 70 | + """Update the sortable attribute in a DiffSync object.""" |
| 71 | + sorted_data = None |
| 72 | + if key: |
| 73 | + sorted_data = sorted( |
| 74 | + getattr(obj, attribute), |
| 75 | + key=lambda x: x[key], |
| 76 | + ) |
| 77 | + else: |
| 78 | + sorted_data = sorted(getattr(obj, attribute)) |
| 79 | + |
| 80 | + if sorted_data: |
| 81 | + setattr(obj, attribute, sorted_data) |
| 82 | + return obj |
| 83 | + |
| 84 | + |
| 85 | +def sort_relationships(source: Adapter, target: Adapter): |
| 86 | + """Sort relationships based on the metadata defined in the DiffSync model.""" |
| 87 | + if not source or not target: |
| 88 | + return |
| 89 | + |
| 90 | + models_to_sort = {} |
| 91 | + # Loop through target's top_level attribute to determine models with sortable attributes |
| 92 | + for model_name in getattr(target, "top_level", []): |
| 93 | + # Get the DiffSync Model |
| 94 | + diffsync_model = getattr(target, model_name) |
| 95 | + if not diffsync_model: |
| 96 | + continue |
| 97 | + |
| 98 | + # Get sortable fields current model |
| 99 | + model_sortable_fields = get_sortable_fields_from_model(diffsync_model) |
| 100 | + if not model_sortable_fields: |
| 101 | + continue |
| 102 | + models_to_sort[model_name] = model_sortable_fields |
| 103 | + |
| 104 | + # Loop through adapaters to sort models |
| 105 | + for adapter in (source, target): |
| 106 | + for model_name, attrs_to_sort in models_to_sort.items(): |
| 107 | + for diffsync_obj in adapter.get_all(model_name): |
| 108 | + for attr_name, sort_data in attrs_to_sort.items(): |
| 109 | + sort_type = sort_data["sort_type"] |
| 110 | + # Sort the data based on its sort type |
| 111 | + if sort_type == SortType.DICT: |
| 112 | + diffsync_obj = _sort_dict_attr(diffsync_obj, attr_name, sort_data["sort_key"]) |
| 113 | + adapter.update(diffsync_obj) |
0 commit comments