-
Notifications
You must be signed in to change notification settings - Fork 185
Expand file tree
/
Copy pathsorting.rs
More file actions
104 lines (91 loc) · 3.61 KB
/
sorting.rs
File metadata and controls
104 lines (91 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
//! Tests that the sorting of candidates remains the same.
use std::path::Path;
use futures::FutureExt;
use itertools::Itertools;
use rattler_conda_types::{
Channel, MatchSpec, PackageName, ParseStrictness::Lenient, RepoDataRecord,
};
use rattler_repodata_gateway::sparse::{PackageFormatSelection, SparseRepoData};
use rattler_solve::{resolvo::CondaDependencyProvider, ChannelPriority, SolveStrategy};
use resolvo::{Interner, SolverCache};
use rstest::*;
fn load_repodata(package_name: &PackageName) -> Vec<Vec<RepoDataRecord>> {
let channel_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.join("test-data")
.join("channels")
.join("conda-forge");
let repodata_json_path = channel_path.join("linux-64").join("repodata.json");
let channel = Channel::from_directory(&channel_path);
let sparse_repo_data = SparseRepoData::from_file(channel, "linux-64", repodata_json_path, None)
.expect("failed to load sparse repodata");
SparseRepoData::load_records_recursive(
&[sparse_repo_data],
[package_name.clone()],
None,
PackageFormatSelection::default(),
)
.expect("failed to load records")
}
fn create_sorting_snapshot(package_name: &str, strategy: SolveStrategy) -> String {
let match_spec = MatchSpec::from_str(package_name, Lenient).unwrap();
let package_name = match_spec.name.as_exact().unwrap().clone();
// Load repodata
let repodata = load_repodata(&package_name);
// Construct dependency provider
let dependency_provider = CondaDependencyProvider::new(
repodata.iter().map(|r| r.iter().collect()),
&[],
&[],
&[],
std::slice::from_ref(&match_spec),
None,
ChannelPriority::default(),
None,
None, // min_age
strategy,
Vec::new(), // dependency_overrides
&[], // channel_package_names
)
.expect("failed to create dependency provider");
let name = dependency_provider.pool.intern_package_name(&package_name);
let version_set = dependency_provider
.pool
.intern_version_set(name, match_spec.into_nameless().1.into());
// Construct a cache
let cache = SolverCache::new(dependency_provider);
// Get the candidates for the package
let sorted_candidates = cache
.get_or_cache_sorted_candidates(version_set.into())
.now_or_never()
.expect("failed to get candidates")
.expect("solver requested cancellation");
sorted_candidates
.iter()
.map(|&candidate| cache.provider().display_solvable(candidate))
.format("\n")
.to_string()
}
#[rstest]
#[case::pytorch("pytorch >=1.12.0", SolveStrategy::Highest)]
#[case::pytorch("pytorch >=1.12.0", SolveStrategy::LowestVersion)]
#[case::pytorch("pytorch >=1.12.0", SolveStrategy::LowestVersionDirect)]
#[case::python("python ~=3.10.*", SolveStrategy::Highest)]
#[case::libuuid("libuuid", SolveStrategy::Highest)]
#[case::abess("abess", SolveStrategy::Highest)]
#[case::libgcc("libgcc-ng", SolveStrategy::Highest)]
#[case::certifi("certifi >=2016.9.26", SolveStrategy::Highest)]
fn test_ordering(#[case] spec: &str, #[case] solve_strategy: SolveStrategy) {
insta::assert_snapshot!(
format!(
"test_ordering_{}_{}",
spec.split_whitespace().next().unwrap_or(spec),
match solve_strategy {
SolveStrategy::Highest => "highest",
SolveStrategy::LowestVersion => "lowest",
SolveStrategy::LowestVersionDirect => "lowest_direct",
}
),
create_sorting_snapshot(spec, solve_strategy)
);
}