Skip to content

Commit 1202fab

Browse files
committed
!! (WIP; squish) make Thir responsible for walking THIR patterns
1 parent 62fd815 commit 1202fab

File tree

7 files changed

+63
-34
lines changed

7 files changed

+63
-34
lines changed

compiler/rustc_middle/src/thir.rs

+33-14
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,17 @@ impl<'tcx> Pat<'tcx> {
640640
_ => None,
641641
}
642642
}
643+
}
643644

645+
impl<'tcx> Thir<'tcx> {
644646
/// Call `f` on every "binding" in a pattern, e.g., on `a` in
645647
/// `match foo() { Some(a) => (), None => () }`
646-
pub fn each_binding(&self, mut f: impl FnMut(Symbol, ByRef, Ty<'tcx>, Span)) {
647-
self.walk_always(|p| {
648+
pub fn for_each_binding_in_pat(
649+
&self,
650+
pat: &Pat<'tcx>,
651+
mut f: impl FnMut(Symbol, ByRef, Ty<'tcx>, Span),
652+
) {
653+
self.walk_pat_always(pat, |p| {
648654
if let PatKind::Binding { name, mode, ty, .. } = p.kind {
649655
f(name, mode.0, ty, p.span);
650656
}
@@ -654,22 +660,22 @@ impl<'tcx> Pat<'tcx> {
654660
/// Walk the pattern in left-to-right order.
655661
///
656662
/// If `it(pat)` returns `false`, the children are not visited.
657-
pub fn walk(&self, mut it: impl FnMut(&Pat<'tcx>) -> bool) {
658-
self.walk_(&mut it)
663+
pub fn walk_pat(&self, pat: &Pat<'tcx>, mut it: impl FnMut(&Pat<'tcx>) -> bool) {
664+
self.walk_pat_inner(pat, &mut it)
659665
}
660666

661-
fn walk_(&self, it: &mut impl FnMut(&Pat<'tcx>) -> bool) {
662-
if !it(self) {
667+
fn walk_pat_inner(&self, pat: &Pat<'tcx>, it: &mut impl FnMut(&Pat<'tcx>) -> bool) {
668+
if !it(pat) {
663669
return;
664670
}
665671

666-
for_each_immediate_subpat(self, |p| p.walk_(it));
672+
for_each_immediate_subpat(pat, |p| self.walk_pat_inner(p, it));
667673
}
668674

669675
/// Whether the pattern has a `PatKind::Error` nested within.
670-
pub fn pat_error_reported(&self) -> Result<(), ErrorGuaranteed> {
676+
pub fn pat_error_reported(&self, pat: &Pat<'tcx>) -> Result<(), ErrorGuaranteed> {
671677
let mut error = None;
672-
self.walk(|pat| {
678+
self.walk_pat(pat, |pat| {
673679
if let PatKind::Error(e) = pat.kind
674680
&& error.is_none()
675681
{
@@ -683,26 +689,39 @@ impl<'tcx> Pat<'tcx> {
683689
}
684690
}
685691

692+
pub fn pat_references_error(&self, pat: &Pat<'tcx>) -> bool {
693+
use rustc_type_ir::visit::TypeVisitableExt;
694+
695+
let mut references_error = TypeVisitableExt::references_error(pat);
696+
if !references_error {
697+
for_each_immediate_subpat(pat, |p| {
698+
references_error = references_error || self.pat_references_error(p);
699+
});
700+
}
701+
702+
references_error
703+
}
704+
686705
/// Walk the pattern in left-to-right order.
687706
///
688707
/// If you always want to recurse, prefer this method over `walk`.
689-
pub fn walk_always(&self, mut it: impl FnMut(&Pat<'tcx>)) {
690-
self.walk(|p| {
708+
pub fn walk_pat_always(&self, pat: &Pat<'tcx>, mut it: impl FnMut(&Pat<'tcx>)) {
709+
self.walk_pat(pat, |p| {
691710
it(p);
692711
true
693712
})
694713
}
695714

696715
/// Whether this a never pattern.
697-
pub fn is_never_pattern(&self) -> bool {
716+
pub fn is_never_pattern(&self, pat: &Pat<'tcx>) -> bool {
698717
let mut is_never_pattern = false;
699-
self.walk(|pat| match &pat.kind {
718+
self.walk_pat(pat, |pat| match &pat.kind {
700719
PatKind::Never => {
701720
is_never_pattern = true;
702721
false
703722
}
704723
PatKind::Or { pats } => {
705-
is_never_pattern = pats.iter().all(|p| p.is_never_pattern());
724+
is_never_pattern = pats.iter().all(|p| self.is_never_pattern(p));
706725
false
707726
}
708727
_ => true,

compiler/rustc_mir_build/src/builder/matches/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
861861
pattern: &Pat<'tcx>,
862862
f: &mut impl FnMut(&mut Self, LocalVarId, Span),
863863
) {
864-
pattern.walk_always(|pat| {
864+
self.thir.walk_pat_always(pattern, |pat| {
865865
if let PatKind::Binding { var, is_primary: true, .. } = pat.kind {
866866
f(self, var, pat.span);
867867
}
@@ -1037,7 +1037,7 @@ impl<'tcx> FlatPat<'tcx> {
10371037
span: pattern.span,
10381038
bindings: Vec::new(),
10391039
ascriptions: Vec::new(),
1040-
is_never: pattern.is_never_pattern(),
1040+
is_never: cx.thir.is_never_pattern(pattern),
10411041
};
10421042
// Recursively remove irrefutable match pairs, while recording their
10431043
// bindings/ascriptions, and sort or-patterns after other match pairs.

compiler/rustc_mir_build/src/thir/cx/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ impl<'tcx> ThirBuildCx<'tcx> {
120120

121121
#[instrument(level = "debug", skip(self))]
122122
fn pattern_from_hir(&mut self, p: &'tcx hir::Pat<'tcx>) -> Box<Pat<'tcx>> {
123-
pat_from_hir(self.tcx, self.typing_env, self.typeck_results, p)
123+
pat_from_hir(self.tcx, &mut self.thir, self.typing_env, self.typeck_results, p)
124124
}
125125

126126
fn closure_env_param(&self, owner_def: LocalDefId, expr_id: HirId) -> Option<Param<'tcx>> {

compiler/rustc_mir_build/src/thir/pattern/check_match.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,14 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
277277
cx: &PatCtxt<'p, 'tcx>,
278278
pat: &'p Pat<'tcx>,
279279
) -> Result<&'p DeconstructedPat<'p, 'tcx>, ErrorGuaranteed> {
280-
if let Err(err) = pat.pat_error_reported() {
280+
if let Err(err) = cx.thir.pat_error_reported(pat) {
281281
self.error = Err(err);
282282
Err(err)
283283
} else {
284284
// Check the pattern for some things unrelated to exhaustiveness.
285285
let refutable = if cx.refutable { Refutable } else { Irrefutable };
286286
let mut err = Ok(());
287-
pat.walk_always(|pat| {
287+
cx.thir.walk_pat_always(pat, |pat| {
288288
check_borrow_conflicts_in_at_patterns(self, pat);
289289
check_for_bindings_named_same_as_variants(self, pat, refutable);
290290
err = err.and(check_never_pattern(cx, pat));
@@ -385,6 +385,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
385385
scrutinee.map(|scrut| self.is_known_valid_scrutinee(scrut)).unwrap_or(true);
386386
PatCtxt {
387387
tcx: self.tcx,
388+
thir: self.thir,
388389
typeck_results: self.typeck_results,
389390
typing_env: self.typing_env,
390391
module: self.tcx.parent_module(self.lint_level).to_def_id(),
@@ -704,7 +705,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
704705
&& scrut.is_some()
705706
{
706707
let mut bindings = vec![];
707-
pat.each_binding(|name, _, _, _| bindings.push(name));
708+
self.thir.for_each_binding_in_pat(pat, |name, _, _, _| bindings.push(name));
708709

709710
let semi_span = span.shrink_to_hi();
710711
let start_span = span.shrink_to_lo();
@@ -780,7 +781,7 @@ fn check_borrow_conflicts_in_at_patterns<'tcx>(cx: &MatchVisitor<'_, 'tcx>, pat:
780781
ByRef::No if is_binding_by_move(ty) => {
781782
// We have `x @ pat` where `x` is by-move. Reject all borrows in `pat`.
782783
let mut conflicts_ref = Vec::new();
783-
sub.each_binding(|_, mode, _, span| {
784+
cx.thir.for_each_binding_in_pat(sub, |_, mode, _, span| {
784785
if matches!(mode, ByRef::Yes(_)) {
785786
conflicts_ref.push(span)
786787
}
@@ -809,7 +810,7 @@ fn check_borrow_conflicts_in_at_patterns<'tcx>(cx: &MatchVisitor<'_, 'tcx>, pat:
809810
let mut conflicts_move = Vec::new();
810811
let mut conflicts_mut_mut = Vec::new();
811812
let mut conflicts_mut_ref = Vec::new();
812-
sub.each_binding(|name, mode, ty, span| {
813+
cx.thir.for_each_binding_in_pat(sub, |name, mode, ty, span| {
813814
match mode {
814815
ByRef::Yes(mut_inner) => match (mut_outer, mut_inner) {
815816
// Both sides are `ref`.

compiler/rustc_mir_build/src/thir/pattern/const_to_pat.rs

+14-9
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@ use rustc_index::Idx;
99
use rustc_infer::infer::TyCtxtInferExt;
1010
use rustc_infer::traits::Obligation;
1111
use rustc_middle::mir::interpret::ErrorHandled;
12-
use rustc_middle::thir::{FieldPat, Pat, PatKind};
13-
use rustc_middle::ty::{
14-
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitableExt, TypeVisitor, ValTree,
15-
};
12+
use rustc_middle::thir::{FieldPat, Pat, PatKind, Thir};
13+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitor, ValTree};
1614
use rustc_middle::{mir, span_bug};
1715
use rustc_span::def_id::DefId;
1816
use rustc_span::{Span, sym};
@@ -36,7 +34,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
3634
/// so we have to carry one ourselves.
3735
#[instrument(level = "debug", skip(self), ret)]
3836
pub(super) fn const_to_pat(
39-
&self,
37+
&mut self,
4038
c: ty::Const<'tcx>,
4139
ty: Ty<'tcx>,
4240
id: hir::HirId,
@@ -52,8 +50,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
5250
}
5351
}
5452

55-
struct ConstToPat<'tcx> {
53+
struct ConstToPat<'a, 'tcx> {
5654
tcx: TyCtxt<'tcx>,
55+
thir: &'a Thir<'tcx>,
5756
typing_env: ty::TypingEnv<'tcx>,
5857
span: Span,
5958
id: hir::HirId,
@@ -63,11 +62,17 @@ struct ConstToPat<'tcx> {
6362
c: ty::Const<'tcx>,
6463
}
6564

66-
impl<'tcx> ConstToPat<'tcx> {
67-
fn new(pat_ctxt: &PatCtxt<'_, 'tcx>, id: hir::HirId, span: Span, c: ty::Const<'tcx>) -> Self {
65+
impl<'a, 'tcx> ConstToPat<'a, 'tcx> {
66+
fn new(
67+
pat_ctxt: &'a PatCtxt<'_, 'tcx>,
68+
id: hir::HirId,
69+
span: Span,
70+
c: ty::Const<'tcx>,
71+
) -> Self {
6872
trace!(?pat_ctxt.typeck_results.hir_owner);
6973
ConstToPat {
7074
tcx: pat_ctxt.tcx,
75+
thir: pat_ctxt.thir,
7176
typing_env: pat_ctxt.typing_env,
7277
span,
7378
id,
@@ -187,7 +192,7 @@ impl<'tcx> ConstToPat<'tcx> {
187192
// Convert the valtree to a const.
188193
let inlined_const_as_pat = self.valtree_to_pat(valtree, ty);
189194

190-
if !inlined_const_as_pat.references_error() {
195+
if !self.thir.pat_references_error(&inlined_const_as_pat) {
191196
// Always check for `PartialEq` if we had no other errors yet.
192197
if !type_has_partial_eq_impl(self.tcx, typing_env, ty).has_impl {
193198
let mut err = self.tcx.dcx().create_err(TypeNotPartialEq { span: self.span, ty });

compiler/rustc_mir_build/src/thir/pattern/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use rustc_hir::{self as hir, RangeEnd};
1515
use rustc_index::Idx;
1616
use rustc_middle::mir::interpret::LitToConstInput;
1717
use rustc_middle::thir::{
18-
Ascription, FieldPat, LocalVarId, Pat, PatKind, PatRange, PatRangeBoundary,
18+
Ascription, FieldPat, LocalVarId, Pat, PatKind, PatRange, PatRangeBoundary, Thir,
1919
};
2020
use rustc_middle::ty::layout::IntegerExt;
2121
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, TyCtxt, TypeVisitableExt};
@@ -30,6 +30,7 @@ use crate::errors::*;
3030

3131
struct PatCtxt<'a, 'tcx> {
3232
tcx: TyCtxt<'tcx>,
33+
thir: &'a Thir<'tcx>,
3334
typing_env: ty::TypingEnv<'tcx>,
3435
typeck_results: &'a ty::TypeckResults<'tcx>,
3536

@@ -39,12 +40,14 @@ struct PatCtxt<'a, 'tcx> {
3940

4041
pub(super) fn pat_from_hir<'a, 'tcx>(
4142
tcx: TyCtxt<'tcx>,
43+
thir: &'a Thir<'tcx>,
4244
typing_env: ty::TypingEnv<'tcx>,
4345
typeck_results: &'a ty::TypeckResults<'tcx>,
4446
pat: &'tcx hir::Pat<'tcx>,
4547
) -> Box<Pat<'tcx>> {
4648
let mut pcx = PatCtxt {
4749
tcx,
50+
thir,
4851
typing_env,
4952
typeck_results,
5053
rust_2024_migration: typeck_results

compiler/rustc_pattern_analysis/src/rustc.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_hir::def_id::DefId;
88
use rustc_index::{Idx, IndexVec};
99
use rustc_middle::middle::stability::EvalResult;
1010
use rustc_middle::mir::{self, Const};
11-
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary};
11+
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary, Thir};
1212
use rustc_middle::ty::layout::IntegerExt;
1313
use rustc_middle::ty::{
1414
self, FieldDef, OpaqueTypeKey, ScalarInt, Ty, TyCtxt, TypeVisitableExt, VariantDef,
@@ -76,8 +76,9 @@ impl<'tcx> RevealedTy<'tcx> {
7676
}
7777

7878
#[derive(Clone)]
79-
pub struct RustcPatCtxt<'p, 'tcx: 'p> {
79+
pub struct RustcPatCtxt<'p, 'tcx> {
8080
pub tcx: TyCtxt<'tcx>,
81+
pub thir: &'p Thir<'tcx>,
8182
pub typeck_results: &'tcx ty::TypeckResults<'tcx>,
8283
/// The module in which the match occurs. This is necessary for
8384
/// checking inhabited-ness of types because whether a type is (visibly)

0 commit comments

Comments
 (0)