Skip to content

Commit 8b28d8b

Browse files
committed
Improve handling of circular schema references
1 parent 4afec88 commit 8b28d8b

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

src/openapi_test_client/libraries/api/api_spec.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from openapi_test_client.clients import APIClientType
1717
from openapi_test_client.libraries.api import Endpoint
1818

19-
2019
logger = get_logger(__name__)
2120

2221

@@ -57,7 +56,7 @@ def get_api_spec(self, url: str | None = None) -> dict[str, Any] | None:
5756
elif not (open_api_version := api_spec["openapi"]).startswith("3."):
5857
raise NotImplementedError(f"Unsupported OpenAPI version: {open_api_version}")
5958
except Exception as e:
60-
logger.error(f"Unable to get API specs from {url}\n{type(e).__name__}: {e}")
59+
logger.exception(f"Unable to get API specs from {url}\n{type(e).__name__}", exc_info=e)
6160
else:
6261
self._spec = OpenAPISpec.parse(api_spec)
6362
return self._spec
@@ -113,7 +112,17 @@ def _resolve_schemas(api_spec: dict[str, Any]) -> dict[str, Any]:
113112
ref_pattern = re.compile(r"#?/([^/]+)")
114113

115114
def has_reference(obj: Any) -> bool:
116-
return "'$ref':" in str(obj)
115+
if isinstance(obj, dict):
116+
for key, value in obj.items():
117+
if key == "$ref":
118+
return True
119+
if has_reference(value):
120+
return True
121+
elif isinstance(obj, list):
122+
for item in obj:
123+
if has_reference(item):
124+
return True
125+
return False
117126

118127
def resolve_recursive(reference: Any, schemas_seen: list[str] | None = None):
119128
if schemas_seen is None:
@@ -122,32 +131,38 @@ def resolve_recursive(reference: Any, schemas_seen: list[str] | None = None):
122131
for k, v in copy.deepcopy(reference).items():
123132
new_reference = reference[k]
124133
if k == "$ref":
134+
del reference[k]
125135
ref_keys = re.findall(ref_pattern, new_reference)
126136
assert ref_keys
127137
schema = "/".join(ref_keys)
128138
if schema in schemas_seen:
139+
# Detected a circular reference
129140
logger.warning(
130-
f"WARNING: Detected recursive schema definition. This is not supported: {schema}"
141+
f"WARNING: Detected a circular schema reference. This is not supported: {schema}"
131142
)
143+
reference.clear()
144+
reference.update(type="object")
132145
else:
133146
schemas_seen.append(schema)
134-
try:
135-
resolved_value = reduce(lambda d, k: d[k], ref_keys, api_spec)
136-
del reference[k]
137-
except KeyError as e:
138-
logger.warning(f"SKIPPED: Unable to resolve '$ref' for '{new_reference}' (KeyError: {e})")
139-
else:
140-
if has_reference(resolved_value):
141-
resolved_value = resolve_recursive(resolved_value, schemas_seen=schemas_seen)
142-
if isinstance(resolved_value, dict):
143-
reference.update(resolved_value)
147+
try:
148+
resolved_value = reduce(lambda d, k: d[k], ref_keys, api_spec)
149+
except KeyError as e:
150+
logger.warning(
151+
f"SKIPPED: Unable to resolve '$ref' for '{new_reference}' (KeyError: {e})"
152+
)
144153
else:
145-
reference = resolved_value
154+
if has_reference(resolved_value):
155+
resolved_value = resolve_recursive(resolved_value, schemas_seen=schemas_seen)
156+
schemas_seen.remove(schema)
157+
if isinstance(resolved_value, dict):
158+
reference.update(resolved_value)
159+
else:
160+
reference = resolved_value
146161
else:
147-
resolve_recursive(new_reference)
162+
resolve_recursive(new_reference, schemas_seen=schemas_seen)
148163
elif isinstance(reference, list):
149164
for item in reference:
150-
resolve_recursive(item)
165+
resolve_recursive(item, schemas_seen=schemas_seen)
151166
return reference
152167

153168
if has_reference(api_spec):

0 commit comments

Comments
 (0)