-
Notifications
You must be signed in to change notification settings - Fork 185
Expand file tree
/
Copy pathsorting_bench.rs
More file actions
83 lines (72 loc) · 3.08 KB
/
sorting_bench.rs
File metadata and controls
83 lines (72 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use std::{hint::black_box, path::Path};
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use futures::FutureExt;
use rattler_conda_types::{Channel, MatchSpec};
use rattler_repodata_gateway::sparse::{PackageFormatSelection, SparseRepoData};
use rattler_solve::{resolvo::CondaDependencyProvider, ChannelPriority};
use resolvo::SolverCache;
fn bench_sort(c: &mut Criterion, sparse_repo_data: &SparseRepoData, spec: &str) {
let match_spec =
MatchSpec::from_str(spec, rattler_conda_types::ParseStrictness::Lenient).unwrap();
let package_name = match_spec.name.as_exact().unwrap().clone();
let repodata = SparseRepoData::load_records_recursive(
[sparse_repo_data],
[package_name.clone()],
None,
PackageFormatSelection::default(),
)
.expect("failed to load records");
// Construct a cache
c.bench_function(&format!("sort {spec}"), |b| {
// Get the candidates for the package
b.iter_batched(
|| (package_name.clone(), match_spec.clone()),
|(package_name, match_spec)| {
// Construct dependency provider
let dependency_provider = CondaDependencyProvider::new(
repodata.iter().map(|r| r.iter().collect()),
&[],
&[],
&[],
std::slice::from_ref(&match_spec),
None,
ChannelPriority::default(),
None,
None,
rattler_solve::SolveStrategy::Highest,
Vec::new(),
&[],
)
.expect("failed to create dependency provider");
let name = dependency_provider.pool.intern_package_name(&package_name);
let version_set = dependency_provider
.pool
.intern_version_set(name, match_spec.into_nameless().1.into());
let cache = SolverCache::new(dependency_provider);
let deps = cache
.get_or_cache_sorted_candidates(version_set.into())
.now_or_never()
.expect("failed to get candidates")
.expect("solver requested cancellation");
black_box(deps);
},
BatchSize::SmallInput,
);
});
}
fn criterion_benchmark(c: &mut Criterion) {
let channel_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.join("test-data")
.join("channels")
.join("conda-forge");
let repodata_json_path = channel_path.join("linux-64").join("repodata.json");
let channel = Channel::from_directory(&channel_path);
let sparse_repo_data = SparseRepoData::from_file(channel, "linux-64", repodata_json_path, None)
.expect("failed to load sparse repodata");
bench_sort(c, &sparse_repo_data, "pytorch");
bench_sort(c, &sparse_repo_data, "python");
bench_sort(c, &sparse_repo_data, "tensorflow");
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);