Skip to content

Commit 19783b8

Browse files
stroxlermeta-codesync[bot]
authored andcommitted
make call-scope context a required solve boundary
Summary: This is part of the solver-boundary plan: constraint solving is usually per-call, but we also rely on ad-hoc subset boundaries where call-like behavior can occur with residual tracking disabled, so we must make boundary ownership explicit and always drain call-scoped state before leaving the boundary. This change makes callable finishing consume the active call context, adds explicit boundary-consumption state with debug-time drop invariants, and tracks fresh quantified vars created during call-scoped solving so boundary finishing still sees them even when residual reachability is lost. NOTE: there's a lot of new code here and not much deltion, but a lot of the new logic is just the preparation / plumbing phase of a mini-stack that shifts responsibility for finishing around - there's offsetting deletion of a lot of legacy finishing code in D103337567 (which is the end goal - that commit fixes the `reduce` regression by storing all fresh vars in context and finishing them at the appropriate time). Reviewed By: rchen152 Differential Revision: D103328383 fbshipit-source-id: d5bf4bbff28154cb64f7365a9e64c66c78fe90fd
1 parent f024130 commit 19783b8

4 files changed

Lines changed: 204 additions & 8 deletions

File tree

pyrefly/lib/alt/callable.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
13851385
) {
13861386
let mut call_context = CallContext::outside();
13871387
call_context.set_argument_side(ArgumentSide::Got);
1388+
call_context.require_boundary_consumption();
13881389

13891390
// Look up meta-shape early so we can conditionally collect bound args.
13901391
// Only consult the registry when tensor_shapes is enabled to avoid
@@ -1422,6 +1423,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
14221423
} else {
14231424
(QuantifiedHandle::empty(), callable)
14241425
};
1426+
call_context.register_fresh_quantified_vars(callable_qs.vars());
14251427
let (self_qs, remaining_callable_qs) = if self_obj.is_some()
14261428
&& let Some(first_param) = callable.get_first_param()
14271429
// TODO(https://github.com/facebook/pyrefly/issues/105): handle nested vars
@@ -1437,6 +1439,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
14371439
};
14381440
let ctor_qs = if let Some(targs) = ctor_targs.as_mut() {
14391441
let qs = self.solver().freshen_class_targs(targs, self.uniques);
1442+
call_context.register_fresh_quantified_vars(qs.vars());
14401443
let mp = targs.substitution_map();
14411444
callable.params.visit_mut(&mut |t| t.subst_mut(&mp));
14421445
if let Some(obj) = self_obj.as_mut() {
@@ -1564,10 +1567,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
15641567
}
15651568
let mut errors = self
15661569
.solver()
1567-
.finish_quantified_with_type_order(
1570+
.finish_quantified_with_type_order_and_call_context(
15681571
remaining_callable_qs,
15691572
self.solver().infer_with_first_use,
15701573
self.type_order(),
1574+
&call_context,
15711575
)
15721576
.map_or_else(|e| e.to_vec(), |_| Vec::new());
15731577
if let Err(e) = self.solver().finish_quantified_with_type_order(

pyrefly/lib/solver/solver.rs

Lines changed: 196 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use std::hash::Hash;
1717
use std::hash::Hasher;
1818
use std::mem;
1919
use std::sync::Arc;
20+
use std::sync::atomic::AtomicBool;
21+
use std::sync::atomic::Ordering;
2022

2123
use itertools::Either;
2224
use itertools::Itertools;
@@ -315,6 +317,10 @@ impl QuantifiedHandle {
315317
Self(Vec::new())
316318
}
317319

320+
pub(crate) fn vars(&self) -> &[Var] {
321+
&self.0
322+
}
323+
318324
/// Split the handle into (vars in ty, vars not in ty)
319325
pub fn partition_by(self, ty: &Type) -> (Self, Self) {
320326
let vars_in_ty = ty.collect_maybe_placeholder_vars();
@@ -2570,6 +2576,118 @@ impl Solver {
25702576
})
25712577
}
25722578

2579+
/// Finish quantified vars at a call boundary by consuming tracked fresh
2580+
/// quantified vars from `CallContext`, then finishing the reachable
2581+
/// quantified closure from explicit roots plus tracked roots.
2582+
pub fn finish_quantified_with_type_order_and_call_context<Ans: LookupAnswer>(
2583+
&self,
2584+
vs: QuantifiedHandle,
2585+
infer_with_first_use: bool,
2586+
type_order: TypeOrder<Ans>,
2587+
call_context: &CallContext,
2588+
) -> Result<(), Vec1<TypeVarSpecializationError>> {
2589+
let tracked_fresh_vars = call_context.take_deferred_quantified_vars();
2590+
let overload_witness_payloads = call_context.take_overload_witness_payloads();
2591+
call_context.mark_boundary_consumed_and_drained();
2592+
let payload_vars: SmallSet<Var> = overload_witness_payloads
2593+
.values()
2594+
.flat_map(|branch_captures| branch_captures.iter())
2595+
.flat_map(|capture| capture.values.keys().copied())
2596+
.collect();
2597+
let mut roots: SmallSet<Var> = vs.0.into_iter().collect();
2598+
roots.extend(tracked_fresh_vars.0);
2599+
// Forall instantiation during call analysis can unify call-scope vars
2600+
// with additional fresh vars that are only visible through Answer /
2601+
// ResidualAnswer payloads. Finish the full reachable closure so no
2602+
// reachable Variable::Quantified can leak to pinning.
2603+
//
2604+
// We also must include reachable vars that already became Answer but
2605+
// still have pending instantiation errors. Those errors are surfaced by
2606+
// finish_quantified, so excluding Answer vars here can silently drop
2607+
// call-site specialization diagnostics.
2608+
let roots = roots.into_iter().collect::<Vec<_>>();
2609+
let mut already_finished: SmallSet<Var> = SmallSet::new();
2610+
loop {
2611+
// Fixed-point: finishing can mutate solver state and expose new
2612+
// reachable vars that also require finishing.
2613+
let reachable_finish_vars = self.reachable_finish_vars_from_roots(&roots);
2614+
let mut next_round: Vec<Var> = reachable_finish_vars
2615+
.into_iter()
2616+
.filter(|var| !already_finished.contains(var))
2617+
.collect();
2618+
// Payload-driven overload pruning must consider solved vars even if
2619+
// they already collapsed to `Answer` before finishing and therefore
2620+
// are not selected by reachability-based finishing alone.
2621+
next_round.extend(
2622+
payload_vars
2623+
.iter()
2624+
.copied()
2625+
.filter(|var| !already_finished.contains(var)),
2626+
);
2627+
next_round.sort_unstable();
2628+
next_round.dedup();
2629+
if next_round.is_empty() {
2630+
break;
2631+
}
2632+
already_finished.extend(next_round.iter().copied());
2633+
let mut subset = self.subset(type_order);
2634+
self.finish_quantified_with_subset_and_payloads(
2635+
QuantifiedHandle(next_round),
2636+
infer_with_first_use,
2637+
&mut |got, want| subset.is_subset_eq_probe_for_pruning(got, want),
2638+
Some(&overload_witness_payloads),
2639+
)?;
2640+
}
2641+
Ok(())
2642+
}
2643+
2644+
fn reachable_finish_vars_from_roots(&self, roots: &[Var]) -> SmallSet<Var> {
2645+
if roots.is_empty() {
2646+
return SmallSet::new();
2647+
}
2648+
let variables = self.variables.lock();
2649+
let instantiation_errors = self.instantiation_errors.read();
2650+
let mut visited: SmallSet<Var> = SmallSet::new();
2651+
let mut reachable_finish_vars: SmallSet<Var> = SmallSet::new();
2652+
let mut stack = roots.to_vec();
2653+
while let Some(var) = stack.pop() {
2654+
if !visited.insert(var) {
2655+
continue;
2656+
}
2657+
let variable = variables.get(var);
2658+
let needs_finish = match &*variable {
2659+
Variable::Quantified { .. } => true,
2660+
Variable::Answer(_) | Variable::ResidualAnswer { .. } => {
2661+
instantiation_errors.contains_key(&var)
2662+
}
2663+
_ => false,
2664+
};
2665+
if needs_finish {
2666+
reachable_finish_vars.insert(var);
2667+
}
2668+
match &*variable {
2669+
Variable::Answer(ty) => {
2670+
stack.extend(ty.collect_maybe_placeholder_vars());
2671+
}
2672+
Variable::ResidualAnswer { target_vars: _, ty } => {
2673+
// `target_vars` are read-gates for residual visibility,
2674+
// not ownership edges for finishing reachability.
2675+
stack.extend(ty.collect_maybe_placeholder_vars());
2676+
}
2677+
Variable::Quantified { .. } => {}
2678+
Variable::Unwrap(bounds) => {
2679+
for ty in bounds.lower.iter().chain(bounds.upper.iter()) {
2680+
stack.extend(ty.collect_maybe_placeholder_vars());
2681+
}
2682+
}
2683+
Variable::PartialQuantified(_)
2684+
| Variable::PartialContained(_)
2685+
| Variable::Recursive => {}
2686+
}
2687+
}
2688+
reachable_finish_vars
2689+
}
2690+
25732691
/// Finish all quantified vars reachable from `ty` using the solver default
25742692
/// inference mode.
25752693
///
@@ -3499,6 +3617,10 @@ pub struct CallContext {
34993617
/// Invariant: payload entries are scoped to this call-context lineage and
35003618
/// must not leak across `with_outside_call_context` boundaries.
35013619
overload_witness_payloads: Arc<Mutex<OverloadWitnessPayloadByHash>>,
3620+
/// Whether this context must be consumed at a solve boundary.
3621+
require_boundary_consumption: Arc<AtomicBool>,
3622+
/// Whether deferred state from this context lineage was consumed/drained.
3623+
boundary_consumed_and_drained: Arc<AtomicBool>,
35023624
}
35033625

35043626
impl Default for CallContext {
@@ -3508,6 +3630,8 @@ impl Default for CallContext {
35083630
argument_side: ArgumentSide::default(),
35093631
deferred_quantified_vars: Arc::new(Mutex::new(SmallSet::new())),
35103632
overload_witness_payloads: Arc::new(Mutex::new(SmallMap::new())),
3633+
require_boundary_consumption: Arc::new(AtomicBool::new(false)),
3634+
boundary_consumed_and_drained: Arc::new(AtomicBool::new(false)),
35113635
}
35123636
}
35133637
}
@@ -3521,6 +3645,18 @@ impl CallContext {
35213645
self.argument_side = argument_side;
35223646
}
35233647

3648+
pub(crate) fn register_fresh_quantified_vars(&self, vars: &[Var]) {
3649+
let mut deferred_quantified_vars = self.deferred_quantified_vars.lock();
3650+
deferred_quantified_vars.extend(vars.iter().copied());
3651+
}
3652+
3653+
pub fn require_boundary_consumption(&self) {
3654+
self.require_boundary_consumption
3655+
.store(true, Ordering::Relaxed);
3656+
self.boundary_consumed_and_drained
3657+
.store(false, Ordering::Relaxed);
3658+
}
3659+
35243660
pub fn residual_witness(&self) -> Option<&ResidualWitnessContext> {
35253661
self.witness.as_ref()
35263662
}
@@ -3556,6 +3692,53 @@ impl CallContext {
35563692
SubsetCacheContext::Default
35573693
}
35583694
}
3695+
3696+
pub(crate) fn take_deferred_quantified_vars(&self) -> QuantifiedHandle {
3697+
let mut deferred_quantified_vars = self.deferred_quantified_vars.lock();
3698+
QuantifiedHandle(
3699+
mem::take(&mut *deferred_quantified_vars)
3700+
.into_iter()
3701+
.collect(),
3702+
)
3703+
}
3704+
3705+
pub(crate) fn take_overload_witness_payloads(&self) -> OverloadWitnessPayloadByHash {
3706+
let mut overload_witness_payloads = self.overload_witness_payloads.lock();
3707+
mem::take(&mut *overload_witness_payloads)
3708+
}
3709+
3710+
fn mark_boundary_consumed_and_drained(&self) {
3711+
self.boundary_consumed_and_drained
3712+
.store(true, Ordering::Relaxed);
3713+
}
3714+
}
3715+
3716+
impl Drop for CallContext {
3717+
fn drop(&mut self) {
3718+
#[cfg(debug_assertions)]
3719+
{
3720+
if std::thread::panicking()
3721+
|| Arc::strong_count(&self.require_boundary_consumption) != 1
3722+
{
3723+
return;
3724+
}
3725+
if !self.require_boundary_consumption.load(Ordering::Relaxed) {
3726+
return;
3727+
}
3728+
assert!(
3729+
self.boundary_consumed_and_drained.load(Ordering::Relaxed),
3730+
"CallContext dropped without boundary consume/drain",
3731+
);
3732+
assert!(
3733+
self.deferred_quantified_vars.lock().is_empty(),
3734+
"CallContext dropped with deferred quantified vars still pending",
3735+
);
3736+
assert!(
3737+
self.overload_witness_payloads.lock().is_empty(),
3738+
"CallContext dropped with overload witness payloads still pending",
3739+
);
3740+
}
3741+
}
35593742
}
35603743

35613744
/// A helper to implement subset ergonomically.
@@ -3704,11 +3887,21 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
37043887
&mut self.active_call_context.overload_witness_payloads,
37053888
call_context.overload_witness_payloads.clone(),
37063889
);
3890+
let old_require_boundary_consumption = mem::replace(
3891+
&mut self.active_call_context.require_boundary_consumption,
3892+
call_context.require_boundary_consumption.clone(),
3893+
);
3894+
let old_boundary_consumed_and_drained = mem::replace(
3895+
&mut self.active_call_context.boundary_consumed_and_drained,
3896+
call_context.boundary_consumed_and_drained.clone(),
3897+
);
37073898
let res = f(self);
37083899
self.active_call_context.witness = old_witness;
37093900
self.active_call_context.argument_side = old_argument_side;
37103901
self.active_call_context.deferred_quantified_vars = old_deferred_quantified_vars;
37113902
self.active_call_context.overload_witness_payloads = old_overload_witness_payloads;
3903+
self.active_call_context.require_boundary_consumption = old_require_boundary_consumption;
3904+
self.active_call_context.boundary_consumed_and_drained = old_boundary_consumed_and_drained;
37123905
res
37133906
}
37143907

@@ -3746,18 +3939,16 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
37463939
&mut self.active_call_context.argument_side,
37473940
ArgumentSide::NotAnalyzingACall,
37483941
);
3749-
let old_deferred_quantified_vars = mem::replace(
3750-
&mut self.active_call_context.deferred_quantified_vars,
3751-
Arc::new(Mutex::new(SmallSet::new())),
3752-
);
3942+
// Keep fresh-var tracking attached to the same boundary while
3943+
// temporarily disabling residual hooks. Fresh quantified vars created in
3944+
// this scope must still be finished when the outer boundary drains.
37533945
let old_overload_witness_payloads = mem::replace(
37543946
&mut self.active_call_context.overload_witness_payloads,
37553947
Arc::new(Mutex::new(SmallMap::new())),
37563948
);
37573949
let res = f(self);
37583950
self.active_call_context.witness = old_witness;
37593951
self.active_call_context.argument_side = old_argument_side;
3760-
self.active_call_context.deferred_quantified_vars = old_deferred_quantified_vars;
37613952
self.active_call_context.overload_witness_payloads = old_overload_witness_payloads;
37623953
res
37633954
}

pyrefly/lib/solver/subset.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,6 +2248,8 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
22482248
let forall_type = Type::Forall(forall.clone());
22492249
// Finalizing the quantified vars returns instantiation errors
22502250
let (vs, got) = self.type_order.instantiate_fresh_forall((**forall).clone());
2251+
self.active_call_context
2252+
.register_fresh_quantified_vars(vs.vars());
22512253
let argument_side = self.active_argument_side();
22522254
let witness = self.make_forall_witness(&forall_type, &vs, want);
22532255
let (result, mut maybe_witness) =

pyrefly/lib/test/callable_residuals.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,6 @@ reveal_type(result) # E: revealed type: () -> tuple[int, int]
602602
);
603603

604604
testcase!(
605-
bug = "Constrained type vars can collapse to Answer before finishing, so overload pruning can miss them",
606605
test_overload_pruning_ignored_for_constrained_tvar_solved_early,
607606
r#"
608607
from typing import Callable, overload, reveal_type
@@ -615,7 +614,7 @@ def f(x: float) -> float: ... # E: Overload return type `float` is not assignab
615614
def f(x: bytes) -> bytes: ... # E: Overload return type `bytes` is not assignable to implementation return type `None`
616615
def f(x): ...
617616
618-
result = project(f, 1)
617+
result = project(f, 1) # E: Overload type was not compatible with solved type variables: unknown = int
619618
reveal_type(result) # E: revealed type: (int) -> int
620619
"#,
621620
);

0 commit comments

Comments
 (0)