Skip to content

Commit 8b6ff48

Browse files
committed
Improve tests compliance
1 parent 55090d0 commit 8b6ff48

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

api/src/backend/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ def _install_compliance_catalog_test_cache() -> None:
8686
test-only equivalent of an ``lru_cache`` on the SDK functions, without
8787
changing SDK behavior in production.
8888
89+
A second, lower-level cache memoizes ``load_compliance_framework_universal``
90+
**per file path**. ``get_bulk_compliance_frameworks_universal`` parses *every*
91+
compliance JSON and only then filters by provider, so a per-provider cache
92+
still re-parses all ~100 files on the first load of each provider. The
93+
per-path cache makes the first provider parse the files once and every other
94+
provider/test reuse the already-parsed ``ComplianceFramework`` objects (only
95+
the cheap ``listdir`` + filtering re-runs). ``_load_jsons_from_dir`` calls
96+
``load_compliance_framework_universal`` as a module global, so patching the
97+
attribute is picked up without touching the SDK.
98+
8999
Installed at conftest import time (before test modules are collected) so that
90100
even ``from ... import get_bulk_compliance_frameworks_universal`` bindings in
91101
the test modules resolve to the cached wrapper.
@@ -95,11 +105,13 @@ def _install_compliance_catalog_test_cache() -> None:
95105

96106
framework_cache: dict[str, dict] = {}
97107
checks_cache: dict[str, dict] = {}
108+
path_cache: dict[str, object] = {}
98109

99110
original_bulk_frameworks = (
100111
compliance_models.get_bulk_compliance_frameworks_universal
101112
)
102113
original_get_bulk = CheckMetadata.get_bulk
114+
original_load = compliance_models.load_compliance_framework_universal
103115

104116
def cached_bulk_frameworks(provider):
105117
if provider not in framework_cache:
@@ -111,7 +123,13 @@ def cached_get_bulk(provider):
111123
checks_cache[provider] = original_get_bulk(provider)
112124
return checks_cache[provider]
113125

126+
def cached_load(path):
127+
if path not in path_cache:
128+
path_cache[path] = original_load(path)
129+
return path_cache[path]
130+
114131
compliance_models.get_bulk_compliance_frameworks_universal = cached_bulk_frameworks
132+
compliance_models.load_compliance_framework_universal = cached_load
115133
CheckMetadata.get_bulk = staticmethod(cached_get_bulk)
116134

117135
# ``api.compliance`` does ``from ... import get_bulk_compliance_frameworks_universal``

0 commit comments

Comments
 (0)