diff --git a/pynetbox/core/app.py b/pynetbox/core/app.py index 10cc6445..b173afe9 100644 --- a/pynetbox/core/app.py +++ b/pynetbox/core/app.py @@ -41,6 +41,7 @@ def __init__(self, api, name): self.api = api self.name = name self._setmodel() + self._cached_endpoints = {} models = { "dcim": dcim, @@ -63,7 +64,11 @@ def __setstate__(self, d): self._setmodel() def __getattr__(self, name): - return Endpoint(self.api, self, name, model=self.model) + if name not in self._cached_endpoints: + self._cached_endpoints[name] = Endpoint( + self.api, self, name, model=self.model + ) + return self._cached_endpoints[name] def config(self): """Returns config response from app @@ -103,6 +108,7 @@ class PluginsApp: def __init__(self, api): self.api = api + self._cached_apps = {} def __getstate__(self): return self.__dict__ @@ -111,7 +117,11 @@ def __setstate__(self, d): self.__dict__.update(d) def __getattr__(self, name): - return App(self.api, "plugins/{}".format(name.replace("_", "-"))) + if name not in self._cached_apps: + self._cached_apps[name] = App( + self.api, "plugins/{}".format(name.replace("_", "-")) + ) + return self._cached_apps[name] def installed_plugins(self): """Returns raw response with installed plugins diff --git a/pynetbox/core/endpoint.py b/pynetbox/core/endpoint.py index dfd989ec..3e42a812 100644 --- a/pynetbox/core/endpoint.py +++ b/pynetbox/core/endpoint.py @@ -20,6 +20,37 @@ RESERVED_KWARGS = () +class CachedRecordRegistry: + """ + A cache for Record objects. + """ + + def __init__(self): + self._cache = {} + self._hit = 0 + self._miss = 0 + + def get(self, object_type, key): + """ + Retrieves a record from the cache + """ + if not (object_cache := self._cache.get(object_type)): + return None + if object := object_cache.get(key, None): + self._hit += 1 + return object + self._miss += 1 + return None + + def set(self, object_type, key, value): + """ + Stores a record in the cache + """ + if object_type not in self._cache: + self._cache[object_type] = {} + self._cache[object_type][key] = value + + class Endpoint: """Represent actions available on endpoints in the Netbox API. @@ -42,8 +73,8 @@ class Endpoint: """ def __init__(self, api, app, name, model=None): - self.return_obj = self._lookup_ret_obj(name, model) self.name = name.replace("_", "-") + self.return_obj = self._lookup_ret_obj(model) self.api = api self.base_url = api.base_url self.token = api.token @@ -53,8 +84,12 @@ def __init__(self, api, app, name, model=None): endpoint=self.name, ) self._choices = None + self._init_cache() - def _lookup_ret_obj(self, name, model): + def _init_cache(self): + self._cache = CachedRecordRegistry() + + def _lookup_ret_obj(self, model): """Loads unique Response objects. This method loads a unique response object for an endpoint if @@ -67,7 +102,7 @@ def _lookup_ret_obj(self, name, model): :Returns: Record (obj) """ if model: - name = name.title().replace("_", "") + name = self.name.title().replace("-", "") ret = getattr(model, name, Record) else: ret = Record @@ -636,8 +671,8 @@ def count(self, *args, **kwargs): if any(i in RESERVED_KWARGS for i in kwargs): raise ValueError( - "A reserved {} kwarg was passed. Please remove it " - "try again.".format(RESERVED_KWARGS) + "A reserved kwarg was passed ({}). Please remove it " + "and try again.".format(RESERVED_KWARGS) ) ret = Request( diff --git a/pynetbox/core/response.py b/pynetbox/core/response.py index fd8cfaef..2262f50e 100644 --- a/pynetbox/core/response.py +++ b/pynetbox/core/response.py @@ -14,11 +14,9 @@ limitations under the License. """ -import copy +import marshal from collections import OrderedDict -from urllib.parse import urlsplit -import pynetbox.core.app from pynetbox.core.query import Request from pynetbox.core.util import Hashabledict @@ -26,37 +24,6 @@ LIST_AS_SET = ("tags", "tagged_vlans") -def get_return(lookup, return_fields=None): - """Returns simple representations for items passed to lookup. - - Used to return a "simple" representation of objects and collections - sent to it via lookup. Otherwise, we look to see if - lookup is a "choices" field (dict with only 'id' and 'value') - or a nested_return. Finally, we check if it's a Record, if - so simply return a string. Order is important due to nested_return - being self-referential. - - :arg list,optional return_fields: A list of fields to reference when - calling values on lookup. - """ - - for i in return_fields or ["id", "value", "nested_return"]: - if isinstance(lookup, dict) and lookup.get(i): - return lookup[i] - else: - if hasattr(lookup, i): - # check if this is a "choices" field record - # from a NetBox 2.7 server. - if sorted(dict(lookup)) == sorted(["id", "value", "label"]): - return getattr(lookup, "value") - return getattr(lookup, i) - - if isinstance(lookup, Record): - return str(lookup) - else: - return lookup - - def flatten_custom(custom_dict): ret = {} @@ -119,13 +86,21 @@ def __iter__(self): return self def __next__(self): - if self._response_cache: + try: + if self._response_cache: + return self.endpoint.return_obj( + self._response_cache.pop(), + self.endpoint.api, + self.endpoint, + ) return self.endpoint.return_obj( - self._response_cache.pop(), self.endpoint.api, self.endpoint + next(self.response), + self.endpoint.api, + self.endpoint, ) - return self.endpoint.return_obj( - next(self.response), self.endpoint.api, self.endpoint - ) + except StopIteration: + self.endpoint._init_cache() + raise def __len__(self): try: @@ -182,7 +157,56 @@ def delete(self): return self.endpoint.delete(self) -class Record: +class BaseRecord: + def __init__(self): + self._init_cache = [] + + def __getitem__(self, k): + return dict(self)[k] + + def __repr__(self): + return str(self) + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, d): + self.__dict__.update(d) + + +class ValueRecord(BaseRecord): + def __init__(self, values, *args, **kwargs): + super().__init__() + if values: + self._parse_values(values) + + def __iter__(self): + for k, _ in self._init_cache: + cur_attr = getattr(self, k) + yield k, cur_attr + + def __repr__(self): + return getattr(self, "label", "") + + @property + def _key(self): + return getattr(self, "value") + + def __eq__(self, other): + if isinstance(other, ValueRecord): + return self._foreign_key == other._foreign_key + return NotImplemented + + def _parse_values(self, values): + for k, v in values.items(): + self._init_cache.append((k, v)) + setattr(self, k, v) + + def serialize(self, nested=False): + return self._key if nested else dict(self) + + +class Record(BaseRecord): """Create Python objects from NetBox API responses. Creates an object from a NetBox response passed as ``values``. @@ -273,19 +297,13 @@ class Record: """ - url = None - def __init__(self, values, api, endpoint): self.has_details = False - self._full_cache = [] - self._init_cache = [] + super().__init__() self.api = api self.default_ret = Record - self.endpoint = ( - self._endpoint_from_url(values["url"]) - if values and "url" in values and values["url"] - else endpoint - ) + self.url = values.get("url", None) if values else None + self._endpoint = endpoint if values: self._parse_values(values) @@ -308,19 +326,16 @@ def __getattr__(self, k): raise AttributeError('object has no attribute "{}"'.format(k)) def __iter__(self): - for i in dict(self._init_cache): - cur_attr = getattr(self, i) - if isinstance(cur_attr, Record): - yield i, dict(cur_attr) + for k, _ in self._init_cache: + cur_attr = getattr(self, k) + if isinstance(cur_attr, BaseRecord): + yield k, dict(cur_attr) elif isinstance(cur_attr, list) and all( - isinstance(i, Record) for i in cur_attr + isinstance(i, (BaseRecord, GenericListObject)) for i in cur_attr ): - yield i, [dict(x) for x in cur_attr] + yield k, [dict(x) for x in cur_attr] else: - yield i, cur_attr - - def __getitem__(self, k): - return dict(self)[k] + yield k, cur_attr def __str__(self): return ( @@ -330,15 +345,6 @@ def __str__(self): or "" ) - def __repr__(self): - return str(self) - - def __getstate__(self): - return self.__dict__ - - def __setstate__(self, d): - self.__dict__.update(d) - def __key__(self): if hasattr(self, "id"): return (self.endpoint.name, self.id) @@ -353,9 +359,34 @@ def __eq__(self, other): return self.__key__() == other.__key__() return NotImplemented - def _add_cache(self, item): - key, value = item - self._init_cache.append((key, get_return(value))) + @property + def endpoint(self): + if self._endpoint is None: + self._endpoint = self._endpoint_from_url() + return self._endpoint + + def _endpoint_from_url(self): + url_path = self.url.replace(self.api.base_url, "").split("/") + is_plugin = url_path and url_path[1] == "plugins" + start = 2 if is_plugin else 1 + app, name = [i.replace("-", "_") for i in url_path[start : start + 2]] + if is_plugin: + return getattr(getattr(self.api.plugins, app), name) + else: + return getattr(getattr(self.api, app), name) + + def _get_or_init(self, object_type, key, value, model): + """ + Returns a record from the endpoint cache if it exists, otherwise + initializes a new record, store it in the cache, and return it. + """ + if self._endpoint: + if cached := self._endpoint._cache.get(object_type, key): + return cached + record = model(value, self.api, None) + if self._endpoint: + self._endpoint._cache.set(object_type, key, record) + return record def _parse_values(self, values): """Parses values init arg. @@ -364,81 +395,77 @@ def _parse_values(self, values): values within. """ - def generic_list_parser(key_name, list_item): - from pynetbox.models.mapper import CONTENT_TYPE_MAPPER + non_record_dict_fields = ["custom_fields", "local_context_data"] - if ( - isinstance(list_item, dict) - and "object_type" in list_item - and "object" in list_item - ): - lookup = list_item["object_type"] - model = None - model = CONTENT_TYPE_MAPPER.get(lookup) - if model: - return model(list_item["object"], self.api, self.endpoint) + def deep_copy(value): + return marshal.loads(marshal.dumps(value)) - return list_item + def dict_parser(key_name, value, model=None): + if key_name in non_record_dict_fields: + return value, deep_copy(value) - def list_parser(key_name, list_item): - if isinstance(list_item, dict): - lookup = getattr(self.__class__, key_name, None) - if not isinstance(lookup, list): - # This is *list_parser*, so if the custom model field is not - # a list (or is not defined), just return the default model - return self.default_ret(list_item, self.api, self.endpoint) - else: - model = lookup[0] - return model(list_item, self.api, self.endpoint) + if model is None: + model = getattr(self.__class__, key_name, None) - return list_item + if model and issubclass(model, JsonField): + return value, deep_copy(value) - for k, v in values.items(): - if isinstance(v, dict): - lookup = getattr(self.__class__, k, None) - if k in ["custom_fields", "local_context_data"] or hasattr( - lookup, "_json_field" - ): - self._add_cache((k, copy.deepcopy(v))) - setattr(self, k, v) - continue - if lookup: - v = lookup(v, self.api, self.endpoint) - else: - v = self.default_ret(v, self.api, self.endpoint) - self._add_cache((k, v)) - - elif isinstance(v, list): - # check if GFK - if len(v) and isinstance(v[0], dict) and "object_type" in v[0]: - v = [generic_list_parser(k, i) for i in v] - to_cache = list(v) - elif k == "constraints": - # Permissions constraints can be either dict or list - to_cache = copy.deepcopy(v) - else: - v = [list_parser(k, i) for i in v] - to_cache = list(v) - self._add_cache((k, to_cache)) + if (id := value.get("id", None)) and (url := value.get("url", None)): + model = model or Record + value = self._get_or_init(key_name, url, value, model) + return value, id + if record_value := value.get("value", None): + value = self._get_or_init(key_name, record_value, value, ValueRecord) + return value, record_value + + return value, deep_copy(value) + + def generic_list_parser(value): + from pynetbox.models.mapper import CONTENT_TYPE_MAPPER + + parsed_list = [] + for item in value: + object_type = item["object_type"] + if model := CONTENT_TYPE_MAPPER.get(object_type, None): + item = self._get_or_init( + object_type, item["object"]["url"], item["object"], model + ) + parsed_list.append(GenericListObject(item)) + return parsed_list + + def list_parser(key_name, value): + if not value: + return [], [] + + if key_name in ["constraints"]: + return value, deep_copy(value) + + sample_item = value[0] + if not isinstance(sample_item, dict): + return value, [*value] + + is_generic_list = "object_type" in sample_item and "object" in sample_item + if is_generic_list: + value = generic_list_parser(value) else: - self._add_cache((k, v)) - setattr(self, k, v) + lookup = getattr(self.__class__, key_name, None) + model = lookup[0] if isinstance(lookup, list) else self.default_ret + value = [dict_parser(key_name, i, model=model)[0] for i in value] - def _endpoint_from_url(self, url): - url_path = urlsplit(url).path - base_url_path_parts = urlsplit(self.api.base_url).path.split("/") - if len(base_url_path_parts) > 2: - # There are some extra directories in the path, remove them from url - extra_path = "/".join(base_url_path_parts[:-1]) - url_path = url_path[len(extra_path) :] - split_url_path = url_path.split("/") - if split_url_path[2] == "plugins": - app = "plugins/{}".format(split_url_path[3]) - name = split_url_path[4] - else: - app, name = split_url_path[2:4] - return getattr(pynetbox.core.app.App(self.api, app), name) + return value, [*value] + + def parse_value(key_name, value): + if isinstance(value, dict): + value, to_cache = dict_parser(key_name, value) + elif isinstance(value, list): + value, to_cache = list_parser(key_name, value) + else: + to_cache = value + setattr(self, key_name, value) + return to_cache + + self._init_cache = [(k, parse_value(k, v)) for k, v in values.items()] def full_details(self): """Queries the hyperlinked endpoint if 'url' is defined. @@ -470,6 +497,7 @@ def serialize(self, nested=False, init=False): If an attribute's value is a ``Record`` type it's replaced with the ``id`` field of that object. + .. note:: Using this to get a dictionary representation of the record @@ -479,30 +507,37 @@ def serialize(self, nested=False, init=False): :returns: dict. """ if nested: - return get_return(self) + return getattr(self, "id") if init: init_vals = dict(self._init_cache) ret = {} + for i in dict(self): current_val = getattr(self, i) if not init else init_vals.get(i) if i == "custom_fields": ret[i] = flatten_custom(current_val) else: - if isinstance(current_val, Record): + if isinstance(current_val, BaseRecord): current_val = getattr(current_val, "serialize")(nested=True) if isinstance(current_val, list): - current_val = [ - v.id if isinstance(v, Record) else v for v in current_val - ] + serialized_list = [] + for v in current_val: + if isinstance(v, BaseRecord): + v = v.id + elif isinstance(v, GenericListObject): + v = v.serialize() + serialized_list.append(v) + current_val = serialized_list if i in LIST_AS_SET and ( all([isinstance(v, str) for v in current_val]) or all([isinstance(v, int) for v in current_val]) ): current_val = list(OrderedDict.fromkeys(current_val)) ret[i] = current_val + return ret def _diff(self): @@ -615,3 +650,29 @@ def delete(self): http_session=self.api.http_session, ) return True if req.delete() else False + + +class GenericListObject: + def __init__(self, record): + from pynetbox.models.mapper import TYPE_CONTENT_MAPPER + + self.object = record + self.object_id = record.id + self.object_type = TYPE_CONTENT_MAPPER.get(record.__class__) + + def __repr__(self): + return str(self.object) + + def serialize(self): + ret = {k: getattr(self, k) for k in ["object_id", "object_type"]} + return ret + + def __getattr__(self, k): + return getattr(self.object, k) + + def __iter__(self): + for i in ["object_id", "object_type", "object"]: + cur_attr = getattr(self, i) + if isinstance(cur_attr, Record): + cur_attr = dict(cur_attr) + yield i, cur_attr diff --git a/pynetbox/models/dcim.py b/pynetbox/models/dcim.py index 02533c03..67bcd850 100644 --- a/pynetbox/models/dcim.py +++ b/pynetbox/models/dcim.py @@ -151,6 +151,7 @@ def __str__(self): class Interfaces(TraceableRecord): + device = Devices interface_connection = InterfaceConnection diff --git a/pynetbox/models/mapper.py b/pynetbox/models/mapper.py index 1dcfc998..9a468ee4 100644 --- a/pynetbox/models/mapper.py +++ b/pynetbox/models/mapper.py @@ -109,3 +109,5 @@ "wireless.WirelessLANGroup": None, "wireless.wirelesslink": None, } + +TYPE_CONTENT_MAPPER = {v: k for k, v in CONTENT_TYPE_MAPPER.items() if v is not None} diff --git a/tests/fixtures/users/permission.json b/tests/fixtures/users/permission.json index b33f7cbb..d807d9f4 100644 --- a/tests/fixtures/users/permission.json +++ b/tests/fixtures/users/permission.json @@ -3,7 +3,9 @@ "name": "permission1", "users": [ { - "username": "user1" + "id": 1, + "username": "user1", + "url": "http://localhost:8000/api/users/users/1/" } ], "constraints": [ diff --git a/tests/unit/test_response.py b/tests/unit/test_response.py index 1b4beb57..e27bd12f 100644 --- a/tests/unit/test_response.py +++ b/tests/unit/test_response.py @@ -26,8 +26,13 @@ def test_attribute_access(self): test_values = { "id": 123, "units": 12, - "nested_dict": {"id": 222, "name": "bar"}, + "nested_dict": { + "id": 222, + "name": "bar", + "url": "http://localhost:8000/api/test-app/test-endpoint/222/", + }, "int_list": [123, 321, 231], + "url": "http://localhost:8000/api/test-app/test-endpoint/123/", } test_obj = Record(test_values, None, None) self.assertEqual(test_obj.id, 123) @@ -219,11 +224,16 @@ def test_compare(self): self.assertEqual(test1, test2) def test_nested_write(self): - app = Mock() - app.token = "abc123" - app.base_url = "http://localhost:8080/api" + api = Mock() + api.token = "abc123" + api.base_url = "http://localhost:8080/api" endpoint = Mock() endpoint.name = "test-endpoint" + endpoint.url = "http://localhost:8080/api/test-app/test-endpoint/" + endpoint._cache = Mock() + endpoint._cache.get = Mock(return_value=None) + api.test_app = Mock() + api.test_app.test_endpoint = endpoint test = Record( { "id": 123, @@ -234,22 +244,29 @@ def test_nested_write(self): "url": "http://localhost:8080/api/test-app/test-endpoint/321/", }, }, - app, + api, endpoint, ) test.child.name = "test321" test.child.save() + print(api) + print(api.http_session) self.assertEqual( - app.http_session.patch.call_args[0][0], + api.http_session.patch.call_args[0][0], "http://localhost:8080/api/test-app/test-endpoint/321/", ) def test_nested_write_with_directory_in_base_url(self): - app = Mock() - app.token = "abc123" - app.base_url = "http://localhost:8080/testing/api" + api = Mock() + api.token = "abc123" + api.base_url = "http://localhost:8080/testing/api" endpoint = Mock() endpoint.name = "test-endpoint" + endpoint.url = "http://localhost:8080/testing/api/test-app/test-endpoint/" + endpoint._cache = Mock() + endpoint._cache.get = Mock(return_value=None) + api.test_app = Mock() + api.test_app.test_endpoint = endpoint test = Record( { "id": 123, @@ -260,19 +277,22 @@ def test_nested_write_with_directory_in_base_url(self): "url": "http://localhost:8080/testing/api/test-app/test-endpoint/321/", }, }, - app, + api, endpoint, ) test.child.name = "test321" test.child.save() self.assertEqual( - app.http_session.patch.call_args[0][0], + api.http_session.patch.call_args[0][0], "http://localhost:8080/testing/api/test-app/test-endpoint/321/", ) def test_endpoint_from_url(self): api = Mock() api.base_url = "http://localhost:8080/api" + api.test_app = Mock() + api.test_app.test_endpoint = Mock() + api.test_app.test_endpoint.name = "test-endpoint" test = Record( { "id": 123, @@ -282,12 +302,15 @@ def test_endpoint_from_url(self): api, None, ) - ret = test._endpoint_from_url(test.url) + ret = test._endpoint_from_url() self.assertEqual(ret.name, "test-endpoint") def test_endpoint_from_url_with_directory_in_base_url(self): api = Mock() api.base_url = "http://localhost:8080/testing/api" + api.test_app = Mock() + api.test_app.test_endpoint = Mock() + api.test_app.test_endpoint.name = "test-endpoint" test = Record( { "id": 123, @@ -297,12 +320,16 @@ def test_endpoint_from_url_with_directory_in_base_url(self): api, None, ) - ret = test._endpoint_from_url(test.url) + ret = test._endpoint_from_url() self.assertEqual(ret.name, "test-endpoint") def test_endpoint_from_url_with_plugins(self): api = Mock() api.base_url = "http://localhost:8080/api" + api.plugins = Mock() + api.plugins.test_app = Mock() + api.plugins.test_app.test_endpoint = Mock() + api.plugins.test_app.test_endpoint.name = "test-endpoint" test = Record( { "id": 123, @@ -312,12 +339,16 @@ def test_endpoint_from_url_with_plugins(self): api, None, ) - ret = test._endpoint_from_url(test.url) + ret = test._endpoint_from_url() self.assertEqual(ret.name, "test-endpoint") def test_endpoint_from_url_with_plugins_and_directory_in_base_url(self): api = Mock() api.base_url = "http://localhost:8080/testing/api" + api.plugins = Mock() + api.plugins.test_app = Mock() + api.plugins.test_app.test_endpoint = Mock() + api.plugins.test_app.test_endpoint.name = "test-endpoint" test = Record( { "id": 123, @@ -327,7 +358,7 @@ def test_endpoint_from_url_with_plugins_and_directory_in_base_url(self): api, None, ) - ret = test._endpoint_from_url(test.url) + ret = test._endpoint_from_url() self.assertEqual(ret.name, "test-endpoint") def test_serialize_tag_list_order(self):