55
66#![ allow( clippy:: single_match, clippy:: new_without_default) ]
77
8- mod outputs;
98mod type_registry;
109
11- use outputs:: Outputs ;
1210use proc_macro2:: { Punct , Spacing , Span } ;
13- use quote:: { TokenStreamExt , quote} ;
14- use std:: { collections:: BTreeMap as Map , env, fs, ops:: Deref } ;
15- use syn:: {
16- Expr , ExprCall , ExprPath , ExprReference , Fields , FnArg , Ident , Item , ItemFn , Local , LocalInit ,
17- Meta , Pat , PatIdent , PatTuple , Path , Stmt , TypeReference , parse_quote,
18- punctuated:: Punctuated ,
19- token:: { Const , Eq , Let , Paren , Semi } ,
20- } ;
11+ use quote:: TokenStreamExt ;
12+ use std:: { env, fs, ops:: Deref } ;
13+ use syn:: { parse_quote, token:: Const , FnArg , Ident , Item , ItemFn , Meta , Pat , Stmt , TypeReference } ;
2114use type_registry:: TypeRegistry ;
2215
2316fn main ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
@@ -30,21 +23,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
3023 let code = fs:: read_to_string ( & args[ 1 ] ) ?;
3124 let mut ast = syn:: parse_file ( & code) ?;
3225 ast. attrs . push ( parse_quote ! {
33- #![ allow(
34- clippy:: identity_op,
35- clippy:: unnecessary_cast,
36- dead_code,
37- rustdoc:: broken_intra_doc_links,
38- unused_assignments,
39- unused_mut,
40- unused_variables
41- ) ]
26+ #![ allow( dead_code) ]
4227 } ) ;
4328
4429 let mut type_registry = TypeRegistry :: new ( ) ;
4530
4631 // Iterate over functions, transforming them into `const fn`
47- let mut const_deref = Vec :: new ( ) ;
4832 for item in & mut ast. items {
4933 match item {
5034 Item :: Fn ( func) => rewrite_fn_as_const ( func, & type_registry) ,
@@ -67,32 +51,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
6751 } ) ;
6852 }
6953
70- let ident = & ty. ident ;
71- if let Fields :: Unnamed ( unnamed) = & ty. fields {
72- if let Some ( unit) = unnamed. unnamed . first ( ) {
73- let unit_ty = & unit. ty ;
74- const_deref. push ( parse_quote ! {
75- impl #ident {
76- #[ inline]
77- pub const fn as_inner( & self ) -> & #unit_ty {
78- & self . 0
79- }
80-
81- #[ inline]
82- pub const fn into_inner( self ) -> #unit_ty {
83- self . 0
84- }
85- }
86- } ) ;
87- }
88- }
89-
90- type_registry. add_new_type ( ty)
54+ type_registry. add_newtype ( ty)
9155 }
9256 _ => ( ) ,
9357 }
9458 }
95- ast. items . extend_from_slice ( & const_deref) ;
9659
9760 println ! (
9861 "//! fiat-crypto output postprocessed by fiat-constify: <https://github.com/rustcrypto/utils>"
@@ -116,8 +79,6 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) {
11679 func. sig . constness = Some ( Const :: default ( ) ) ;
11780
11881 // Transform mutable arguments into return values.
119- let mut inputs = Punctuated :: new ( ) ;
120- let mut outputs = Outputs :: new ( type_registry) ;
12182 let mut stmts = Vec :: < Stmt > :: new ( ) ;
12283
12384 for arg in & func. sig . inputs {
@@ -129,8 +90,17 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) {
12990 elem,
13091 ..
13192 } ) => {
132- outputs. add ( get_ident_from_pat ( & t. pat ) , elem. deref ( ) . clone ( ) ) ;
133- continue ;
93+ if matches ! ( elem. deref( ) , syn:: Type :: Path ( _) ) {
94+ // Generation of reborrows, LLVM should optimize this out, and it definitely
95+ // will if `#[repr(transparent)]` is used.
96+ let ty = type_registry:: type_to_ident ( elem) . unwrap ( ) ;
97+ let ident = get_ident_from_pat ( & t. pat ) ;
98+ if type_registry. is_newtype ( ty) {
99+ stmts. push ( parse_quote ! {
100+ let #ident = & mut #ident. 0 ;
101+ } ) ;
102+ }
103+ }
134104 }
135105 syn:: Type :: Reference ( TypeReference {
136106 mutability : None ,
@@ -141,177 +111,17 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) {
141111 // will if `#[repr(transparent)]` is used.
142112 let ty = type_registry:: type_to_ident ( elem) . unwrap ( ) ;
143113 let ident = get_ident_from_pat ( & t. pat ) ;
144- if outputs . type_registry ( ) . is_new_type ( ty) {
114+ if type_registry. is_newtype ( ty) {
145115 stmts. push ( parse_quote ! {
146- let #ident = #ident. as_inner ( ) ;
116+ let #ident = & #ident. 0 ;
147117 } ) ;
148118 }
149119 }
150120 _ => ( ) ,
151121 }
152122 }
153-
154- // If the argument wasn't a mutable reference, add it as an input.
155- inputs. push ( arg. clone ( ) ) ;
156123 }
157124
158- // Replace inputs with ones where the mutable references have been filtered out
159- func. sig . inputs = inputs;
160- func. sig . output = outputs. to_return_type ( ) ;
161- stmts. extend ( rewrite_fn_body ( & func. block . stmts , & outputs) ) ;
125+ stmts. extend ( func. block . stmts . clone ( ) ) ;
162126 func. block . stmts = stmts;
163127}
164-
165- /// Rewrite the function body, adding let bindings with `Default::default()`
166- /// values for outputs, removing mutable references, and adding a return
167- /// value/tuple.
168- fn rewrite_fn_body ( stmts : & [ Stmt ] , outputs : & Outputs ) -> Vec < Stmt > {
169- let mut ident_assignments: Map < & Ident , Vec < & Expr > > = Map :: new ( ) ;
170- let mut rewritten = Vec :: new ( ) ;
171-
172- for stmt in stmts {
173- if let Stmt :: Expr ( Expr :: Assign ( assignment) , Some ( _) ) = stmt {
174- let lhs_path = match assignment. left . as_ref ( ) {
175- Expr :: Unary ( lhs) => {
176- if let Expr :: Path ( exprpath) = lhs. expr . as_ref ( ) {
177- Some ( exprpath)
178- } else {
179- panic ! ( "All unary exprpaths should have the LHS as the path" ) ;
180- }
181- }
182- Expr :: Index ( lhs) => {
183- if let Expr :: Path ( exprpath) = lhs. expr . as_ref ( ) {
184- Some ( exprpath)
185- } else {
186- panic ! ( "All unary exprpaths should have the LHS as the path" ) ;
187- }
188- }
189- Expr :: Call ( expr) => {
190- rewritten. push ( Stmt :: Local ( rewrite_fn_call ( expr. clone ( ) ) ) ) ;
191- None
192- }
193- _ => None ,
194- } ;
195- if let Some ( lhs_path) = lhs_path {
196- ident_assignments
197- . entry ( Path :: get_ident ( & lhs_path. path ) . unwrap ( ) )
198- . or_default ( )
199- . push ( & assignment. right ) ;
200- }
201- } else if let Stmt :: Expr ( Expr :: Call ( expr) , Some ( _) ) = stmt {
202- rewritten. push ( Stmt :: Local ( rewrite_fn_call ( expr. clone ( ) ) ) ) ;
203- } else if let Stmt :: Local ( Local {
204- pat : Pat :: Type ( pat) ,
205- ..
206- } ) = stmt
207- {
208- let unboxed = pat. pat . as_ref ( ) ;
209- if let Pat :: Ident ( PatIdent {
210- mutability : Some ( _) ,
211- ..
212- } ) = unboxed
213- {
214- // This is a mut var, in the case of fiat-crypto transformation dead code
215- } else {
216- rewritten. push ( stmt. clone ( ) ) ;
217- }
218- } else {
219- rewritten. push ( stmt. clone ( ) ) ;
220- }
221- }
222-
223- let mut asts = Vec :: new ( ) ;
224- for ( ident, ty) in outputs. ident_type_pairs ( ) {
225- let value = ident_assignments. get ( ident) . unwrap ( ) ;
226- let type_prefix = match type_registry:: type_to_ident ( ty) {
227- Some ( ident) if outputs. type_registry ( ) . is_new_type ( ident) => Some ( ty) ,
228- _ => None ,
229- } ;
230-
231- let ast = match ( type_prefix, value. len ( ) ) {
232- ( None , 1 ) => {
233- let first = value. first ( ) . unwrap ( ) ;
234- quote ! ( #first)
235- }
236- ( Some ( prefix) , 1 ) => {
237- let first = value. first ( ) . unwrap ( ) ;
238- quote ! ( #prefix( #first) )
239- }
240-
241- ( None , _) => {
242- quote ! ( [ #( #value) , * ] )
243- }
244- ( Some ( prefix) , _) => {
245- quote ! ( #prefix( [ #( #value) , * ] ) )
246- }
247- } ;
248- asts. push ( ast) ;
249- }
250-
251- let expr: Expr = parse_quote ! {
252- ( #( #asts) , * )
253- } ;
254-
255- rewritten. push ( Stmt :: Expr ( expr, None ) ) ;
256- rewritten
257- }
258-
259- /// Rewrite a function call, removing the mutable reference arguments and
260- /// let-binding return values for them instead.
261- fn rewrite_fn_call ( mut call : ExprCall ) -> Local {
262- let mut args = Punctuated :: new ( ) ;
263- let mut output = Punctuated :: new ( ) ;
264-
265- for arg in & call. args {
266- if let Expr :: Reference ( ExprReference {
267- mutability : Some ( _) ,
268- expr,
269- ..
270- } ) = arg
271- {
272- match expr. deref ( ) {
273- Expr :: Path ( ExprPath {
274- path : Path { segments, .. } ,
275- ..
276- } ) => {
277- assert_eq ! ( segments. len( ) , 1 , "expected only one segment in fn arg" ) ;
278- let ident = segments. first ( ) . unwrap ( ) . ident . clone ( ) ;
279-
280- output. push ( Pat :: Ident ( PatIdent {
281- attrs : Vec :: new ( ) ,
282- by_ref : None ,
283- mutability : None ,
284- ident,
285- subpat : None ,
286- } ) ) ;
287- }
288- other => panic ! ( "unexpected expr in fn arg: {:?}" , other) ,
289- }
290-
291- continue ;
292- }
293-
294- args. push ( arg. clone ( ) ) ;
295- }
296-
297- // Overwrite call arguments with the ones that aren't mutable references
298- call. args = args;
299-
300- let pat = Pat :: Tuple ( PatTuple {
301- attrs : Vec :: new ( ) ,
302- paren_token : Paren :: default ( ) ,
303- elems : output,
304- } ) ;
305-
306- Local {
307- attrs : Vec :: new ( ) ,
308- let_token : Let :: default ( ) ,
309- pat,
310- init : Some ( LocalInit {
311- eq_token : Eq :: default ( ) ,
312- expr : Box :: new ( Expr :: Call ( call) ) ,
313- diverge : None ,
314- } ) ,
315- semi_token : Semi :: default ( ) ,
316- }
317- }
0 commit comments