Skip to content

Commit aad8296

Browse files
feat: support templating other types in default
1 parent 0ef6a68 commit aad8296

4 files changed

Lines changed: 194 additions & 41 deletions

File tree

agent-control/src/agent_type/render.rs

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -663,15 +663,22 @@ variables:
663663
description: "registry url"
664664
type: string
665665
required: false
666-
default: "${nr-default:registry_url}"
666+
default: "${nr-default:registry_url}/test-repository"
667+
enable_file_logging:
668+
description: "enable file logging"
669+
type: bool
670+
required: false
671+
default: ${nr-default:enable_file_logging}
667672
deployment:
668673
linux:
674+
enable_file_logging: ${nr-var:enable_file_logging}
669675
executables:
670676
- id: first
671677
path: /opt/first
672678
args:
673679
- "${nr-var:registry}"
674680
windows:
681+
enable_file_logging: ${nr-var:enable_file_logging}
675682
executables:
676683
- id: first
677684
path: /opt/first
@@ -683,10 +690,16 @@ deployment:
683690
let values = testing_values("");
684691
let attributes = testing_agent_attributes(&agent_id);
685692

686-
let global_defaults = HashMap::from([(
687-
"registry_url".to_string(),
688-
serde_yaml::to_value("random-registry-url").unwrap(),
689-
)]);
693+
let global_defaults = HashMap::from([
694+
(
695+
"registry_url".to_string(),
696+
serde_yaml::to_value("test-registry-url").unwrap(),
697+
),
698+
(
699+
"enable_file_logging".to_string(),
700+
serde_yaml::to_value(true).unwrap(),
701+
),
702+
]);
690703

691704
let renderer =
692705
TemplateRenderer::default().with_agent_control_variables(HashMap::new().into_iter());
@@ -701,14 +714,15 @@ deployment:
701714
)
702715
.unwrap();
703716
assert_eq!(
704-
rendered::Args(vec!("random-registry-url".to_string())),
705-
extract_runtime_by_environment(runtime_config)
717+
rendered::Args(vec!("test-registry-url/test-repository".to_string())),
718+
extract_runtime_by_environment(runtime_config.clone())
706719
.executables
707720
.first()
708721
.unwrap()
709722
.args
710723
.clone()
711724
);
725+
assert!(extract_runtime_by_environment(runtime_config.clone()).enable_file_logging);
712726
}
713727

714728
#[test]
@@ -726,15 +740,22 @@ variables:
726740
description: "registry url"
727741
type: string
728742
required: false
729-
default: "${nr-default:registry_url}"
743+
default: "${nr-default:registry_url}/test-repository"
744+
enable_file_logging:
745+
description: "enable file logging"
746+
type: bool
747+
required: false
748+
default: ${nr-default:enable_file_logging}
730749
deployment:
731750
linux:
751+
enable_file_logging: ${nr-var:enable_file_logging}
732752
executables:
733753
- id: first
734754
path: /opt/first
735755
args:
736756
- "${nr-var:registry}"
737757
windows:
758+
enable_file_logging: ${nr-var:enable_file_logging}
738759
executables:
739760
- id: first
740761
path: /opt/first
@@ -746,15 +767,27 @@ deployment:
746767
let values = testing_values("");
747768
let attributes = testing_agent_attributes(&agent_id);
748769

749-
let global_defaults = HashMap::from([(
750-
"registry_url".to_string(),
751-
serde_yaml::to_value("${nr-env:REGISTRY_URL}").unwrap(),
752-
)]);
770+
let global_defaults = HashMap::from([
771+
(
772+
"registry_url".to_string(),
773+
serde_yaml::to_value("${nr-env:REGISTRY_URL}").unwrap(),
774+
),
775+
(
776+
"enable_file_logging".to_string(),
777+
serde_yaml::to_value("${nr-env:ENABLE_FILE_LOGGING}").unwrap(),
778+
),
779+
]);
753780

754-
let secrets = HashMap::from([(
755-
Namespace::EnvironmentVariable.namespaced_name("REGISTRY_URL"),
756-
Variable::new_final_string_variable("random-registry-url".to_string()),
757-
)]);
781+
let secrets = HashMap::from([
782+
(
783+
Namespace::EnvironmentVariable.namespaced_name("REGISTRY_URL"),
784+
Variable::new_final_string_variable("test-registry-url".to_string()),
785+
),
786+
(
787+
Namespace::EnvironmentVariable.namespaced_name("ENABLE_FILE_LOGGING"),
788+
Variable::from(serde_yaml::Value::Bool(true)),
789+
),
790+
]);
758791

759792
let renderer =
760793
TemplateRenderer::default().with_agent_control_variables(HashMap::new().into_iter());
@@ -769,14 +802,15 @@ deployment:
769802
)
770803
.unwrap();
771804
assert_eq!(
772-
rendered::Args(vec!("random-registry-url".to_string())),
773-
extract_runtime_by_environment(runtime_config)
805+
rendered::Args(vec!("test-registry-url/test-repository".to_string())),
806+
extract_runtime_by_environment(runtime_config.clone())
774807
.executables
775808
.first()
776809
.unwrap()
777810
.args
778811
.clone()
779812
);
813+
assert!(extract_runtime_by_environment(runtime_config.clone()).enable_file_logging);
780814
}
781815

782816
// Agent Type and Values definitions

agent-control/src/agent_type/variable.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ mod tests {
211211

212212
#[test]
213213
fn variable_definition_tree_deserialize() {
214+
use super::fields::DefaultValue;
214215
let value = r#"
215216
foo:
216217
bar:
@@ -235,7 +236,7 @@ foo:
235236
variable_type: VariableTypeDefinition::String(StringFieldsDefinition {
236237
inner: FieldsDefinition {
237238
required: false,
238-
default: Some("a".to_string()),
239+
default: Some(DefaultValue::Value("a".to_string())),
239240
},
240241
variants: VariantsConfig {
241242
ac_config_field: Some("foo.bar.var_name".to_string()),

agent-control/src/agent_type/variable/fields.rs

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::agent_type::{
88
error::AgentTypeError,
99
variable::{
1010
constraints::{VariableConstraints, VariantsConstraints},
11+
namespace::Namespace,
1112
variants::{Variants, VariantsConfig},
1213
},
1314
};
@@ -19,7 +20,17 @@ where
1920
T: PartialEq,
2021
{
2122
pub(crate) required: bool,
22-
pub(crate) default: Option<T>,
23+
pub(crate) default: Option<DefaultValue<T>>,
24+
}
25+
26+
#[derive(Debug, PartialEq, Clone, Serialize)]
27+
#[serde(untagged)]
28+
pub enum DefaultValue<T>
29+
where
30+
T: PartialEq,
31+
{
32+
Value(T),
33+
Template(String),
2334
}
2435

2536
/// Type to support special default deserialization for 'null' Yaml value in 'default'.
@@ -47,7 +58,8 @@ where
4758
T: PartialEq,
4859
{
4960
pub(crate) required: bool,
50-
pub(crate) default: Option<T>,
61+
#[serde(flatten)]
62+
pub(crate) default: Option<DefaultValue<T>>,
5163
pub(crate) final_value: Option<T>, // TODO: move this outside the struct and avoid mutating the variables
5264
}
5365

@@ -133,7 +145,7 @@ impl StringFields {
133145

134146
impl<'de, T> Deserialize<'de> for FieldsDefinition<T>
135147
where
136-
T: Deserialize<'de> + PartialEq,
148+
T: serde::de::DeserializeOwned + PartialEq,
137149
{
138150
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139151
where
@@ -142,8 +154,8 @@ where
142154
use serde::de::Error;
143155
// intermediate serialization type to validate `default` and `required` fields
144156
#[derive(Debug, Deserialize)]
145-
struct IntermediateValueKind<T: PartialEq> {
146-
default: Option<T>,
157+
struct IntermediateValueKind {
158+
default: Option<serde_yaml::Value>,
147159
required: bool,
148160
}
149161

@@ -162,12 +174,60 @@ where
162174
}
163175

164176
Ok(FieldsDefinition {
165-
default: intermediate_spec.default,
166177
required: intermediate_spec.required,
178+
default: intermediate_spec
179+
.default
180+
.map(serde_yaml::from_value::<DefaultValue<T>>)
181+
.transpose()
182+
.map_err(|_| {
183+
D::Error::custom(AgentTypeError::Parse(
184+
"default value is not of the correct type".to_string(),
185+
))
186+
})?,
167187
})
168188
}
169189
}
170190

191+
impl<T> DefaultValue<T>
192+
where
193+
T: PartialEq,
194+
{
195+
pub fn as_value(&self) -> Option<&T> {
196+
match self {
197+
DefaultValue::Value(v) => Some(v),
198+
DefaultValue::Template(_) => None,
199+
}
200+
}
201+
}
202+
203+
impl<'de, T> Deserialize<'de> for DefaultValue<T>
204+
where
205+
T: serde::de::DeserializeOwned + PartialEq,
206+
{
207+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
208+
where
209+
D: Deserializer<'de>,
210+
{
211+
let yaml_val = serde_yaml::Value::deserialize(deserializer)?;
212+
213+
match serde_yaml::from_value::<T>(yaml_val.clone()) {
214+
Ok(value) => Ok(DefaultValue::Value(value)),
215+
Err(_) => {
216+
if let serde_yaml::Value::String(s) = yaml_val
217+
&& s.starts_with(format!("${{{}", Namespace::Default).as_str())
218+
&& s.ends_with("}")
219+
{
220+
Ok(DefaultValue::Template(s))
221+
} else {
222+
Err(serde::de::Error::custom(
223+
"default value must be of the correct type or an `nr-default` template string",
224+
))
225+
}
226+
}
227+
}
228+
}
229+
}
230+
171231
// An special deserializer is used in order to consider the absence of default as a 'null' Yaml default value.
172232
impl<'de> Deserialize<'de> for YamlFieldsDefinition {
173233
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -198,7 +258,7 @@ impl<'de> Deserialize<'de> for YamlFieldsDefinition {
198258
Ok(YamlFieldsDefinition {
199259
inner: FieldsDefinition {
200260
required: intermediate_spec.required,
201-
default: intermediate_spec.default,
261+
default: intermediate_spec.default.map(DefaultValue::Value),
202262
},
203263
})
204264
}
@@ -221,7 +281,7 @@ mod tests {
221281
},
222282
};
223283

224-
use super::Fields;
284+
use super::{DefaultValue, Fields};
225285

226286
impl<T> Fields<T>
227287
where
@@ -230,7 +290,7 @@ mod tests {
230290
pub(crate) fn new(required: bool, default: Option<T>, final_value: Option<T>) -> Self {
231291
Self {
232292
required,
233-
default,
293+
default: default.map(DefaultValue::Value),
234294
final_value,
235295
}
236296
}
@@ -246,7 +306,7 @@ mod tests {
246306
Self {
247307
inner: Fields::<String> {
248308
required,
249-
default,
309+
default: default.map(DefaultValue::Value),
250310
final_value,
251311
},
252312
variants,
@@ -257,7 +317,10 @@ mod tests {
257317
impl YamlFieldsDefinition {
258318
pub(crate) fn new(required: bool, default: Option<serde_yaml::Value>) -> Self {
259319
Self {
260-
inner: FieldsDefinition { required, default },
320+
inner: FieldsDefinition {
321+
required,
322+
default: default.map(DefaultValue::Value),
323+
},
261324
}
262325
}
263326
}

0 commit comments

Comments
 (0)