Skip to content

Commit 29c36c2

Browse files
committed
set platform on login based on user permissions
1 parent 6786dce commit 29c36c2

2 files changed

Lines changed: 32 additions & 5 deletions

File tree

coman/src/config.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use color_eyre::{
99
use directories::ProjectDirs;
1010
use lazy_static::lazy_static;
1111
use serde::{Deserialize, Serialize};
12-
use strum_macros::{EnumString, VariantNames};
12+
use strum_macros::{EnumIter, EnumString, VariantNames};
1313
use toml_edit::DocumentMut;
1414

1515
const DEFAULT_CONFIG_TOML: &str = include_str!("../.config/config.toml");
@@ -33,7 +33,7 @@ pub struct SystemDescription {
3333
pub architecture: Vec<String>,
3434
}
3535

36-
#[derive(Clone, Debug, Serialize, Deserialize, Default, strum::Display, EnumString, VariantNames)]
36+
#[derive(Clone, Debug, Serialize, Deserialize, Default, strum::Display, EnumString, VariantNames, EnumIter)]
3737
#[strum(serialize_all = "lowercase")]
3838
#[allow(clippy::upper_case_acronyms)]
3939
pub enum ComputePlatform {
@@ -331,6 +331,18 @@ impl Config {
331331
let _cfg: ComanConfig = builder.build()?.try_deserialize().wrap_err("invalid config")?;
332332
Ok(())
333333
}
334+
335+
// Returns tuple of bool saying whether a values is set in (default, global, project local) config
336+
pub fn value_source(&self, key_path: &str) -> Result<(bool, bool, bool)> {
337+
Ok((
338+
self.default_layer.get(key_path).is_some(),
339+
self.global_layer.get(key_path).unwrap_or_default().is_some(),
340+
self.project_layer
341+
.as_ref()
342+
.map(|l| l.get(key_path).unwrap_or_default().is_some())
343+
.unwrap_or(false),
344+
))
345+
}
334346
}
335347

336348
pub fn global_config_layer() -> Result<Layer> {

coman/src/cscs/handlers.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use itertools::Itertools;
1919
use regex::Regex;
2020
use reqwest::Url;
2121
use sha2::{Digest, Sha256};
22+
use strum::IntoEnumIterator;
2223
use tarpc::{client, context, serde_transport, tokio_serde::formats::Bincode};
2324
use tokio::{
2425
fs::File,
@@ -84,9 +85,23 @@ pub(crate) async fn cscs_login(client_id: String, client_secret: String) -> Resu
8485
store_secret(CLIENT_ID_SECRET_NAME, client_id_secret.clone()).await?;
8586
let client_secret_secret = Secret::new(client_secret);
8687
store_secret(CLIENT_SECRET_SECRET_NAME, client_secret_secret.clone()).await?;
87-
client_credentials_login(client_id_secret, client_secret_secret)
88-
.await
89-
.map(|_| ())
88+
let token = client_credentials_login(client_id_secret, client_secret_secret).await?;
89+
90+
// figure out what platform the user has access to and set it in config
91+
let mut config = Config::new()?;
92+
let source = config.value_source("cscs.current_platform")?;
93+
if source.1 || source.2 {
94+
// don't override setting if it's already provided
95+
return Ok(());
96+
}
97+
for platform in ComputePlatform::iter() {
98+
let api_client = CscsApi::new(token.0.0.clone(), Some(platform.clone())).unwrap();
99+
if (api_client.list_systems().await).is_ok() {
100+
config.set("cscs.current_platform", platform.to_string(), true)?;
101+
break;
102+
}
103+
}
104+
Ok(())
90105
}
91106

92107
#[allow(dead_code)]

0 commit comments

Comments
 (0)