Skip to content

Commit ab66759

Browse files
committed
add wizard to cscs login command
1 parent a41d583 commit ab66759

4 files changed

Lines changed: 91 additions & 25 deletions

File tree

coman/src/app/model.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ use crate::{
2424
login_popup::LoginPopup, resource_usage::ResourceUsage, system_select_popup::SystemSelectPopup,
2525
workload_details::WorkloadDetails, workload_list::WorkloadList, workload_log::WorkloadLog,
2626
},
27+
config::Config,
2728
cscs::{
28-
handlers::{cscs_login, cscs_system_set},
29+
handlers::{cscs_login, cscs_system_set, get_available_compute_platforms},
2930
ports::{BackgroundTask, JobLogAction, JobResourceUsageAction},
3031
},
3132
trace_dbg,
@@ -509,7 +510,42 @@ where
509510
let error_tx = self.error_tx.clone();
510511
tokio::spawn(async move {
511512
match cscs_login(client_id, client_secret).await {
512-
Ok(_) => event_tx.send(UserEvent::Cscs(CscsEvent::LoggedIn)).await.unwrap(),
513+
Ok(_) => {
514+
let mut config = match Config::new() {
515+
Ok(config) => config,
516+
Err(e) => {
517+
error_tx
518+
.send(format!(
519+
"{:?}",
520+
Err::<(), Report>(e).wrap_err("Couldn't create config object")
521+
))
522+
.await
523+
.unwrap();
524+
return;
525+
}
526+
};
527+
let source = config.value_source("cscs.current_platform");
528+
if !source.1 && !source.2 {
529+
// don't override platform if it's already set
530+
if let Ok(available_platforms) = get_available_compute_platforms().await
531+
&& !available_platforms.is_empty()
532+
&& let Err(e) = config.set(
533+
"cscs.current_platform",
534+
available_platforms[0].to_string(),
535+
true,
536+
) {
537+
error_tx
538+
.send(format!(
539+
"{:?}",
540+
Err::<(), Report>(e).wrap_err("Couldn't set currnt platform")
541+
))
542+
.await
543+
.unwrap();
544+
return;
545+
}
546+
}
547+
event_tx.send(UserEvent::Cscs(CscsEvent::LoggedIn)).await.unwrap()
548+
}
513549
Err(e) => error_tx
514550
.send(format!(
515551
"{:?}",

coman/src/config.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use color_eyre::{
99
use directories::ProjectDirs;
1010
use lazy_static::lazy_static;
1111
use serde::{Deserialize, Serialize};
12-
use strum_macros::{EnumIter, EnumString, VariantNames};
12+
use strum_macros::{EnumIter, EnumString, VariantArray, VariantNames};
1313
use toml_edit::DocumentMut;
1414

1515
const DEFAULT_CONFIG_TOML: &str = include_str!("../.config/config.toml");
@@ -33,7 +33,9 @@ pub struct SystemDescription {
3333
pub architecture: Vec<String>,
3434
}
3535

36-
#[derive(Clone, Debug, Serialize, Deserialize, Default, strum::Display, EnumString, VariantNames, EnumIter)]
36+
#[derive(
37+
Clone, Debug, Serialize, Deserialize, Default, strum::Display, EnumString, VariantNames, VariantArray, EnumIter,
38+
)]
3739
#[strum(serialize_all = "lowercase")]
3840
#[allow(clippy::upper_case_acronyms)]
3941
pub enum ComputePlatform {
@@ -333,15 +335,15 @@ impl Config {
333335
}
334336

335337
// Returns tuple of bool saying whether a values is set in (default, global, project local) config
336-
pub fn value_source(&self, key_path: &str) -> Result<(bool, bool, bool)> {
337-
Ok((
338+
pub fn value_source(&self, key_path: &str) -> (bool, bool, bool) {
339+
(
338340
self.default_layer.get(key_path).is_some(),
339341
self.global_layer.get(key_path).unwrap_or_default().is_some(),
340342
self.project_layer
341343
.as_ref()
342344
.map(|l| l.get(key_path).unwrap_or_default().is_some())
343345
.unwrap_or(false),
344-
))
346+
)
345347
}
346348
}
347349

coman/src/cscs/cli.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,24 @@ use bytesize::ByteSize;
88
use color_eyre::{Result, eyre::Context};
99
use eyre::eyre;
1010
use futures::StreamExt;
11-
use inquire::{Password, Text};
11+
use inquire::{Password, Select, Text};
1212
use itertools::Itertools;
1313
use reqwest::Url;
14+
use strum::VariantArray;
1415
use tokio::{
1516
fs::File,
1617
io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader},
1718
};
1819

1920
use crate::{
2021
cli::app::JobIdOrName,
21-
config::ComputePlatform,
22+
config::{ComputePlatform, Config},
2223
cscs::{
2324
api_client::{client::JobStartOptions, types::JobStatus},
2425
handlers::{
2526
cscs_file_delete, cscs_file_download, cscs_file_list, cscs_file_upload, cscs_job_cancel, cscs_job_details,
2627
cscs_job_list, cscs_job_log, cscs_job_start, cscs_login, cscs_port_forward, cscs_resource_usage,
27-
cscs_system_list, cscs_system_set,
28+
cscs_system_list, cscs_system_set, get_available_compute_platforms,
2829
},
2930
},
3031
};
@@ -39,6 +40,31 @@ pub(crate) async fn cli_cscs_login() -> Result<()> {
3940
}
4041
Err(e) => Err(e).wrap_err("couldn't get acccess token")?,
4142
};
43+
44+
// select compute platform
45+
let mut config = Config::new()?;
46+
47+
let source = config.value_source("cscs.current_platform");
48+
if !source.1 && !source.2 {
49+
let available_platforms: Vec<_> = get_available_compute_platforms()
50+
.await
51+
.unwrap_or(<ComputePlatform as VariantArray>::VARIANTS.to_vec())
52+
.iter()
53+
.map(|c| c.to_string())
54+
.collect();
55+
let platform = Select::new("Compute Platform:", available_platforms).prompt()?;
56+
57+
config.set("cscs.current_platform", platform, true)?;
58+
}
59+
60+
// select cscs account
61+
let source = config.value_source("cscs.account");
62+
if !source.1 && !source.2
63+
&& let Ok(Some(account)) = Text::new("CSCS Account:").prompt_skippable()
64+
&& !account.is_empty()
65+
{
66+
config.set("cscs.account", account, true)?;
67+
}
4268
Ok(())
4369
}
4470
pub(crate) async fn cli_cscs_job_list(system: Option<String>, platform: Option<ComputePlatform>) -> Result<()> {

coman/src/cscs/handlers.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,28 +80,30 @@ async fn get_access_token() -> Result<Secret> {
8080
let token = client_credentials_login(client_id, client_secret).await?;
8181
Ok(token.0)
8282
}
83+
8384
pub(crate) async fn cscs_login(client_id: String, client_secret: String) -> Result<()> {
8485
let client_id_secret = Secret::new(client_id);
8586
store_secret(CLIENT_ID_SECRET_NAME, client_id_secret.clone()).await?;
8687
let client_secret_secret = Secret::new(client_secret);
8788
store_secret(CLIENT_SECRET_SECRET_NAME, client_secret_secret.clone()).await?;
88-
let token = client_credentials_login(client_id_secret, client_secret_secret).await?;
89-
90-
// figure out what platform the user has access to and set it in config
91-
let mut config = Config::new()?;
92-
let source = config.value_source("cscs.current_platform")?;
93-
if source.1 || source.2 {
94-
// don't override setting if it's already provided
95-
return Ok(());
96-
}
97-
for platform in ComputePlatform::iter() {
98-
let api_client = CscsApi::new(token.0.0.clone(), Some(platform.clone())).unwrap();
99-
if (api_client.list_systems().await).is_ok() {
100-
config.set("cscs.current_platform", platform.to_string(), true)?;
101-
break;
89+
client_credentials_login(client_id_secret, client_secret_secret)
90+
.await
91+
.map(|_| ())
92+
}
93+
pub(crate) async fn get_available_compute_platforms() -> Result<Vec<ComputePlatform>> {
94+
match get_access_token().await {
95+
Ok(access_token) => {
96+
let mut platforms = Vec::new();
97+
for platform in ComputePlatform::iter() {
98+
let api_client = CscsApi::new(access_token.0.clone(), Some(platform.clone())).unwrap();
99+
if (api_client.list_systems().await).is_ok() {
100+
platforms.push(platform);
101+
}
102+
}
103+
Ok(platforms)
102104
}
105+
Err(e) => Err(e),
103106
}
104-
Ok(())
105107
}
106108

107109
#[allow(dead_code)]

0 commit comments

Comments
 (0)