@@ -27,26 +27,22 @@ pub(crate) struct StageOneOutput<'a, C: ProverContext> {
2727 pub witness_holder : TraceHolder < BF , C > ,
2828 pub memory_holder : TraceHolder < BF , C > ,
2929 pub generic_lookup_mapping : Option < C :: Allocation < u32 > > ,
30- pub callbacks : Callbacks < ' a > ,
30+ pub callbacks : Option < Callbacks < ' a > > ,
3131 pub public_inputs : Option < Arc < Mutex < Vec < BF > > > > ,
3232}
3333
3434impl < ' a , C : ProverContext > StageOneOutput < ' a , C > {
35- pub fn new (
35+ pub fn allocate_trace_holders (
3636 circuit : & CompiledCircuitArtifact < BF > ,
37- setup : & SetupPrecomputations < C > ,
38- tracing_data_transfer : TracingDataTransfer < ' a , C > ,
3937 log_lde_factor : u32 ,
4038 log_tree_cap_size : u32 ,
41- circuit_sequence : usize ,
4239 context : & C ,
4340 ) -> CudaResult < Self > {
4441 let trace_len = circuit. trace_len ;
4542 assert ! ( trace_len. is_power_of_two( ) ) ;
4643 let log_domain_size = trace_len. trailing_zeros ( ) ;
47- let witness_subtree = & circuit. witness_layout ;
48- let witness_columns_count = witness_subtree. total_width ;
49- let mut witness_holder = TraceHolder :: new (
44+ let witness_columns_count = circuit. witness_layout . total_width ;
45+ let witness_holder = TraceHolder :: new (
5046 log_domain_size,
5147 log_lde_factor,
5248 0 ,
@@ -55,9 +51,8 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
5551 true ,
5652 context,
5753 ) ?;
58- let memory_subtree = & circuit. memory_layout ;
59- let memory_columns_count = memory_subtree. total_width ;
60- let mut memory_holder = TraceHolder :: new (
54+ let memory_columns_count = circuit. memory_layout . total_width ;
55+ let memory_holder = TraceHolder :: new (
6156 log_domain_size,
6257 log_lde_factor,
6358 0 ,
@@ -66,6 +61,28 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
6661 true ,
6762 context,
6863 ) ?;
64+ Ok ( Self {
65+ witness_holder,
66+ memory_holder,
67+ generic_lookup_mapping : None ,
68+ callbacks : None ,
69+ public_inputs : None ,
70+ } )
71+ }
72+
73+ pub fn generate_witness (
74+ & mut self ,
75+ circuit : & CompiledCircuitArtifact < BF > ,
76+ setup : & SetupPrecomputations < C > ,
77+ tracing_data_transfer : TracingDataTransfer < ' a , C > ,
78+ circuit_sequence : usize ,
79+ context : & C ,
80+ ) -> CudaResult < ( ) > {
81+ let trace_len = circuit. trace_len ;
82+ assert ! ( trace_len. is_power_of_two( ) ) ;
83+ let log_domain_size = trace_len. trailing_zeros ( ) ;
84+ let witness_subtree = & circuit. witness_layout ;
85+ let memory_subtree = & circuit. memory_layout ;
6986 let generic_lookup_mapping_size = witness_subtree. width_3_lookups . len ( ) << log_domain_size;
7087 let mut generic_lookup_mapping = context. alloc ( generic_lookup_mapping_size) ?;
7188 let TracingDataTransfer {
@@ -75,6 +92,7 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
7592 transfer,
7693 } = tracing_data_transfer;
7794 transfer. ensure_transferred ( context) ?;
95+ self . callbacks = Some ( transfer. callbacks ) ;
7896 let stream = context. get_exec_stream ( ) ;
7997 assert_eq ! ( COMMON_TABLE_WIDTH , 3 ) ;
8098 assert_eq ! ( NUM_COLUMNS_FOR_COMMON_TABLE_WIDTH_SETUP , 4 ) ;
@@ -100,6 +118,8 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
100118 + timestamp_range_check_multiplicities_columns. num_elements,
101119 generic_multiplicities_columns. start
102120 ) ;
121+ let witness_holder = & mut self . witness_holder ;
122+ let memory_holder = & mut self . memory_holder ;
103123 match data_device {
104124 TracingDataDevice :: Main {
105125 setup_and_teardown,
@@ -172,15 +192,26 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
172192 trace_len,
173193 context,
174194 ) ?;
175- memory_holder. make_evaluations_sum_to_zero_extend_and_commit ( context) ?;
176- witness_holder. make_evaluations_sum_to_zero_extend_and_commit ( context) ?;
177- Ok ( Self {
178- witness_holder,
179- memory_holder,
180- generic_lookup_mapping : Some ( generic_lookup_mapping) ,
181- callbacks : transfer. callbacks ,
182- public_inputs : None ,
183- } )
195+ self . generic_lookup_mapping = Some ( generic_lookup_mapping) ;
196+ Ok ( ( ) )
197+ }
198+
199+ pub fn commit_witness (
200+ & mut self ,
201+ circuit : & ' a CompiledCircuitArtifact < BF > ,
202+ context : & C ,
203+ ) -> CudaResult < ( ) >
204+ where
205+ C :: HostAllocator : ' a ,
206+ {
207+ self . memory_holder
208+ . make_evaluations_sum_to_zero_extend_and_commit ( context) ?;
209+ self . memory_holder . produce_tree_caps ( context) ?;
210+ self . witness_holder
211+ . make_evaluations_sum_to_zero_extend_and_commit ( context) ?;
212+ self . witness_holder . produce_tree_caps ( context) ?;
213+ self . produce_public_inputs ( circuit, context) ?;
214+ Ok ( ( ) )
184215 }
185216
186217 pub fn produce_public_inputs (
@@ -257,7 +288,10 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
257288 guard. extend ( first_row_public_inputs) ;
258289 guard. extend ( one_before_last_row_public_inputs) ;
259290 } ;
260- self . callbacks . schedule ( function, stream) ?;
291+ self . callbacks
292+ . as_mut ( )
293+ . unwrap ( )
294+ . schedule ( function, stream) ?;
261295 self . public_inputs = Some ( public_inputs) ;
262296 Ok ( ( ) )
263297 }
0 commit comments