Skip to content

Commit 0573ff5

Browse files
authored
impl(auth): make mds format an enum (#3792)
Towards #3449
1 parent 80b732a commit 0573ff5

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

src/auth/integration-tests/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ pub mod unstable {
531531
use super::*;
532532
use auth::credentials::idtoken::{
533533
Builder as IDTokenCredentialBuilder, impersonated::Builder as ImpersonatedIDTokenBuilder,
534-
mds::Builder as IDTokenMDSBuilder,
534+
mds::Builder as IDTokenMDSBuilder, mds::Format,
535535
service_account::Builder as ServiceAccountIDTokenBuilder,
536536
verifier::Builder as VerifierBuilder,
537537
};
@@ -554,7 +554,7 @@ pub mod unstable {
554554

555555
// Only works when running on an env that has MDS.
556556
let id_token_creds = IDTokenMDSBuilder::new(audience)
557-
.with_format("full")
557+
.with_format(Format::Full)
558558
.build()
559559
.expect("failed to create id token credentials");
560560

src/auth/src/credentials/idtoken.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ fn build_id_token_credentials(
225225
match json {
226226
None => {
227227
// TODO(#3587): pass context that is being built from ADC flow.
228-
mds::Builder::new(audience).with_format("full").build()
228+
mds::Builder::new(audience)
229+
.with_format(mds::Format::Full)
230+
.build()
229231
}
230232
Some(json) => {
231233
let cred_type = extract_credential_type(&json)?;

src/auth/src/credentials/idtoken/mds.rs

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,32 @@ where
106106
}
107107
}
108108

109+
/// Specifies what assertions are included in ID Tokens fetched from the Metadata Service.
110+
#[derive(Debug, Clone)]
111+
pub enum Format {
112+
/// Omit project and instance details from the payload. It's the default value.
113+
Standard,
114+
/// Include project and instance details in the payload.
115+
Full,
116+
/// Use this variant to handle new values that are not yet known to this library.
117+
UnknownValue(String),
118+
}
119+
120+
impl Format {
121+
fn as_str(&self) -> &str {
122+
match self {
123+
Format::Standard => "standard",
124+
Format::Full => "full",
125+
Format::UnknownValue(value) => value.as_str(),
126+
}
127+
}
128+
}
129+
109130
/// Creates [`IDTokenCredentials`] instances that fetch ID tokens from the
110131
/// metadata service.
111132
pub struct Builder {
112133
endpoint: Option<String>,
113-
format: Option<String>,
134+
format: Option<Format>,
114135
licenses: Option<String>,
115136
target_audience: String,
116137
}
@@ -146,8 +167,8 @@ impl Builder {
146167
/// from the payload. The default value is `standard`.
147168
///
148169
/// [format]: https://cloud.google.com/compute/docs/instances/verifying-instance-identity#token_format
149-
pub fn with_format<S: Into<String>>(mut self, format: S) -> Self {
150-
self.format = Some(format.into());
170+
pub fn with_format(mut self, format: Format) -> Self {
171+
self.format = Some(format);
151172
self
152173
}
153174

@@ -204,7 +225,7 @@ impl Builder {
204225
#[derive(Debug, Clone, Default)]
205226
struct MDSTokenProvider {
206227
endpoint: String,
207-
format: Option<String>,
228+
format: Option<Format>,
208229
licenses: Option<String>,
209230
target_audience: String,
210231
}
@@ -221,8 +242,9 @@ impl TokenProvider for MDSTokenProvider {
221242
HeaderValue::from_static(METADATA_FLAVOR_VALUE),
222243
)
223244
.query(&[("audience", audience)]);
245+
224246
let request = self.format.iter().fold(request, |builder, format| {
225-
builder.query(&[("format", format)])
247+
builder.query(&[("format", format.as_str())])
226248
});
227249
let request = self.licenses.iter().fold(request, |builder, licenses| {
228250
builder.query(&[("licenses", licenses)])
@@ -258,21 +280,25 @@ mod tests {
258280
use reqwest::StatusCode;
259281
use scoped_env::ScopedEnv;
260282
use serial_test::{parallel, serial};
283+
use test_case::test_case;
261284

262285
type TestResult = anyhow::Result<()>;
263286

264287
#[tokio::test]
288+
#[test_case(Format::Standard)]
289+
#[test_case(Format::Full)]
290+
#[test_case(Format::UnknownValue("minimal".to_string()))]
265291
#[parallel]
266-
async fn test_idtoken_builder_build() -> TestResult {
292+
async fn test_idtoken_builder_build(format: Format) -> TestResult {
267293
let server = Server::run();
268294
let audience = "test-audience";
269-
let format = "format";
270295
let token_string = generate_test_id_token(audience);
296+
let format_str = format.as_str().to_string();
271297
server.expect(
272298
Expectation::matching(all_of![
273299
request::path(format!("{MDS_DEFAULT_URI}/identity")),
274300
request::query(url_decoded(contains(("audience", audience)))),
275-
request::query(url_decoded(contains(("format", format)))),
301+
request::query(url_decoded(contains(("format", format_str)))),
276302
request::query(url_decoded(contains(("licenses", "TRUE"))))
277303
])
278304
.respond_with(status_code(200).body(token_string.clone())),

0 commit comments

Comments
 (0)