@@ -115,41 +115,101 @@ impl<'tcx> LateLintPass<'tcx> for DuplicateMutableAccounts {
115115 }
116116 }
117117 }
118+ } else {
119+ // perform alternate constraint check, e.g., check fn bodies, then check key checks
120+ self. check_fn ( )
121+ }
122+
123+ // TODO: how to enforce that this is only called when necessary?
124+ fn check_fn (
125+ & mut self ,
126+ cx : & LateContext < ' tcx > ,
127+ _: FnKind < ' tcx > ,
128+ _: & ' tcx FnDecl < ' tcx > ,
129+ body : & ' tcx Body < ' tcx > ,
130+ span : Span ,
131+ _: HirId ,
132+ ) {
133+ if !span. from_expansion ( ) {
134+ let accounts = get_referenced_accounts ( cx, body) ;
135+
136+ accounts. values ( ) . for_each ( |exprs| {
137+ // TODO: figure out handling of >2 accounts
138+ match exprs. len ( ) {
139+ 2 => {
140+ let first = exprs[ 0 ] ;
141+ let second = exprs[ 1 ] ;
142+ if !contains_key_call ( cx, body, first) {
143+ span_lint_and_help (
144+ cx,
145+ DUP_MUTABLE_ACCOUNTS_2 ,
146+ first. span ,
147+ "this expression does not have a key check but has the same account type as another expression" ,
148+ Some ( second. span ) ,
149+ "add a key check to make sure the accounts have different keys, e.g., x.key() != y.key()" ,
150+ ) ;
151+ }
152+ if !contains_key_call ( cx, body, second) {
153+ span_lint_and_help (
154+ cx,
155+ DUP_MUTABLE_ACCOUNTS_2 ,
156+ second. span ,
157+ "this expression does not have a key check but has the same account type as another expression" ,
158+ Some ( first. span ) ,
159+ "add a key check to make sure the accounts have different keys, e.g., x.key() != y.key()" ,
160+ ) ;
161+ }
162+ } ,
163+ n if n > 2 => {
164+ span_lint_and_note (
165+ cx,
166+ DUP_MUTABLE_ACCOUNTS_2 ,
167+ exprs[ 0 ] . span ,
168+ & format ! ( "the following expression has the same account type as {} other accounts" , exprs. len( ) ) ,
169+ None ,
170+ "might not check that each account has a unique key"
171+ )
172+ } ,
173+ _ => { }
174+ }
175+ } ) ;
176+ }
118177 }
119178 }
120179}
121180
122- /// Returns the `DefId` of the anchor account type, ie, `T` in `Account<'info, T>`.
123- /// Returns `None` if the type of `field` is not an anchor account.
124- fn get_anchor_account_type_def_id ( field : & FieldDef ) -> Option < DefId > {
125- if_chain ! {
126- if let TyKind :: Path ( qpath) = & field. ty. kind;
127- if let QPath :: Resolved ( _, path) = qpath;
128- if !path. segments. is_empty( ) ;
129- if let Some ( generic_args) = path. segments[ 0 ] . args;
130- if generic_args. args. len( ) == ANCHOR_ACCOUNT_GENERIC_ARG_COUNT ;
131- if let GenericArg :: Type ( hir_ty) = & generic_args. args[ 1 ] ;
132- then {
133- get_def_id( hir_ty)
134- } else {
135- None
181+ mod anchor_constraint_check {
182+ /// Returns the `DefId` of the anchor account type, ie, `T` in `Account<'info, T>`.
183+ /// Returns `None` if the type of `field` is not an anchor account.
184+ fn get_anchor_account_type_def_id ( field : & FieldDef ) -> Option < DefId > {
185+ if_chain ! {
186+ if let TyKind :: Path ( qpath) = & field. ty. kind;
187+ if let QPath :: Resolved ( _, path) = qpath;
188+ if !path. segments. is_empty( ) ;
189+ if let Some ( generic_args) = path. segments[ 0 ] . args;
190+ if generic_args. args. len( ) == ANCHOR_ACCOUNT_GENERIC_ARG_COUNT ;
191+ if let GenericArg :: Type ( hir_ty) = & generic_args. args[ 1 ] ;
192+ then {
193+ get_def_id( hir_ty)
194+ } else {
195+ None
196+ }
136197 }
137198 }
138- }
139199
140- /// Returns the `DefId` of `ty`, an hir type. Returns `None` if cannot resolve type.
141- fn get_def_id ( ty : & rustc_hir:: Ty ) -> Option < DefId > {
142- if_chain ! {
143- if let TyKind :: Path ( qpath) = & ty. kind;
144- if let QPath :: Resolved ( _, path) = qpath;
145- if let Res :: Def ( _, def_id) = path. res;
146- then {
147- Some ( def_id)
148- } else {
149- None
200+ /// Returns the `DefId` of `ty`, an hir type. Returns `None` if cannot resolve type.
201+ fn get_def_id ( ty : & rustc_hir:: Ty ) -> Option < DefId > {
202+ if_chain ! {
203+ if let TyKind :: Path ( qpath) = & ty. kind;
204+ if let QPath :: Resolved ( _, path) = qpath;
205+ if let Res :: Def ( _, def_id) = path. res;
206+ then {
207+ Some ( def_id)
208+ } else {
209+ None
210+ }
150211 }
151212 }
152- }
153213
154214/// Returns a `TokenStream` of form: `a`.key() != `b`.key().
155215fn create_key_check_constraint_tokenstream ( a : Symbol , b : Symbol ) -> TokenStream {
@@ -175,17 +235,27 @@ fn create_key_check_constraint_tokenstream(a: Symbol, b: Symbol) -> TokenStream
175235 ) ) ,
176236 ] ;
177237
178- TokenStream :: new ( constraint)
179- }
238+ TokenStream :: new ( constraint)
239+ }
240+
241+ /// Returns a `TokenTree::Token` which has `TokenKind::Ident`, with the string set to `s`.
242+ fn create_token_from_ident ( s : & str ) -> TokenTree {
243+ let ident = Ident :: from_str ( s) ;
244+ TokenTree :: Token ( Token :: from_ast_ident ( ident) )
245+ }
246+
247+ #[ derive( Debug , Default ) ]
248+ pub struct Streams ( Vec < TokenStream > ) ;
180249
181- /// Returns a `TokenTree::Token` which has `TokenKind::Ident`, with the string set to `s`.
182- fn create_token_from_ident ( s : & str ) -> TokenTree {
183- let ident = Ident :: from_str ( s) ;
184- TokenTree :: Token ( Token :: from_ast_ident ( ident) )
250+ impl Streams {
251+ /// Returns true if `self` contains `other`, by comparing if there is an
252+ /// identical `TokenStream` in `self` regardless of span.
253+ fn contains ( & self , other : & TokenStream ) -> bool {
254+ self . 0 . iter ( ) . any ( |stream| stream. eq_unspanned ( other) )
255+ }
256+ }
185257}
186258
187- #[ derive( Debug , Default ) ]
188- pub struct Streams ( Vec < TokenStream > ) ;
189259
190260impl Streams {
191261 /// Returns true if `self` has a TokenStream that `other` is a substream of
@@ -222,6 +292,98 @@ impl Streams {
222292 }
223293}
224294
295+ mod alternate_constraint_check {
296+ struct AccountUses < ' cx , ' tcx > {
297+ cx : & ' cx LateContext < ' tcx > ,
298+ uses : HashMap < DefId , Vec < & ' tcx Expr < ' tcx > > > ,
299+ }
300+
301+ fn get_referenced_accounts < ' tcx > (
302+ cx : & LateContext < ' tcx > ,
303+ body : & ' tcx Body < ' tcx > ,
304+ ) -> HashMap < DefId , Vec < & ' tcx Expr < ' tcx > > > {
305+ let mut accounts = AccountUses {
306+ cx,
307+ uses : HashMap :: new ( ) ,
308+ } ;
309+
310+ accounts. visit_expr ( & body. value ) ;
311+ accounts. uses
312+ }
313+
314+ impl < ' cx , ' tcx > Visitor < ' tcx > for AccountUses < ' cx , ' tcx > {
315+ fn visit_expr ( & mut self , expr : & ' tcx Expr < ' tcx > ) {
316+ if_chain ! {
317+ // get mutable reference expressions
318+ if let ExprKind :: AddrOf ( _, mutability, mut_expr) = expr. kind;
319+ if let Mutability :: Mut = mutability;
320+ // check type of expr == Account<'info, T>
321+ let middle_ty = self . cx. typeck_results( ) . expr_ty( mut_expr) ;
322+ if match_type( self . cx, middle_ty, & paths:: ANCHOR_ACCOUNT ) ;
323+ // grab T generic parameter
324+ if let TyKind :: Adt ( _adt_def, substs) = middle_ty. kind( ) ;
325+ if substs. len( ) == ANCHOR_ACCOUNT_GENERIC_ARG_COUNT ;
326+ let account_type = substs[ 1 ] . expect_ty( ) ; // TODO: could just store middle::Ty instead of DefId?
327+ if let Some ( adt_def) = account_type. ty_adt_def( ) ;
328+ then {
329+ let def_id = adt_def. did( ) ;
330+ if let Some ( exprs) = self . uses. get_mut( & def_id) {
331+ let mut spanless_eq = SpanlessEq :: new( self . cx) ;
332+ // check that expr is not a duplicate within its particular key-pair
333+ if exprs. iter( ) . all( |e| !spanless_eq. eq_expr( e, mut_expr) ) {
334+ exprs. push( mut_expr) ;
335+ }
336+ } else {
337+ self . uses. insert( def_id, vec![ mut_expr] ) ;
338+ }
339+ }
340+ }
341+ walk_expr ( self , expr) ;
342+ }
343+ }
344+
345+ /// Performs a walk on `body`, checking whether there exists an expression that contains
346+ /// a `key()` method call on `account_expr`.
347+ fn contains_key_call < ' tcx > (
348+ cx : & LateContext < ' tcx > ,
349+ body : & ' tcx Body < ' tcx > ,
350+ account_expr : & Expr < ' tcx > ,
351+ ) -> bool {
352+ visit_expr_no_bodies ( & body. value , |expr| {
353+ if_chain ! {
354+ if let ExprKind :: MethodCall ( path_seg, exprs, _span) = expr. kind;
355+ if path_seg. ident. name. as_str( ) == "key" ;
356+ if !exprs. is_empty( ) ;
357+ let mut spanless_eq = SpanlessEq :: new( cx) ;
358+ if spanless_eq. eq_expr( & exprs[ 0 ] , account_expr) ;
359+ then {
360+ true
361+ } else {
362+ false
363+ }
364+ }
365+ } )
366+ }
367+ }
368+
369+ // /// Splits `stream` into a vector of substreams, separated by `delimiter`.
370+ // fn split(stream: CursorRef, delimiter: TokenKind) -> Vec<TokenStream> {
371+ // let mut split_streams: Vec<TokenStream> = Vec::new();
372+ // let mut temp: Vec<TreeAndSpacing> = Vec::new();
373+ // let delim = TokenTree::Token(Token::new(delimiter, DUMMY_SP));
374+
375+ // stream.for_each(|t| {
376+ // if t.eq_unspanned(&delim) {
377+ // split_streams.push(TokenStream::new(temp.clone()));
378+ // temp.clear();
379+ // } else {
380+ // temp.push(TreeAndSpacing::from(t.clone()));
381+ // }
382+ // });
383+ // split_streams.push(TokenStream::new(temp));
384+ // split_streams
385+ // }
386+
225387#[ test]
226388fn insecure ( ) {
227389 dylint_testing:: ui_test_example ( env ! ( "CARGO_PKG_NAME" ) , "insecure" ) ;
0 commit comments