Skip to content

Commit 5e2d3e9

Browse files
committed
simplify parser to only take SIS
1 parent ef2e0a9 commit 5e2d3e9

File tree

1 file changed

+14
-29
lines changed

1 file changed

+14
-29
lines changed

src/vivarium_profiling/plugins/parser.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SusceptibleState,
1111
)
1212

13+
CAUSE_KEY = "causes"
1314

1415
class ScalingParsingErrors(ParsingError):
1516
"""Error raised when there are any errors parsing a scaling configuration."""
@@ -30,11 +31,10 @@ class ScalingComponentParser(ComponentConfigurationParser):
3031
.. code-block:: yaml
3132
3233
components:
33-
diseases:
34-
type: SIS_fixed_duration
35-
cause: lower_respiratory_infections
34+
causes:
3635
number: 5
37-
duration: "28"
36+
cause: lower_respiratory_infections
37+
duration: 28
3838
3939
This will create 5 disease components named:
4040
- lower_respiratory_infections_1
@@ -66,14 +66,14 @@ def parse_component_config(self, component_config: LayeredConfigTree) -> list[Co
6666
"""
6767
components = []
6868

69-
if "diseases" in component_config:
70-
diseases_config = component_config["diseases"]
69+
if CAUSE_KEY in component_config:
70+
diseases_config = component_config[CAUSE_KEY]
7171
self._validate_diseases_config(diseases_config)
7272
components += self._get_scaled_disease_components(diseases_config)
7373

7474
# Parse standard components (i.e. not scaled components)
7575
standard_component_config = component_config.to_dict()
76-
standard_component_config.pop("diseases", None)
76+
standard_component_config.pop(CAUSE_KEY, None)
7777
standard_components = (
7878
self.process_level(standard_component_config, [])
7979
if standard_component_config
@@ -98,20 +98,16 @@ def _get_scaled_disease_components(
9898
"""
9999
components = []
100100

101-
component_type = diseases_config.get("type")
102101
base_cause = diseases_config.get("cause")
103102
number = diseases_config.get("number", 1)
104103
duration = diseases_config.get("duration", "28")
105104

106-
if component_type == "SIS_fixed_duration":
107-
for i in range(number):
108-
cause_name = f"{base_cause}_{i+1}"
109-
disease_component = self._create_sis_fixed_duration(
110-
cause_name, duration, base_cause
111-
)
112-
components.append(disease_component)
113-
else:
114-
raise ValueError(f"Unsupported disease type: {component_type}")
105+
for i in range(number):
106+
cause_name = f"{base_cause}_{i+1}"
107+
disease_component = self._create_sis_fixed_duration(
108+
cause_name, duration, base_cause
109+
)
110+
components.append(disease_component)
115111

116112
return components
117113

@@ -175,22 +171,11 @@ def _validate_diseases_config(self, diseases_config: LayeredConfigTree) -> None:
175171
error_messages = []
176172

177173
# Check required fields
178-
required_fields = ["type", "cause", "number"]
174+
required_fields = [CAUSE_KEY, "number"]
179175
for field in required_fields:
180176
if field not in diseases_config_dict:
181177
error_messages.append(f"Missing required field: {field}")
182178

183-
# Validate type
184-
supported_types = ["SIS_fixed_duration"]
185-
if (
186-
"type" in diseases_config_dict
187-
and diseases_config_dict["type"] not in supported_types
188-
):
189-
error_messages.append(
190-
f"Unsupported disease type: {diseases_config_dict['type']}. "
191-
f"Supported types: {supported_types}"
192-
)
193-
194179
# Validate number
195180
if "number" in diseases_config_dict:
196181
try:

0 commit comments

Comments
 (0)