Skip to content

Commit a585440

Browse files
Auto merge of #142322 - compiler-errors:perf-instantiate-uwu2, r=<try>
Perf instantiate uwu2 This is horrible
2 parents 8ce2287 + b864867 commit a585440

File tree

5 files changed

+192
-38
lines changed

5 files changed

+192
-38
lines changed

compiler/rustc_infer/src/infer/canonical/instantiate.rs

Lines changed: 169 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
88
99
use rustc_macros::extension;
10-
use rustc_middle::bug;
11-
use rustc_middle::ty::{self, FnMutDelegate, GenericArgKind, TyCtxt, TypeFoldable};
10+
use rustc_middle::ty::{
11+
self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable,
12+
TypeVisitableExt, TypeVisitor,
13+
};
14+
use rustc_type_ir::TypeVisitable;
1215

1316
use crate::infer::canonical::{Canonical, CanonicalVarValues};
1417

@@ -18,11 +21,15 @@ use crate::infer::canonical::{Canonical, CanonicalVarValues};
1821
impl<'tcx, V> Canonical<'tcx, V> {
1922
/// Instantiate the wrapped value, replacing each canonical value
2023
/// with the value given in `var_values`.
21-
fn instantiate(&self, tcx: TyCtxt<'tcx>, var_values: &CanonicalVarValues<'tcx>) -> V
24+
fn instantiate<const UWU: bool>(
25+
&self,
26+
tcx: TyCtxt<'tcx>,
27+
var_values: &CanonicalVarValues<'tcx>,
28+
) -> V
2229
where
2330
V: TypeFoldable<TyCtxt<'tcx>>,
2431
{
25-
self.instantiate_projected(tcx, var_values, |value| value.clone())
32+
self.instantiate_projected::<_, UWU>(tcx, var_values, |value| value.clone())
2633
}
2734

2835
/// Allows one to apply a instantiation to some subset of
@@ -31,7 +38,7 @@ impl<'tcx, V> Canonical<'tcx, V> {
3138
/// variables bound in `self` (usually this extracts from subset
3239
/// of `self`). Apply the instantiation `var_values` to this value
3340
/// V, replacing each of the canonical variables.
34-
fn instantiate_projected<T>(
41+
fn instantiate_projected<T, const UWU: bool>(
3542
&self,
3643
tcx: TyCtxt<'tcx>,
3744
var_values: &CanonicalVarValues<'tcx>,
@@ -42,14 +49,14 @@ impl<'tcx, V> Canonical<'tcx, V> {
4249
{
4350
assert_eq!(self.variables.len(), var_values.len());
4451
let value = projection_fn(&self.value);
45-
instantiate_value(tcx, var_values, value)
52+
instantiate_value_0::<_, UWU>(tcx, var_values, value)
4653
}
4754
}
4855

4956
/// Instantiate the values from `var_values` into `value`. `var_values`
5057
/// must be values for the set of canonical variables that appear in
5158
/// `value`.
52-
pub(super) fn instantiate_value<'tcx, T>(
59+
pub(super) fn instantiate_value_0<'tcx, T, const UWU: bool>(
5360
tcx: TyCtxt<'tcx>,
5461
var_values: &CanonicalVarValues<'tcx>,
5562
value: T,
@@ -58,23 +65,160 @@ where
5865
T: TypeFoldable<TyCtxt<'tcx>>,
5966
{
6067
if var_values.var_values.is_empty() {
61-
value
62-
} else {
63-
let delegate = FnMutDelegate {
64-
regions: &mut |br: ty::BoundRegion| match var_values[br.var].kind() {
65-
GenericArgKind::Lifetime(l) => l,
66-
r => bug!("{:?} is a region but value is {:?}", br, r),
67-
},
68-
types: &mut |bound_ty: ty::BoundTy| match var_values[bound_ty.var].kind() {
69-
GenericArgKind::Type(ty) => ty,
70-
r => bug!("{:?} is a type but value is {:?}", bound_ty, r),
71-
},
72-
consts: &mut |bound_ct: ty::BoundVar| match var_values[bound_ct].kind() {
73-
GenericArgKind::Const(ct) => ct,
74-
c => bug!("{:?} is a const but value is {:?}", bound_ct, c),
75-
},
76-
};
77-
78-
tcx.replace_escaping_bound_vars_uncached(value, delegate)
68+
return value;
69+
}
70+
71+
value.fold_with(&mut BoundVarReplacer::<UWU> {
72+
tcx,
73+
current_index: ty::INNERMOST,
74+
var_values: var_values.var_values,
75+
cache: Default::default(),
76+
})
77+
}
78+
79+
/// Replaces the escaping bound vars (late bound regions or bound types) in a type.
80+
struct BoundVarReplacer<'tcx, const UWU: bool> {
81+
tcx: TyCtxt<'tcx>,
82+
83+
/// As with `RegionFolder`, represents the index of a binder *just outside*
84+
/// the ones we have visited.
85+
current_index: ty::DebruijnIndex,
86+
87+
var_values: ty::GenericArgsRef<'tcx>,
88+
89+
/// This cache only tracks the `DebruijnIndex` and assumes that it does not matter
90+
/// for the delegate how often its methods get used.
91+
cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
92+
}
93+
94+
impl<'tcx, const UWU: bool> TypeFolder<TyCtxt<'tcx>> for BoundVarReplacer<'tcx, UWU> {
95+
fn cx(&self) -> TyCtxt<'tcx> {
96+
self.tcx
97+
}
98+
99+
fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
100+
&mut self,
101+
t: ty::Binder<'tcx, T>,
102+
) -> ty::Binder<'tcx, T> {
103+
self.current_index.shift_in(1);
104+
let t = t.super_fold_with(self);
105+
self.current_index.shift_out(1);
106+
t
107+
}
108+
109+
fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
110+
match *t.kind() {
111+
ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
112+
self.var_values[bound_ty.var.as_usize()].expect_ty()
113+
}
114+
_ => {
115+
if !t.has_vars_bound_at_or_above(self.current_index) {
116+
t
117+
} else if let Some(&t) = self.cache.get(&(self.current_index, t)) {
118+
t
119+
} else {
120+
let res = t.super_fold_with(self);
121+
assert!(self.cache.insert((self.current_index, t), res));
122+
res
123+
}
124+
}
125+
}
126+
}
127+
128+
fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
129+
match r.kind() {
130+
ty::ReBound(debruijn, br) if debruijn == self.current_index => {
131+
self.var_values[br.var.as_usize()].expect_region()
132+
}
133+
_ => r,
134+
}
135+
}
136+
137+
fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
138+
match ct.kind() {
139+
ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
140+
self.var_values[bound_const.as_usize()].expect_const()
141+
}
142+
_ => ct.super_fold_with(self),
143+
}
144+
}
145+
146+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
147+
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
148+
}
149+
150+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
151+
if c.has_vars_bound_at_or_above(self.current_index) {
152+
if self.current_index == ty::INNERMOST {
153+
if UWU {
154+
let index = *self
155+
.tcx
156+
.highest_var_in_clauses_cache
157+
.lock()
158+
.entry(c)
159+
.or_insert_with(|| highest_var_in_clauses(c));
160+
let c_args = &self.var_values[..=index];
161+
if let Some(c2) = self.tcx.clauses_cache.lock().get(&(c, c_args)) {
162+
c2
163+
} else {
164+
let folded = c.super_fold_with(self);
165+
self.tcx.clauses_cache.lock().insert((c, c_args), folded);
166+
folded
167+
}
168+
} else {
169+
c.super_fold_with(self)
170+
}
171+
} else {
172+
c.super_fold_with(self)
173+
}
174+
} else {
175+
c
176+
}
177+
}
178+
}
179+
180+
fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize {
181+
struct HighestVarInClauses {
182+
max_var: usize,
183+
current_index: ty::DebruijnIndex,
184+
}
185+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HighestVarInClauses {
186+
fn visit_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
187+
&mut self,
188+
t: &ty::Binder<'tcx, T>,
189+
) -> Self::Result {
190+
self.current_index.shift_in(1);
191+
let t = t.super_visit_with(self);
192+
self.current_index.shift_out(1);
193+
t
194+
}
195+
fn visit_ty(&mut self, t: Ty<'tcx>) {
196+
if let ty::Bound(debruijn, bound_ty) = *t.kind()
197+
&& debruijn == self.current_index
198+
{
199+
self.max_var = self.max_var.max(bound_ty.var.as_usize());
200+
} else if t.has_vars_bound_at_or_above(self.current_index) {
201+
t.super_visit_with(self);
202+
}
203+
}
204+
fn visit_region(&mut self, r: ty::Region<'tcx>) {
205+
if let ty::ReBound(debruijn, bound_region) = r.kind()
206+
&& debruijn == self.current_index
207+
{
208+
self.max_var = self.max_var.max(bound_region.var.as_usize());
209+
}
210+
}
211+
fn visit_const(&mut self, ct: ty::Const<'tcx>) {
212+
if let ty::ConstKind::Bound(debruijn, bound_const) = ct.kind()
213+
&& debruijn == self.current_index
214+
{
215+
self.max_var = self.max_var.max(bound_const.as_usize());
216+
} else if ct.has_vars_bound_at_or_above(self.current_index) {
217+
ct.super_visit_with(self);
218+
}
219+
}
79220
}
221+
let mut visitor = HighestVarInClauses { max_var: 0, current_index: ty::INNERMOST };
222+
c.visit_with(&mut visitor);
223+
visitor.max_var
80224
}

compiler/rustc_infer/src/infer/canonical/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl<'tcx> InferCtxt<'tcx> {
6969

7070
let canonical_inference_vars =
7171
self.instantiate_canonical_vars(span, canonical.variables, |ui| universes[ui]);
72-
let result = canonical.instantiate(self.tcx, &canonical_inference_vars);
72+
let result = canonical.instantiate::<true>(self.tcx, &canonical_inference_vars);
7373
(result, canonical_inference_vars)
7474
}
7575

compiler/rustc_infer/src/infer/canonical/query_response.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use rustc_middle::mir::ConstraintCategory;
1717
use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt, TypeFoldable};
1818
use tracing::{debug, instrument};
1919

20-
use crate::infer::canonical::instantiate::{CanonicalExt, instantiate_value};
20+
use crate::infer::canonical::instantiate::{CanonicalExt, instantiate_value_0};
2121
use crate::infer::canonical::{
2222
Canonical, CanonicalQueryResponse, CanonicalVarValues, Certainty, OriginalQueryValues,
2323
QueryRegionConstraints, QueryResponse,
@@ -170,12 +170,13 @@ impl<'tcx> InferCtxt<'tcx> {
170170
self.query_response_instantiation(cause, param_env, original_values, query_response)?;
171171

172172
for (predicate, _category) in &query_response.value.region_constraints.outlives {
173-
let predicate = instantiate_value(self.tcx, &result_args, *predicate);
173+
let predicate = instantiate_value_0::<_, false>(self.tcx, &result_args, *predicate);
174174
self.register_outlives_constraint(predicate, cause);
175175
}
176176

177177
let user_result: R =
178-
query_response.instantiate_projected(self.tcx, &result_args, |q_r| q_r.value.clone());
178+
query_response
179+
.instantiate_projected::<_, false>(self.tcx, &result_args, |q_r| q_r.value.clone());
179180

180181
Ok(InferOk { value: user_result, obligations })
181182
}
@@ -242,9 +243,10 @@ impl<'tcx> InferCtxt<'tcx> {
242243

243244
for (index, original_value) in original_values.var_values.iter().enumerate() {
244245
// ...with the value `v_r` of that variable from the query.
245-
let result_value = query_response.instantiate_projected(self.tcx, &result_args, |v| {
246-
v.var_values[BoundVar::new(index)]
247-
});
246+
let result_value =
247+
query_response.instantiate_projected::<_, false>(self.tcx, &result_args, |v| {
248+
v.var_values[BoundVar::new(index)]
249+
});
248250
match (original_value.kind(), result_value.kind()) {
249251
(GenericArgKind::Lifetime(re1), GenericArgKind::Lifetime(re2))
250252
if re1.is_erased() && re2.is_erased() =>
@@ -289,7 +291,7 @@ impl<'tcx> InferCtxt<'tcx> {
289291
// ...also include the other query region constraints from the query.
290292
output_query_region_constraints.outlives.extend(
291293
query_response.value.region_constraints.outlives.iter().filter_map(|&r_c| {
292-
let r_c = instantiate_value(self.tcx, &result_args, r_c);
294+
let r_c = instantiate_value_0::<_, false>(self.tcx, &result_args, r_c);
293295

294296
// Screen out `'a: 'a` cases.
295297
let ty::OutlivesPredicate(k1, r2) = r_c.0;
@@ -298,7 +300,8 @@ impl<'tcx> InferCtxt<'tcx> {
298300
);
299301

300302
let user_result: R =
301-
query_response.instantiate_projected(self.tcx, &result_args, |q_r| q_r.value.clone());
303+
query_response
304+
.instantiate_projected::<_, false>(self.tcx, &result_args, |q_r| q_r.value.clone());
302305

303306
Ok(InferOk { value: user_result, obligations })
304307
}
@@ -469,8 +472,8 @@ impl<'tcx> InferCtxt<'tcx> {
469472

470473
// Carry all newly resolved opaque types to the caller's scope
471474
for &(a, b) in &query_response.value.opaque_types {
472-
let a = instantiate_value(self.tcx, &result_args, a);
473-
let b = instantiate_value(self.tcx, &result_args, b);
475+
let a = instantiate_value_0::<_, false>(self.tcx, &result_args, a);
476+
let b = instantiate_value_0::<_, false>(self.tcx, &result_args, b);
474477
debug!(?a, ?b, "constrain opaque type");
475478
// We use equate here instead of, for example, just registering the
476479
// opaque type's hidden value directly, because the hidden type may have been an inference
@@ -512,7 +515,8 @@ impl<'tcx> InferCtxt<'tcx> {
512515
// `query_response.var_values` after applying the instantiation
513516
// by `result_args`.
514517
let instantiated_query_response = |index: BoundVar| -> GenericArg<'tcx> {
515-
query_response.instantiate_projected(self.tcx, result_args, |v| v.var_values[index])
518+
query_response
519+
.instantiate_projected::<_, false>(self.tcx, result_args, |v| v.var_values[index])
516520
};
517521

518522
// Unify the original value for each variable with the value

compiler/rustc_middle/src/ty/context.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,10 @@ pub struct GlobalCtxt<'tcx> {
14601460

14611461
pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,
14621462

1463+
pub highest_var_in_clauses_cache: Lock<FxHashMap<ty::Clauses<'tcx>, usize>>,
1464+
pub clauses_cache:
1465+
Lock<FxHashMap<(ty::Clauses<'tcx>, &'tcx [ty::GenericArg<'tcx>]), ty::Clauses<'tcx>>>,
1466+
14631467
/// Data layout specification for the current target.
14641468
pub data_layout: TargetDataLayout,
14651469

@@ -1707,6 +1711,8 @@ impl<'tcx> TyCtxt<'tcx> {
17071711
new_solver_evaluation_cache: Default::default(),
17081712
new_solver_canonical_param_env_cache: Default::default(),
17091713
canonical_param_env_cache: Default::default(),
1714+
highest_var_in_clauses_cache: Default::default(),
1715+
clauses_cache: Default::default(),
17101716
data_layout,
17111717
alloc_map: interpret::AllocMap::new(),
17121718
current_gcx,

compiler/rustc_trait_selection/src/solve/delegate.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ impl<'tcx> rustc_next_trait_solver::delegate::SolverDelegate for SolverDelegate<
222222
where
223223
V: TypeFoldable<TyCtxt<'tcx>>,
224224
{
225-
canonical.instantiate(self.tcx, &values)
225+
canonical.instantiate::<false>(self.tcx, &values)
226226
}
227227

228228
fn instantiate_canonical_var_with_infer(

0 commit comments

Comments
 (0)