Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,25 @@ impl Config {
);
}

// Try Backend.AI environment first
// Check for Backend.AI environment first
if let Some(backendai_cluster) = Self::from_backendai_env() {
tracing::debug!("Using Backend.AI cluster configuration from environment");
let mut config = Self::default();
config
.clusters
.insert("backendai".to_string(), backendai_cluster);
.insert("bai_auto".to_string(), backendai_cluster);
return Ok(config);
}

// Load configuration from standard locations
Self::load_from_standard_locations().await.or_else(|_| {
tracing::debug!("No config file found, using default empty configuration");
Ok(Self::default())
})
}

/// Load configuration from standard locations (helper method)
async fn load_from_standard_locations() -> Result<Self> {
// Try current directory config.yaml
let current_dir_config = PathBuf::from("config.yaml");
if current_dir_config.exists() {
Expand Down Expand Up @@ -308,9 +317,8 @@ impl Config {
}
}

// Finally, try the default path (will create empty config if needed)
tracing::debug!("No config file found, using default empty configuration");
Ok(Self::default())
// No config file found
anyhow::bail!("No configuration file found")
}

pub fn get_cluster(&self, name: &str) -> Option<&Cluster> {
Expand Down Expand Up @@ -585,7 +593,7 @@ pub struct InteractiveConfigUpdate {
pub colors: Option<HashMap<String, String>>,
}

fn expand_tilde(path: &Path) -> PathBuf {
pub fn expand_tilde(path: &Path) -> PathBuf {
if let Some(path_str) = path.to_str() {
if path_str.starts_with("~/") {
if let Ok(home) = std::env::var("HOME") {
Expand Down
65 changes: 52 additions & 13 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use anyhow::Result;
use clap::Parser;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::time::Duration;

use bssh::{
Expand Down Expand Up @@ -98,7 +98,7 @@ async fn main() -> Result<()> {
}

// Determine nodes to execute on
let nodes = resolve_nodes(&cli, &config).await?;
let (nodes, actual_cluster_name) = resolve_nodes(&cli, &config).await?;

if nodes.is_empty() {
anyhow::bail!(
Expand All @@ -124,10 +124,19 @@ async fn main() -> Result<()> {
// Handle remaining commands
match cli.command {
Some(Commands::Ping) => {
// Determine SSH key path: CLI argument takes precedence over config
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};

ping_nodes(
nodes,
cli.parallel,
cli.identity.as_deref(),
key_path.as_deref(),
strict_mode,
cli.use_agent,
cli.password,
Expand All @@ -139,10 +148,19 @@ async fn main() -> Result<()> {
destination,
recursive,
}) => {
// Determine SSH key path: CLI argument takes precedence over config
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};

let params = FileTransferParams {
nodes,
max_parallel: cli.parallel,
key_path: cli.identity.as_deref(),
key_path: key_path.as_deref(),
strict_mode,
use_agent: cli.use_agent,
use_password: cli.password,
Expand All @@ -155,10 +173,19 @@ async fn main() -> Result<()> {
destination,
recursive,
}) => {
// Determine SSH key path: CLI argument takes precedence over config
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};

let params = FileTransferParams {
nodes,
max_parallel: cli.parallel,
key_path: cli.identity.as_deref(),
key_path: key_path.as_deref(),
strict_mode,
use_agent: cli.use_agent,
use_password: cli.password,
Expand Down Expand Up @@ -238,14 +265,23 @@ async fn main() -> Result<()> {
let timeout = if cli.timeout > 0 {
Some(cli.timeout)
} else {
config.get_timeout(cli.cluster.as_deref())
config.get_timeout(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
};

// Determine SSH key path: CLI argument takes precedence over config
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};

let params = ExecuteCommandParams {
nodes,
command: &command,
max_parallel: cli.parallel,
key_path: cli.identity.as_deref(),
key_path: key_path.as_deref(),
verbose: cli.verbose > 0,
strict_mode,
use_agent: cli.use_agent,
Expand All @@ -258,8 +294,9 @@ async fn main() -> Result<()> {
}
}

async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<Vec<Node>> {
async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<(Vec<Node>, Option<String>)> {
let mut nodes = Vec::new();
let mut cluster_name = None;

if let Some(hosts) = &cli.hosts {
// Parse hosts from CLI
Expand All @@ -270,16 +307,18 @@ async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<Vec<Node>> {
nodes.push(node);
}
}
} else if let Some(cluster_name) = &cli.cluster {
} else if let Some(cli_cluster_name) = &cli.cluster {
// Get nodes from cluster configuration
nodes = config.resolve_nodes(cluster_name)?;
nodes = config.resolve_nodes(cli_cluster_name)?;
cluster_name = Some(cli_cluster_name.clone());
} else {
// Check if Backend.AI environment is detected (automatic cluster)
if config.clusters.contains_key("backendai") {
if config.clusters.contains_key("bai_auto") {
// Automatically use Backend.AI cluster when no explicit cluster is specified
nodes = config.resolve_nodes("backendai")?;
nodes = config.resolve_nodes("bai_auto")?;
cluster_name = Some("bai_auto".to_string());
}
}

Ok(nodes)
Ok((nodes, cluster_name))
}
20 changes: 10 additions & 10 deletions tests/backendai_env_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ async fn test_backendai_env_auto_detection() {
.await
.expect("Config should load with Backend.AI env");

// Check that backendai cluster was created
assert!(config.clusters.contains_key("backendai"));
// Check that bai_auto cluster was created
assert!(config.clusters.contains_key("bai_auto"));

// Get the backendai cluster
let cluster = config.clusters.get("backendai").unwrap();
// Get the bai_auto cluster
let cluster = config.clusters.get("bai_auto").unwrap();

// Verify nodes were parsed correctly
assert_eq!(cluster.nodes.len(), 3);

// Resolve nodes for the backendai cluster
// Resolve nodes for the bai_auto cluster
let nodes = config
.resolve_nodes("backendai")
.expect("Should resolve backendai nodes");
.resolve_nodes("bai_auto")
.expect("Should resolve bai_auto nodes");
assert_eq!(nodes.len(), 3);

// Check node details
Expand Down Expand Up @@ -114,11 +114,11 @@ async fn test_backendai_env_with_single_host() {
.await
.expect("Config should load");

// Verify backendai cluster exists
assert!(config.clusters.contains_key("backendai"));
// Verify bai_auto cluster exists
assert!(config.clusters.contains_key("bai_auto"));

let nodes = config
.resolve_nodes("backendai")
.resolve_nodes("bai_auto")
.expect("Should resolve nodes");
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].host, "single-node.ai");
Expand Down