diff --git a/executor/src/witgen/machines/machine_extractor.rs b/executor/src/witgen/machines/machine_extractor.rs index 8e62758c85..2c8fd6c4b0 100644 --- a/executor/src/witgen/machines/machine_extractor.rs +++ b/executor/src/witgen/machines/machine_extractor.rs @@ -18,6 +18,7 @@ use crate::witgen::data_structures::identity::Identity; use crate::witgen::machines::dynamic_machine::DynamicMachine; use crate::witgen::machines::second_stage_machine::SecondStageMachine; use crate::witgen::machines::{write_once_memory::WriteOnceMemory, MachineParts}; +use powdr_ast::analyzed::AlgebraicExpression; use powdr_ast::analyzed::{ self, AlgebraicExpression as Expression, PolyID, PolynomialReference, Reference, @@ -136,11 +137,6 @@ impl<'a, T: FieldElement> MachineExtractor<'a, T> { publics.add_all(machine_identities.as_slice()).unwrap(); - let machine_intermediates = intermediates_in_identities( - &machine_identities, - &self.fixed.intermediate_definitions, - ); - // Connections that call into the current machine let machine_receives = self .fixed @@ -158,6 +154,14 @@ impl<'a, T: FieldElement> MachineExtractor<'a, T> { .collect::>(); assert!(machine_receives.contains_key(&bus_receive.bus_id)); + let machine_intermediates = intermediates_in_expressions( + machine_identities + .iter() + .flat_map(|i| i.all_children()) + .chain(machine_receives.values().flat_map(|r| r.all_children())), + &self.fixed.intermediate_definitions, + ); + let prover_functions = prover_functions .iter() .copied() @@ -241,8 +245,10 @@ impl<'a, T: FieldElement> MachineExtractor<'a, T> { .difference(&multiplicity_columns) .cloned() .collect::>(); - let main_intermediates = - intermediates_in_identities(&base_identities, &self.fixed.intermediate_definitions); + let main_intermediates = intermediates_in_expressions( + base_identities.iter().flat_map(|i| i.all_children()), + &self.fixed.intermediate_definitions, + ); log::trace!( "\nThe base machine contains the following witnesses:\n{}\n identities:\n{}\n and prover functions:\n{}", @@ -488,15 +494,13 @@ fn try_as_intermediate_ref(expr: &Expression) -> Option<(Pol } } -/// Returns all intermediate columns referenced in the identities as a map to their name. +/// Returns all intermediate columns referenced in the expression as a map to their name. /// Follows intermediate references recursively. -fn intermediates_in_identities( - identities: &[&Identity], +fn intermediates_in_expressions<'a, T: FieldElement>( + expressions: impl Iterator>, intermediate_definitions: &BTreeMap>, ) -> HashMap { - let mut queue = identities - .iter() - .flat_map(|id| id.all_children()) + let mut queue = expressions .filter_map(try_as_intermediate_ref) .collect::>(); let mut intermediates = HashMap::new();