Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions crytic_compile/utils/natspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,70 @@
"""


class DevStateVariable:
"""
Model the dev state variable
"""

def __init__(self, variable: dict) -> None:
"""Init the object

Args:
method (Dict): Method infos (details, params, returns, custom:*)
"""
self._details: str | None = variable.get("details", None)
if "returns" in variable:
self._returns: dict[str, str] = variable["returns"]
elif "return" in variable:
self._returns: dict[str, str] = {"_0": variable["return"]}
else:
self._returns: dict[str, str] = {}
# Extract custom fields (keys starting with "custom:")
self._custom: dict[str, str] = {
k: v for k, v in variable.items() if k.startswith("custom:")
}

@property
def details(self) -> str | None:
"""Return the state variable details

Returns:
Optional[str]: state variable details
"""
return self._details

@property
def variable_returns(self) -> dict[str, str]:
"""Return the state variable returns

Returns:
dict[str, str]: state variable returns
"""
return self._returns

@property
def custom(self) -> dict[str, str]:
"""Return the state variable custom fields

Returns:
Dict[str, str]: custom field name => value (e.g. "custom:security" => "value")
"""
return self._custom

def export(self) -> dict:
"""Export to a python dict

Returns:
Dict: Exported dev state variable
"""
result = {
"details": self.details,
"returns": self.variable_returns,
"custom": self.custom,
}
return result


class UserMethod:
"""
Model the user method
Expand Down Expand Up @@ -47,12 +111,17 @@ def __init__(self, method: dict) -> None:
"""Init the object

Args:
method (Dict): Method infos (author, details, params, return, custom:*)
method (Dict): Method infos (author, details, params, returns, custom:*)
"""
self._author: str | None = method.get("author", None)
self._details: str | None = method.get("details", None)
self._params: dict[str, str] = method.get("params", {})
self._return: str | None = method.get("return", None)
if "returns" in method:
self._returns: dict[str, str] = method["returns"]
elif "return" in method:
self._returns: dict[str, str] = {"_0": method["return"]}
else:
self._returns: dict[str, str] = {}
# Extract custom fields (keys starting with "custom:")
self._custom: dict[str, str] = {k: v for k, v in method.items() if k.startswith("custom:")}

Expand All @@ -75,13 +144,13 @@ def details(self) -> str | None:
return self._details

@property
def method_return(self) -> str | None:
"""Return the method return
def method_returns(self) -> dict[str, str]:
"""Return the method returns

Returns:
Optional[str]: method return
dict[str, str]: method returns
"""
return self._return
return self._returns

@property
def params(self) -> dict[str, str]:
Expand Down Expand Up @@ -111,7 +180,7 @@ def export(self) -> dict:
"author": self.author,
"details": self.details,
"params": self.params,
"return": self.method_return,
"returns": self.method_returns,
}
# Include custom fields if present
result.update(self.custom)
Expand Down Expand Up @@ -180,6 +249,9 @@ def __init__(self, devdoc: dict):
self._methods: dict[str, DevMethod] = {
k: DevMethod(item) for k, item in devdoc.get("methods", {}).items()
}
self._state_variables: dict[str, DevStateVariable] = {
k: DevStateVariable(item) for k, item in devdoc.get("stateVariables", {}).items()
}
self._title: str | None = devdoc.get("title", None)
# Extract contract-level custom fields (keys starting with "custom:")
self._custom: dict[str, str] = {k: v for k, v in devdoc.items() if k.startswith("custom:")}
Expand Down Expand Up @@ -211,6 +283,15 @@ def methods(self) -> dict[str, DevMethod]:
"""
return self._methods

@property
def state_variables(self) -> dict[str, DevStateVariable]:
"""Return the dev state variables

Returns:
Dict[str, DevStateVariable]: state_variable_name => DevStateVariable
"""
return self._state_variables

@property
def title(self) -> str | None:
"""Return the dev title
Expand Down Expand Up @@ -240,6 +321,7 @@ def export(self) -> dict:
"author": self.author,
"details": self.details,
"title": self.title,
"state_variables": self.state_variables,
}
# Include custom fields if present
result.update(self.custom)
Expand Down
98 changes: 94 additions & 4 deletions tests/test_natspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
Test NatSpec parsing, including custom fields (@custom:*)
"""

from crytic_compile.utils.natspec import DevDoc, DevMethod, Natspec, UserDoc, UserMethod
from crytic_compile.utils.natspec import (
DevDoc,
DevMethod,
DevStateVariable,
Natspec,
UserDoc,
UserMethod,
)


class TestUserMethod:
Expand Down Expand Up @@ -40,7 +47,7 @@ def test_devmethod_basic_fields(self) -> None:
assert method.author == "Test Author"
assert method.details == "Method details"
assert method.params == {"a": "first param", "b": "second param"}
assert method.method_return == "return value description"
assert method.method_returns == {"_0": "return value description"}

def test_devmethod_custom_fields_parsing(self) -> None:
"""Test DevMethod extracts custom fields"""
Expand Down Expand Up @@ -80,7 +87,7 @@ def test_devmethod_export_includes_custom(self) -> None:
assert exported["author"] == "Test Author"
assert exported["details"] == "Details"
assert exported["params"] == {"x": "param x"}
assert exported["return"] == "returns something"
assert exported["returns"] == {"_0": "returns something"}
assert exported["custom:security"] == "critical"
assert exported["custom:audit"] == "passed"

Expand All @@ -90,9 +97,92 @@ def test_devmethod_empty_method(self) -> None:
assert method.author is None
assert method.details is None
assert method.params == {}
assert method.method_return is None
assert method.method_returns == {}
assert method.custom == {}

def test_devmethod_returns_dict(self) -> None:
"""Test DevMethod with 'returns' dict field (multiple return values)"""
method_data = {
"details": "Method with multiple returns",
"returns": {"_0": "first value", "_1": "second value"},
}
method = DevMethod(method_data)
assert method.method_returns == {"_0": "first value", "_1": "second value"}

def test_devmethod_returns_takes_precedence(self) -> None:
"""Test DevMethod prefers 'returns' over 'return' when both present"""
method_data = {
"returns": {"_0": "from returns"},
"return": "from return",
}
method = DevMethod(method_data)
assert method.method_returns == {"_0": "from returns"}


class TestDevStateVariable:
"""Tests for DevStateVariable class"""

def test_state_variable_with_returns_dict(self) -> None:
"""Test DevStateVariable with 'returns' dict field"""
var_data = {
"details": "A state variable",
"returns": {"_0": "the stored value"},
}
var = DevStateVariable(var_data)
assert var.details == "A state variable"
assert var.variable_returns == {"_0": "the stored value"}

def test_state_variable_with_return_string(self) -> None:
"""Test DevStateVariable falls back to 'return' string field"""
var_data = {
"details": "A state variable",
"return": "the stored value",
}
var = DevStateVariable(var_data)
assert var.variable_returns == {"_0": "the stored value"}

def test_state_variable_returns_takes_precedence(self) -> None:
"""Test DevStateVariable prefers 'returns' over 'return' when both present"""
var_data = {
"returns": {"_0": "from returns"},
"return": "from return",
}
var = DevStateVariable(var_data)
assert var.variable_returns == {"_0": "from returns"}

def test_state_variable_empty(self) -> None:
"""Test DevStateVariable with empty dict"""
var = DevStateVariable({})
assert var.details is None
assert var.variable_returns == {}
assert var.custom == {}

def test_state_variable_custom_fields(self) -> None:
"""Test DevStateVariable extracts custom fields"""
var_data = {
"details": "A variable",
"custom:security": "sensitive",
"custom:deprecated": "true",
}
var = DevStateVariable(var_data)
assert var.custom == {
"custom:security": "sensitive",
"custom:deprecated": "true",
}

def test_state_variable_export(self) -> None:
"""Test DevStateVariable export"""
var_data = {
"details": "A state variable",
"returns": {"_0": "the value"},
"custom:audit": "verified",
}
var = DevStateVariable(var_data)
exported = var.export()
assert exported["details"] == "A state variable"
assert exported["returns"] == {"_0": "the value"}
assert exported["custom"] == {"custom:audit": "verified"}


class TestUserDoc:
"""Tests for UserDoc class"""
Expand Down