1313import itertools as it
1414import xml .etree .ElementTree as ET
1515
16- from .imports import xmlUtils , Template
16+ from .imports import Template
1717from .heron_types import HeronCase , Component , Source , ValuedParam
1818from .naming_utils import get_result_stats , get_component_activity_vars , get_opt_objective , get_statistics , Statistic
1919from .xml_utils import add_node_to_tree , stringify_node_values
2626from .snippets .models import GaussianProcessRegressor , PickledROM , EnsembleModel
2727from .snippets .distributions import Distribution , Uniform
2828from .snippets .outstreams import PrintOutStream
29- from .snippets .dataobjects import DataObject , PointSet , DataSet
29+ from .snippets .dataobjects import PointSet , DataSet
3030from .snippets .variablegroups import VariableGroup
3131from .snippets .files import File
3232from .snippets .factory import factory as snippet_factory
@@ -118,27 +118,26 @@ def createWorkflow(self, **kwargs) -> None:
118118 # Universal workflow settings
119119 self ._set_verbosity (kwargs ["case" ].get_verbosity ())
120120
121- def writeWorkflow (self , template : ET . Element , destination : str , run : bool = False ) -> None :
121+ def writeWorkflow (self , dest_dir : str ) -> None :
122122 """
123123 Writes a template to file.
124- @ In, template, xml.etree.ElementTree.Element, file to write
125- @ In, destination, str, path and filename to write to
126- @ In, run, bool, optional, if True then run the workflow after writing? good idea?
127- @ Out, errors, int, 0 if successfully wrote [and run] and nonzero if there was a problem
124+ @ In, dest_dir, str, path to the directory to which to write template workflows
125+ @ Out, None
128126 """
129127 # Ensure all node attribute values and text are expressed as strings. Errors are thrown if any of these aren't
130128 # strings. Enforcing this here allows flexibility with how node values are stored and manipulated before write
131129 # time, such as storing values as lists or numeric types. For example, text fields which are a comma-separated
132130 # list of values can be stored in the RavenSnippet object as a list, and new items can be inserted into that
133131 # list as needed, then the list can be converted to a string only now at write time.
134- stringify_node_values (template )
132+ stringify_node_values (self . _template )
135133
136134 # Remove any unused top-level nodes (Models, Samplers, etc.) to keep things looking clean
137- for node in template :
135+ for node in self . _template :
138136 if len (node ) == 0 :
139- template .remove (node )
137+ self . _template .remove (node )
140138
141- super ().writeWorkflow (template , destination , run )
139+ destination = self .get_write_path (dest_dir )
140+ super ().writeWorkflow (self ._template , destination )
142141 print (f"Wrote '{ self .write_name } ' to '{ destination } '" )
143142
144143 @property
@@ -221,7 +220,7 @@ def _set_case_name(self, name: str) -> None:
221220 @ In, name, str, case name to use
222221 @ Out, None
223222 """
224- run_info = self ._template .find ("RunInfo" ) # type: RunInfo
223+ run_info : RunInfo = self ._template .find ("RunInfo" )
225224 run_info .job_name = name
226225 run_info .working_dir = name
227226
@@ -232,7 +231,7 @@ def _add_step_to_sequence(self, step: Step, index: int | None = None) -> None:
232231 @ In, index, int, optional, the index to add the step at
233232 @ Out, None
234233 """
235- run_info = self ._template .find ("RunInfo" ) # type: RunInfo
234+ run_info : RunInfo = self ._template .find ("RunInfo" )
236235 idx = index if index is not None else len (run_info .sequence )
237236 run_info .sequence .insert (idx , step )
238237
@@ -264,7 +263,7 @@ def _load_file_to_object(self, source: Source, target: RavenSnippet) -> IOStep:
264263 @ Out, step, IOStep, the step used to do the loading
265264 """
266265 # Get the file to load. Might already exist in the template XML
267- file = self ._template .find ("Files/Input[@name='{source.name}']" ) # type: File
266+ file : File | None = self ._template .find ("Files/Input[@name='{source.name}']" )
268267 if file is None :
269268 file = File (source .name )
270269 file .path = source ._target_file
@@ -416,14 +415,14 @@ def _add_time_series_roms(self, ensemble_model: EnsembleModel, case: HeronCase,
416415 @ In, sources, list[Source], case sources
417416 @ Out, None
418417 """
419- dispatch_eval = self ._template .find ("DataObjects/DataSet[@name='dispatch_eval']" ) # type: DataSet
418+ dispatch_eval : DataSet = self ._template .find ("DataObjects/DataSet[@name='dispatch_eval']" )
420419
421420 # Gather any ARMA sources from the list of sources
422421 arma_sources = [s for s in sources if s .is_type ("ARMA" )]
423422
424423 # Add cluster index info to dispatch variable groups and data objects
425424 if any (source .eval_mode == "clustered" for source in arma_sources ):
426- vg_dispatch = self ._template .find ("VariableGroups/Group[@name='GRO_dispatch']" ) # type: VariableGroup
425+ vg_dispatch : VariableGroup = self ._template .find ("VariableGroups/Group[@name='GRO_dispatch']" )
427426 vg_dispatch .variables .append (self .namingTemplates ["cluster_index" ])
428427 dispatch_eval .add_index (self .namingTemplates ["cluster_index" ], "GRO_dispatch_in_Time" )
429428
@@ -555,7 +554,7 @@ def _get_uncertain_cashflow_params(self,
555554 dist_name = self .namingTemplates ["distribution" ].format (variable = feat_name )
556555
557556 # Reconstruct distribution XML node from valuedParam definition
558- dist_node = vp ._vp .get_distribution () # type: ET.Element
557+ dist_node : ET . Element = vp ._vp .get_distribution ()
559558 dist_node .set ("name" , dist_name )
560559 dist_snippet = snippet_factory .from_xml (dist_node )
561560 distributions .append (dist_snippet )
@@ -597,7 +596,7 @@ def _create_sampler_variables(self,
597596 interaction = component .get_interaction ()
598597 name = component .name
599598 var_name = self .namingTemplates ["variable" ].format (unit = name , feature = "capacity" )
600- cap = interaction .get_capacity (None , raw = True ) # type: ValuedParam
599+ cap : ValuedParam = interaction .get_capacity (None , raw = True )
601600
602601 if not cap .is_parametric (): # we already know the value
603602 continue
@@ -640,7 +639,7 @@ def _configure_static_history_sampler(self,
640639 if case .debug ["enabled" ]:
641640 indices .append (cluster_index )
642641
643- time_series_vargroup = self ._template .find ("VariableGroups/Group[@name='GRO_timeseries']" ) # type: VariableGroup
642+ time_series_vargroup : VariableGroup = self ._template .find ("VariableGroups/Group[@name='GRO_timeseries']" )
644643
645644 for source in filter (lambda x : x .is_type ("CSV" ), sources ):
646645 # Add the source variables to the GRO_timeseries_in variable group
0 commit comments