Skip to content

Commit ec359a9

Browse files
committed
fix: use cluster SSH key config and simplify BACKENDAI environment handling
1 parent 4431675 commit ec359a9

2 files changed

Lines changed: 66 additions & 19 deletions

File tree

src/config.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,25 @@ impl Config {
259259
);
260260
}
261261

262-
// Try Backend.AI environment first
262+
// Check for Backend.AI environment first
263263
if let Some(backendai_cluster) = Self::from_backendai_env() {
264264
tracing::debug!("Using Backend.AI cluster configuration from environment");
265265
let mut config = Self::default();
266266
config
267267
.clusters
268-
.insert("backendai".to_string(), backendai_cluster);
268+
.insert("bai_auto".to_string(), backendai_cluster);
269269
return Ok(config);
270270
}
271271

272+
// Load configuration from standard locations
273+
Self::load_from_standard_locations().await.or_else(|_| {
274+
tracing::debug!("No config file found, using default empty configuration");
275+
Ok(Self::default())
276+
})
277+
}
278+
279+
/// Load configuration from standard locations (helper method)
280+
async fn load_from_standard_locations() -> Result<Self> {
272281
// Try current directory config.yaml
273282
let current_dir_config = PathBuf::from("config.yaml");
274283
if current_dir_config.exists() {
@@ -308,9 +317,8 @@ impl Config {
308317
}
309318
}
310319

311-
// Finally, try the default path (will create empty config if needed)
312-
tracing::debug!("No config file found, using default empty configuration");
313-
Ok(Self::default())
320+
// No config file found
321+
anyhow::bail!("No configuration file found")
314322
}
315323

316324
pub fn get_cluster(&self, name: &str) -> Option<&Cluster> {
@@ -585,7 +593,7 @@ pub struct InteractiveConfigUpdate {
585593
pub colors: Option<HashMap<String, String>>,
586594
}
587595

588-
fn expand_tilde(path: &Path) -> PathBuf {
596+
pub fn expand_tilde(path: &Path) -> PathBuf {
589597
if let Some(path_str) = path.to_str() {
590598
if path_str.starts_with("~/") {
591599
if let Ok(home) = std::env::var("HOME") {

src/main.rs

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
use anyhow::Result;
1616
use clap::Parser;
17-
use std::path::PathBuf;
17+
use std::path::{Path, PathBuf};
1818
use std::time::Duration;
1919

2020
use bssh::{
@@ -98,7 +98,7 @@ async fn main() -> Result<()> {
9898
}
9999

100100
// Determine nodes to execute on
101-
let nodes = resolve_nodes(&cli, &config).await?;
101+
let (nodes, actual_cluster_name) = resolve_nodes(&cli, &config).await?;
102102

103103
if nodes.is_empty() {
104104
anyhow::bail!(
@@ -124,10 +124,19 @@ async fn main() -> Result<()> {
124124
// Handle remaining commands
125125
match cli.command {
126126
Some(Commands::Ping) => {
127+
// Determine SSH key path: CLI argument takes precedence over config
128+
let key_path = if let Some(identity) = &cli.identity {
129+
Some(identity.clone())
130+
} else {
131+
config
132+
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
133+
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
134+
};
135+
127136
ping_nodes(
128137
nodes,
129138
cli.parallel,
130-
cli.identity.as_deref(),
139+
key_path.as_deref(),
131140
strict_mode,
132141
cli.use_agent,
133142
cli.password,
@@ -139,10 +148,19 @@ async fn main() -> Result<()> {
139148
destination,
140149
recursive,
141150
}) => {
151+
// Determine SSH key path: CLI argument takes precedence over config
152+
let key_path = if let Some(identity) = &cli.identity {
153+
Some(identity.clone())
154+
} else {
155+
config
156+
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
157+
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
158+
};
159+
142160
let params = FileTransferParams {
143161
nodes,
144162
max_parallel: cli.parallel,
145-
key_path: cli.identity.as_deref(),
163+
key_path: key_path.as_deref(),
146164
strict_mode,
147165
use_agent: cli.use_agent,
148166
use_password: cli.password,
@@ -155,10 +173,19 @@ async fn main() -> Result<()> {
155173
destination,
156174
recursive,
157175
}) => {
176+
// Determine SSH key path: CLI argument takes precedence over config
177+
let key_path = if let Some(identity) = &cli.identity {
178+
Some(identity.clone())
179+
} else {
180+
config
181+
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
182+
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
183+
};
184+
158185
let params = FileTransferParams {
159186
nodes,
160187
max_parallel: cli.parallel,
161-
key_path: cli.identity.as_deref(),
188+
key_path: key_path.as_deref(),
162189
strict_mode,
163190
use_agent: cli.use_agent,
164191
use_password: cli.password,
@@ -238,14 +265,23 @@ async fn main() -> Result<()> {
238265
let timeout = if cli.timeout > 0 {
239266
Some(cli.timeout)
240267
} else {
241-
config.get_timeout(cli.cluster.as_deref())
268+
config.get_timeout(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
269+
};
270+
271+
// Determine SSH key path: CLI argument takes precedence over config
272+
let key_path = if let Some(identity) = &cli.identity {
273+
Some(identity.clone())
274+
} else {
275+
config
276+
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
277+
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
242278
};
243279

244280
let params = ExecuteCommandParams {
245281
nodes,
246282
command: &command,
247283
max_parallel: cli.parallel,
248-
key_path: cli.identity.as_deref(),
284+
key_path: key_path.as_deref(),
249285
verbose: cli.verbose > 0,
250286
strict_mode,
251287
use_agent: cli.use_agent,
@@ -258,8 +294,9 @@ async fn main() -> Result<()> {
258294
}
259295
}
260296

261-
async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<Vec<Node>> {
297+
async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<(Vec<Node>, Option<String>)> {
262298
let mut nodes = Vec::new();
299+
let mut cluster_name = None;
263300

264301
if let Some(hosts) = &cli.hosts {
265302
// Parse hosts from CLI
@@ -270,16 +307,18 @@ async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<Vec<Node>> {
270307
nodes.push(node);
271308
}
272309
}
273-
} else if let Some(cluster_name) = &cli.cluster {
310+
} else if let Some(cli_cluster_name) = &cli.cluster {
274311
// Get nodes from cluster configuration
275-
nodes = config.resolve_nodes(cluster_name)?;
312+
nodes = config.resolve_nodes(cli_cluster_name)?;
313+
cluster_name = Some(cli_cluster_name.clone());
276314
} else {
277315
// Check if Backend.AI environment is detected (automatic cluster)
278-
if config.clusters.contains_key("backendai") {
316+
if config.clusters.contains_key("bai_auto") {
279317
// Automatically use Backend.AI cluster when no explicit cluster is specified
280-
nodes = config.resolve_nodes("backendai")?;
318+
nodes = config.resolve_nodes("bai_auto")?;
319+
cluster_name = Some("bai_auto".to_string());
281320
}
282321
}
283322

284-
Ok(nodes)
323+
Ok((nodes, cluster_name))
285324
}

0 commit comments

Comments
 (0)