diff --git a/payu/models/staged_cable.py b/payu/models/staged_cable.py index 9b6b8158..77dfa95c 100644 --- a/payu/models/staged_cable.py +++ b/payu/models/staged_cable.py @@ -11,6 +11,7 @@ import os import shutil import itertools +import glob # Extensions import f90nml @@ -21,23 +22,6 @@ from payu.fsops import mkdir_p -def deep_update(d_1, d_2): - """Deep update of namelists.""" - for key, value in d_2.items(): - if isinstance(value, dict): - # Nested struct - if key in d_1: - # If the master namelist contains the key, then recursively - # apply - deep_update(d_1[key], d_2[key]) - else: - # Otherwise just set the value from the patch dict - d_1[key] = value - else: - # Is value, just override - d_1[key] = value - - class StagedCable(Model): """A driver for running staged CABLE spin-up configurations.""" @@ -48,29 +32,40 @@ def __init__(self, expt, name, config): self.default_exec = 'cable' self.config_files = ['stage_config.yaml'] - self.optional_config_files = ['cable.nml', 'cru.nml', - 'luc.nml', 'met_names.nml'] + # To support different branches of cable, which may have different + # namelists, add all found namelists to the optional_config_files + self.optional_config_files = glob.glob('*.nml') + + def setup(self): + super(StagedCable, self).setup() + + # Initialise the configuration log self.configuration_log = {} - if not os.path.isfile('configuration_log.yaml'): + conf_log_p = os.path.join(self.control_path, 'configuration_log.yaml') + if not os.path.isfile(conf_log_p): # Build a new configuration log self._build_new_configuration_log() else: # Read the current configuration log self._read_configuration_log() - # Now set the number of runs using the configuration_log - remaining_stages = len(self.configuration_log['queued_stages']) - print("Overriding the remaining number of runs according to the " + - "number of queued stages in the configuration log.") - os.environ['PAYU_N_RUNS'] = str(remaining_stages) + # Prepare the namelists for the stage + stage_name = self._get_stage_name() + self._apply_stage_namelists(stage_name) + + # Make the logging directory + mkdir_p(os.path.join(self.work_path, "logs")) + + self._set_current_stage() def _build_new_configuration_log(self): """Build a new configuration log for the first stage of the run.""" + stage_conf_p = os.path.join(self.control_path, 'stage_config.yaml') # Read the stage_config.yaml file - with open('stage_config.yaml', 'r') as stage_conf_f: + with open(stage_conf_p, 'r') as stage_conf_f: self.stage_config = yaml.safe_load(stage_conf_f) # On the first run, we need to read the 'stage_config.yaml' file. @@ -85,9 +80,12 @@ def _build_new_configuration_log(self): def _read_configuration_log(self): """Read the existing configuration log.""" - with open('configuration_log.yaml') as conf_log_file: + conf_log_p = os.path.join(self.control_path, 'configuration_log.yaml') + with open(conf_log_p, 'r') as conf_log_file: self.configuration_log = yaml.safe_load(conf_log_file) + print(f"After reading configuration_log: {self.configuration_log}") + def _prepare_configuration(self): """Prepare the stages in the CABLE configuration.""" @@ -127,55 +125,6 @@ def _prepare_configuration(self): # Finish handling of single step stage return cable_stages - def setup(self): - super(StagedCable, self).setup() - - # Prepare the namelists for the stage - stage_name = self._get_stage_name() - self._apply_stage_namelists(stage_name) - - # Make the logging directory - mkdir_p(os.path.join(self.work_path, "logs")) - - # Get the additional restarts from older restart dirs - self._get_further_restarts() - - # Make necessary adjustments to the configuration log - self._handle_configuration_log_setup() - - def _get_further_restarts(self): - """Get the restarts from stages further in the past where necessary.""" - - # Often we take restarts from runs which are not the most recent run as - # inputs for particular science modules, which means we have to extend - # the existing functionality around retrieving restarts. - - # We can't supercede the parent get_prior_restart_files, since the - # files returned by said function are prepended by - # self.prior_restart_path, which is not desirable in this instance. - - num_completed_stages = len(self.configuration_log['completed_stages']) - - for stage_number in reversed(range(num_completed_stages - 1)): - respath = os.path.join( - self.expt.archive_path, - f'restart{stage_number:03d}' - ) - for f_name in os.listdir(respath): - if os.path.isfile(os.path.join(respath, f_name)): - f_orig = os.path.join(respath, f_name) - f_link = os.path.join(self.work_init_path_local, f_name) - # Check whether a given link already exists in the - # manifest, so we don't write over a newer version of a - # restart - if f_link not in self.expt.manifest.manifests['restart']: - self.expt.manifest.add_filepath( - 'restart', - f_link, - f_orig, - self.copy_restarts - ) - def set_model_pathnames(self): super(StagedCable, self).set_model_pathnames() @@ -226,64 +175,97 @@ def _apply_stage_namelists(self, stage_name): for namelist in namelists: write_target = os.path.join(self.work_input_path, namelist) stage_nml = os.path.join(self.control_path, stage_name, namelist) + master_nml = os.path.join(self.control_path, namelist) - if os.path.isfile(os.path.join(self.control_path, namelist)): + if os.path.isfile(master_nml): # Instance where there is a master and stage namelist with open(stage_nml) as stage_nml_f: stage_namelist = f90nml.read(stage_nml_f) - master_nml = os.path.join(self.control_path, namelist) f90nml.patch(master_nml, stage_namelist, write_target) else: # Instance where there is only a stage namelist shutil.copy(stage_nml, write_target) - def _handle_configuration_log_setup(self): - """Make appropriate adjustments to the configuration log to reflect - that the setup of the stage is complete.""" + def _set_current_stage(self): + """Move the stage at the front of the queue into the current stage + slot, then copy the configuration log to the working directory.""" - if self.configuration_log['current_stage'] != '': - # If the current stage is a non-empty string, it means we exited - # during the running of the previous stage- leave as is - stage_name = self.configuration_log['current_stage'] - else: - # Normal case where we just archived a successful stage. - self.configuration_log['current_stage'] = \ - self.configuration_log['queued_stages'].pop(0) + self.configuration_log['current_stage'] = \ + self.configuration_log['queued_stages'].pop(0) self._save_configuration_log() - - # Copy the log to the work directory - shutil.copy('configuration_log.yaml', self.work_input_path) + conf_log_p = os.path.join(self.control_path, 'configuration_log.yaml') + shutil.copy(conf_log_p, self.work_path) def archive(self): """Store model output to laboratory archive and update the configuration log.""" - # Move files from the restart directory within work to the archive - # restart directory. + # Retrieve all the restarts required for the next stage + self._collect_restarts() + + # Update the configuration log and save it to the working directory + self._read_configuration_log() + self._archive_current_stage() + + # Now set the number of runs using the configuration_log + remaining_stages = len(self.configuration_log['queued_stages']) + print("Overriding the remaining number of runs according to the " + + "number of queued stages in the configuration log.") + self.expt.n_runs = remaining_stages + + conf_log_p = os.path.join(self.control_path, 'configuration_log.yaml') + if self.expt.n_runs == 0: + # Configuration successfully completed + os.remove(conf_log_p) + + super(StagedCable, self).archive() + + def _collect_restarts(self): + """Collect all restart files required for the next stage. This is a + merge of the files in work_path/restart and in prior_restart_path, with + the files in work_path/restart taking precedence.""" + + # First, collect restarts which do not have a newer version (when the + # counter is greater than 0) + if self.expt.counter > 0: + prior_restart_dir = 'restart{0:03}'.format(self.expt.counter - 1) + prior_restart_path = os.path.join(self.expt.archive_path, + prior_restart_dir) + + # For each restart, check if newer version was created. If not, + # copy into the work restart path. + generated_restarts = os.listdir(self.work_restart_path) + + for f in os.listdir(prior_restart_path): + if f not in generated_restarts: + shutil.copy(os.path.join(prior_restart_path, f), + self.work_restart_path) + + # Move the files in work_path/restart first for f in os.listdir(self.work_restart_path): shutil.move(os.path.join(self.work_restart_path, f), self.restart_path) os.rmdir(self.work_restart_path) - # Update the configuration log and save it to the working directory - completed_stage = self.configuration_log['current_stage'] - self.configuration_log['current_stage'] = '' - self.configuration_log['completed_stages'].append(completed_stage) + def _archive_current_stage(self): + """Move the current stage to the list of completed stages.""" + self.configuration_log['completed_stages'].append( + self.configuration_log['current_stage']) + self.configuration_log['current_stage'] = '' self._save_configuration_log() - if len(self.configuration_log["queued_stages"]) == 0: - # Configuration successfully completed - os.remove('configuration_log.yaml') - - super(StagedCable, self).archive() + # Copy the configuration log to the restart directory for shareability + conf_log_p = os.path.join(self.control_path, 'configuration_log.yaml') + shutil.copy(conf_log_p, self.restart_path) def collate(self): pass def _save_configuration_log(self): """Write the updated configuration log back to the staging area.""" - with open('configuration_log.yaml', 'w+') as config_log_f: + conf_log_p = os.path.join(self.control_path, 'configuration_log.yaml') + with open(conf_log_p, 'w+') as config_log_f: yaml.dump(self.configuration_log, config_log_f) diff --git a/test/models/test_staged_cable.py b/test/models/test_staged_cable.py index ed7b56c6..61f5bd97 100644 --- a/test/models/test_staged_cable.py +++ b/test/models/test_staged_cable.py @@ -33,6 +33,8 @@ def setup_module(module): ctrldir.mkdir() expt_workdir.mkdir(parents=True) archive_dir.mkdir() + restart_dir = archive_dir / 'ctrl' / 'restart000' + restart_dir.mkdir(parents=True) except Exception as e: print(e) @@ -107,10 +109,10 @@ def teardown_module(module): if verbose: print("teardown_module module:%s" % module.__name__) - try: - shutil.rmtree(tmpdir) - except Exception as e: - print(e) + # try: + # shutil.rmtree(tmpdir) + # except Exception as e: + # print(e) def test_staged_cable(): @@ -123,50 +125,47 @@ def test_staged_cable(): expt = payu.experiment.Experiment(lab, reproduce=False) model = expt.models[0] - # Since we've called the initialiser, we should be able to inspect the - # stages immediately (through the configuration log) - expected_queued_stages = [ - 'stage_1', - 'stage_2', - 'stage_3', - 'stage_4', - 'stage_3', - 'stage_3', - 'stage_5', - 'stage_6', - 'stage_6', - 'stage_7'] - assert model.configuration_log['queued_stages'] == expected_queued_stages - - # Now prepare for a stage- should see changes in the configuration log - # and the patched namelist in the workdir - model.setup() - expected_current_stage = expected_queued_stages.pop(0) - assert model.configuration_log['current_stage'] == expected_current_stage - assert model.configuration_log['queued_stages'] == expected_queued_stages - - # Now check the namelist - expected_namelist = { - 'cablenml': { - 'option1': 10, - 'struct1': { - 'option2': 20, - 'option3': 3, - 'option5': 50 - }, - 'option4': 4, - 'option6': 60 + # Now prepare for a stage- should see changes in the configuration log + # and the patched namelist in the workdir + model.setup() + + # Since we've called the initialiser, we should be able to inspect the + # stages immediately (through the configuration log) + expected_q_stages = [ + 'stage_2', + 'stage_3', + 'stage_4', + 'stage_3', + 'stage_3', + 'stage_5', + 'stage_6', + 'stage_6', + 'stage_7'] + assert model.configuration_log['queued_stages'] == expected_q_stages + assert model.configuration_log['current_stage'] == 'stage_1' + + # Now check the namelist + expected_namelist = { + 'cablenml': { + 'option1': 10, + 'struct1': { + 'option2': 20, + 'option3': 3, + 'option5': 50 + }, + 'option4': 4, + 'option6': 60 + } } - } - with open(expt_workdir / 'cable.nml') as stage_nml_f: - stage_nml = f90nml.read(stage_nml_f) + with open(expt_workdir / 'cable.nml') as stage_nml_f: + stage_nml = f90nml.read(stage_nml_f) - assert stage_nml == expected_namelist + assert stage_nml == expected_namelist - # Archive the stage and make sure the configuration log is correct - model.archive() - expected_comp_stages = [expected_current_stage] - expected_current_stage = '' - assert model.configuration_log['completed_stages'] == expected_comp_stages - assert model.configuration_log['current_stage'] == expected_current_stage + # Archive the stage and make sure the configuration log is correct + model.archive() + ex_comp_stages = ['stage_1'] + ex_curr_stage = '' + assert model.configuration_log['completed_stages'] == ex_comp_stages + assert model.configuration_log['current_stage'] == ex_curr_stage