Skip to content

Commit 47bde07

Browse files
authored
feat: add solve command and create platform option (#2528)
1 parent cac271f commit 47bde07

4 files changed

Lines changed: 284 additions & 19 deletions

File tree

crates/rattler-bin/src/commands/create.rs

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use rattler_solve::{
2424
libsolv_c::{self},
2525
resolvo,
2626
};
27+
use rattler_virtual_packages::{VirtualPackageOverrides, VirtualPackages};
2728

2829
use crate::{
2930
commands::progress::{wrap_in_async_progress, wrap_in_progress},
@@ -47,13 +48,13 @@ pub struct Opt {
4748
#[clap(required = true)]
4849
specs: Vec<String>,
4950

50-
/// Simulute command without installation
51+
/// Simulate command without installation
5152
#[clap(long)]
5253
dry_run: bool,
5354

54-
/// Target platform (e.g., linux-64, osx-arm64)
55-
#[clap(long)]
56-
platform: Option<String>,
55+
/// The platform to create the environment for.
56+
#[clap(long, default_value_t = Platform::current())]
57+
platform: Platform,
5758

5859
#[clap(long)]
5960
virtual_package: Option<Vec<String>>,
@@ -130,14 +131,9 @@ pub async fn create(opt: Opt) -> miette::Result<()> {
130131
// Make the target prefix absolute
131132
let target_prefix = std::path::absolute(opt.target_prefix).into_diagnostic()?;
132133

133-
// Determine the platform we're going to install for
134-
let install_platform = if let Some(platform) = opt.platform {
135-
Platform::from_str(&platform).into_diagnostic()?
136-
} else {
137-
Platform::current()
138-
};
134+
let install_platform = opt.platform;
139135

140-
println!("Installing for platform: {install_platform:?}");
136+
println!("Installing for platform: {install_platform}");
141137

142138
// Parse the specs from the command line. We do this explicitly instead of allow
143139
// clap to deal with this because we need to parse the `channel_config` when
@@ -241,15 +237,11 @@ pub async fn create(opt: Opt) -> miette::Result<()> {
241237
})
242238
.collect::<miette::Result<Vec<_>>>()?)
243239
} else {
244-
rattler_virtual_packages::VirtualPackage::detect(
245-
&rattler_virtual_packages::VirtualPackageOverrides::from_env(),
240+
VirtualPackages::detect_for_platform(
241+
install_platform,
242+
&VirtualPackageOverrides::from_env(),
246243
)
247-
.map(|vpkgs| {
248-
vpkgs
249-
.iter()
250-
.map(|vpkg| GenericVirtualPackage::from(vpkg.clone()))
251-
.collect::<Vec<_>>()
252-
})
244+
.map(|vpkgs| vpkgs.into_generic_virtual_packages().collect::<Vec<_>>())
253245
.into_diagnostic()
254246
}
255247
})?;

crates/rattler-bin/src/commands/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ pub mod progress;
1515
pub mod run;
1616
pub mod search;
1717
pub mod shell_hook;
18+
pub mod solve;
1819
pub mod virtual_packages;
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
use std::{
2+
collections::HashMap,
3+
env,
4+
str::FromStr,
5+
time::{Duration, Instant},
6+
};
7+
8+
use clap::ValueEnum;
9+
use itertools::Itertools;
10+
use miette::{Context, IntoDiagnostic};
11+
use rattler::{default_cache_dir, package_cache::PackageCache};
12+
use rattler_conda_types::{
13+
Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, Matches, PackageName,
14+
ParseMatchSpecOptions, Platform, RepoDataRecord, Version,
15+
};
16+
use rattler_repodata_gateway::{Gateway, RepoData, SourceConfig};
17+
use rattler_solve::{
18+
SolverImpl, SolverTask,
19+
libsolv_c::{self},
20+
resolvo,
21+
};
22+
use rattler_virtual_packages::{VirtualPackageOverrides, VirtualPackages};
23+
24+
use crate::{
25+
commands::progress::{wrap_in_async_progress, wrap_in_progress},
26+
exclude_newer::ExcludeNewer,
27+
};
28+
29+
/// Solve a conda environment without installing it.
30+
///
31+
/// Resolves the specified package specs for a target platform and prints the
32+
/// resulting package set.
33+
#[derive(Debug, clap::Parser)]
34+
pub struct Opt {
35+
/// Channel to search for packages.
36+
///
37+
/// Example: -c conda-forge -c main
38+
#[clap(short, long = "channel")]
39+
channels: Option<Vec<String>>,
40+
41+
/// Package specs to solve.
42+
#[clap(required = true)]
43+
specs: Vec<String>,
44+
45+
/// The platform to solve the environment for.
46+
#[clap(long, default_value_t = Platform::current())]
47+
platform: Platform,
48+
49+
/// Virtual packages to use for solving, e.g. __glibc=2.28.
50+
#[clap(long)]
51+
virtual_package: Option<Vec<String>>,
52+
53+
/// SAT Solver backend to use.
54+
#[clap(long)]
55+
solver: Option<Solver>,
56+
57+
/// Request solver timeout in milliseconds.
58+
#[clap(long)]
59+
timeout: Option<u64>,
60+
61+
/// Solver strategy to use.
62+
#[clap(long)]
63+
strategy: Option<SolveStrategy>,
64+
65+
/// Only include dependencies of package specs in the output.
66+
#[clap(long, group = "deps_mode")]
67+
only_deps: bool,
68+
69+
/// Only include package specifications without dependencies in the output.
70+
#[clap(long, group = "deps_mode")]
71+
no_deps: bool,
72+
73+
/// Exclude packages that have been published after the specified timestamp.
74+
/// Can be specified as a timestamp (e.g., "2006-12-02T02:07:43Z") or as a date (e.g., "2006-12-02").
75+
/// When using a date, packages from the entire day are included.
76+
#[clap(long)]
77+
exclude_newer: Option<ExcludeNewer>,
78+
}
79+
80+
#[derive(Debug, Clone, Copy, ValueEnum)]
81+
pub enum SolveStrategy {
82+
/// Resolve the highest compatible version for every package.
83+
Highest,
84+
85+
/// Resolve the lowest compatible version for every package.
86+
Lowest,
87+
88+
/// Resolve the lowest compatible version for direct dependencies but the
89+
/// highest compatible for transitive dependencies.
90+
LowestDirect,
91+
}
92+
93+
#[derive(Default, Debug, Clone, Copy, ValueEnum)]
94+
pub enum Solver {
95+
#[default]
96+
Resolvo,
97+
#[value(name = "libsolv")]
98+
LibSolv,
99+
}
100+
101+
impl From<SolveStrategy> for rattler_solve::SolveStrategy {
102+
fn from(value: SolveStrategy) -> Self {
103+
match value {
104+
SolveStrategy::Highest => rattler_solve::SolveStrategy::Highest,
105+
SolveStrategy::Lowest => rattler_solve::SolveStrategy::LowestVersion,
106+
SolveStrategy::LowestDirect => rattler_solve::SolveStrategy::LowestVersionDirect,
107+
}
108+
}
109+
}
110+
111+
pub async fn solve(opt: Opt) -> miette::Result<()> {
112+
let channel_config =
113+
ChannelConfig::default_with_root_dir(env::current_dir().into_diagnostic()?);
114+
115+
println!("Solving for platform: {}", opt.platform);
116+
117+
let match_spec_options = ParseMatchSpecOptions::strict()
118+
.with_extras(true)
119+
.with_conditionals(true)
120+
.with_flags(true);
121+
122+
let specs = opt
123+
.specs
124+
.iter()
125+
.map(|spec| MatchSpec::from_str(spec, match_spec_options))
126+
.collect::<Result<Vec<_>, _>>()
127+
.into_diagnostic()?;
128+
129+
let cache_dir = default_cache_dir()
130+
.map_err(|e| miette::miette!("could not determine default cache directory: {}", e))?;
131+
rattler_cache::ensure_cache_dir(&cache_dir)
132+
.map_err(|e| miette::miette!("could not create cache directory: {}", e))?;
133+
134+
let channels = opt
135+
.channels
136+
.unwrap_or_else(|| vec![String::from("conda-forge")])
137+
.into_iter()
138+
.map(|channel_str| Channel::from_str(channel_str, &channel_config))
139+
.collect::<Result<Vec<_>, _>>()
140+
.into_diagnostic()?;
141+
142+
let download_client = super::client::create_client_with_middleware()?;
143+
144+
let gateway = Gateway::builder()
145+
.with_cache_dir(cache_dir.join(rattler_cache::REPODATA_CACHE_DIR))
146+
.with_package_cache(PackageCache::new(
147+
cache_dir.join(rattler_cache::PACKAGE_CACHE_DIR),
148+
))
149+
.with_client(download_client)
150+
.with_channel_config(rattler_repodata_gateway::ChannelConfig {
151+
default: SourceConfig {
152+
sharded_enabled: true,
153+
..SourceConfig::default()
154+
},
155+
per_channel: HashMap::new(),
156+
})
157+
.finish();
158+
159+
let start_load_repo_data = Instant::now();
160+
let repo_data = wrap_in_async_progress(
161+
"loading repodata",
162+
gateway
163+
.query(channels, [opt.platform, Platform::NoArch], specs.clone())
164+
.recursive(true),
165+
)
166+
.await
167+
.into_diagnostic()
168+
.context("failed to load repodata")?;
169+
170+
let total_records: usize = repo_data.iter().map(RepoData::len).sum();
171+
println!(
172+
"Loaded {} records in {:?}",
173+
total_records,
174+
start_load_repo_data.elapsed()
175+
);
176+
177+
let virtual_packages = wrap_in_progress("determining virtual packages", || {
178+
if let Some(virtual_packages) = &opt.virtual_package {
179+
parse_virtual_packages(virtual_packages)
180+
} else {
181+
VirtualPackages::detect_for_platform(opt.platform, &VirtualPackageOverrides::from_env())
182+
.map(|vpkgs| vpkgs.into_generic_virtual_packages().collect::<Vec<_>>())
183+
.into_diagnostic()
184+
}
185+
})?;
186+
187+
println!(
188+
"Virtual packages:\n{}\n",
189+
virtual_packages
190+
.iter()
191+
.format_with("\n", |i, f| f(&format_args!(" - {i}",)))
192+
);
193+
194+
let solver_task = SolverTask {
195+
virtual_packages,
196+
specs: specs.clone(),
197+
timeout: opt.timeout.map(Duration::from_millis),
198+
strategy: opt.strategy.map_or_else(Default::default, Into::into),
199+
exclude_newer: opt.exclude_newer.map(Into::into),
200+
..SolverTask::from_iter(&repo_data)
201+
};
202+
203+
let solver_result = wrap_in_progress("solving", || match opt.solver.unwrap_or_default() {
204+
Solver::Resolvo => resolvo::Solver.solve(solver_task),
205+
Solver::LibSolv => libsolv_c::Solver.solve(solver_task),
206+
})
207+
.into_diagnostic()?;
208+
209+
let mut solved_packages: Vec<RepoDataRecord> = solver_result.records;
210+
211+
if opt.no_deps {
212+
solved_packages.retain(|r| specs.iter().any(|s| s.matches(&r.package_record)));
213+
} else if opt.only_deps {
214+
solved_packages.retain(|r| !specs.iter().any(|s| s.matches(&r.package_record)));
215+
}
216+
217+
if solved_packages.is_empty() {
218+
println!("No packages solved");
219+
} else {
220+
println!("Solved {} packages:", solved_packages.len());
221+
print_records(&solved_packages, solver_result.extras);
222+
}
223+
224+
Ok(())
225+
}
226+
227+
fn parse_virtual_packages(
228+
virtual_packages: &[String],
229+
) -> miette::Result<Vec<GenericVirtualPackage>> {
230+
virtual_packages
231+
.iter()
232+
.map(|virt_pkg| {
233+
let elems = virt_pkg.split('=').collect::<Vec<&str>>();
234+
Ok(GenericVirtualPackage {
235+
name: elems[0].try_into().into_diagnostic()?,
236+
version: elems
237+
.get(1)
238+
.map_or(Version::from_str("0"), |s| Version::from_str(s))
239+
.into_diagnostic()?,
240+
build_string: (*elems.get(2).unwrap_or(&"")).to_string(),
241+
})
242+
})
243+
.collect::<miette::Result<Vec<_>>>()
244+
}
245+
246+
fn print_records(records: &[RepoDataRecord], features: HashMap<PackageName, Vec<String>>) {
247+
for record in records {
248+
let direct_url_print = record.channel.clone().unwrap_or_default();
249+
if let Some(features) = features.get(&record.package_record.name) {
250+
println!(
251+
"{}[{}] {} {} {} {}",
252+
record.package_record.name.as_normalized(),
253+
features.join(", "),
254+
record.package_record.version,
255+
record.package_record.build,
256+
record.package_record.subdir,
257+
direct_url_print,
258+
);
259+
} else {
260+
println!(
261+
"{} {} {} {} {}",
262+
record.package_record.name.as_normalized(),
263+
record.package_record.version,
264+
record.package_record.build,
265+
record.package_record.subdir,
266+
direct_url_print,
267+
);
268+
}
269+
}
270+
}

crates/rattler-bin/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ enum Command {
4848
FetchFile(commands::fetch_file::Opt),
4949
Inspect(commands::inspect::Opt),
5050
Search(commands::search::Opt),
51+
Solve(commands::solve::Opt),
5152
ShellHook(commands::shell_hook::Opt),
5253
VirtualPackages(commands::virtual_packages::Opt),
5354
InstallMenu(commands::menu::InstallOpt),
@@ -114,6 +115,7 @@ async fn async_main() -> miette::Result<()> {
114115
Command::FetchFile(opts) => commands::fetch_file::fetch_file(opts).await,
115116
Command::Inspect(opts) => commands::inspect::inspect(opts).await,
116117
Command::Search(opts) => commands::search::search(opts).await,
118+
Command::Solve(opts) => commands::solve::solve(opts).await,
117119
Command::List(opts) => commands::list::list(opts).await,
118120
Command::ShellHook(opts) => commands::shell_hook::shell_hook(opts).await,
119121
Command::VirtualPackages(opts) => commands::virtual_packages::virtual_packages(opts),

0 commit comments

Comments
 (0)