1
1
from __future__ import annotations
2
2
3
3
import json
4
+ import shutil
4
5
from collections .abc import Generator
5
6
from datetime import datetime
6
7
from functools import cached_property
@@ -47,6 +48,7 @@ class LocalExperiment(BaseMode):
47
48
_parameter_file = Path ("parameter.json" )
48
49
_responses_file = Path ("responses.json" )
49
50
_metadata_file = Path ("metadata.json" )
51
+ _templates_file = Path ("templates.json" )
50
52
51
53
def __init__ (
52
54
self ,
@@ -86,6 +88,7 @@ def create(
86
88
observations : dict [str , pl .DataFrame ] | None = None ,
87
89
simulation_arguments : dict [Any , Any ] | None = None ,
88
90
name : str | None = None ,
91
+ templates : list [tuple [str , str ]] | None = None ,
89
92
) -> LocalExperiment :
90
93
"""
91
94
Create a new LocalExperiment and store its configuration data.
@@ -108,6 +111,8 @@ def create(
108
111
Simulation arguments for the experiment.
109
112
name : str, optional
110
113
Experiment name. Defaults to current date if None.
114
+ templates : list of tuple[str, str], optional
115
+ Run templates for the experiment. Defaults to None.
111
116
112
117
Returns
113
118
-------
@@ -130,6 +135,22 @@ def create(
130
135
json .dumps (parameter_data , indent = 2 ).encode ("utf-8" ),
131
136
)
132
137
138
+ if templates :
139
+ templates_path = path / "templates"
140
+ templates_path .mkdir (parents = True , exist_ok = True )
141
+ templates_abs : list [tuple [str , str ]] = []
142
+ for src , dst in templates :
143
+ incoming_template_file_path = Path (src )
144
+ template_file_path = Path (
145
+ templates_path / incoming_template_file_path .name
146
+ )
147
+ shutil .copyfile (incoming_template_file_path , template_file_path )
148
+ templates_abs .append ((str (template_file_path .resolve ()), dst ))
149
+ storage ._write_transaction (
150
+ path / cls ._templates_file ,
151
+ json .dumps (templates_abs ).encode ("utf-8" ),
152
+ )
153
+
133
154
response_data = {}
134
155
for response in responses or []:
135
156
response_data .update ({response .response_type : response .to_dict ()})
@@ -248,6 +269,16 @@ def parameter_info(self) -> dict[str, Any]:
248
269
info = json .load (f )
249
270
return info
250
271
272
+ @cached_property
273
+ def templates_configuration (self ) -> list [tuple [str , str ]]:
274
+ try :
275
+ with open (self .mount_point / self ._templates_file , encoding = "utf-8" ) as f :
276
+ return json .load (f )
277
+ except (FileNotFoundError , json .JSONDecodeError ):
278
+ pass
279
+ # If the file is missing or broken, we return an empty list
280
+ return []
281
+
251
282
@property
252
283
def response_info (self ) -> dict [str , Any ]:
253
284
info : dict [str , Any ]
0 commit comments