1+ using System ;
2+ using System . Collections . Generic ;
3+ using System . Collections . ObjectModel ;
4+ using System . Linq ;
5+ using System . Text . RegularExpressions ;
6+ using Dibix . Sdk . Sql ;
7+ using Microsoft . SqlServer . TransactSql . ScriptDom ;
8+
9+ namespace Dibix . Sdk . CodeAnalysis . Rules
10+ {
11+ [ SqlCodeAnalysisRule ( id : 42 ) ]
12+ public sealed class SetStatementSqlCodeAnalysisRule : SqlCodeAnalysisRule
13+ {
14+ private static readonly IDictionary < SetOptions , bool > SupportedOptions = new Dictionary < SetOptions , bool >
15+ {
16+ [ SetOptions . NoCount ] = true
17+ , [ SetOptions . XactAbort ] = true
18+ } ;
19+ private string _procedureName ;
20+
21+ protected override string ErrorMessageTemplate => "{0}" ;
22+
23+ public override void Visit ( PredicateSetStatement node )
24+ {
25+ // Validate supported SET statements
26+ string expression = CollectExpression ( node ) ;
27+ if ( ! SupportedOptions . TryGetValue ( node . Options , out bool on ) )
28+ {
29+ this . ReportUnsupportedSetStatement ( node , expression ) ;
30+ return ;
31+ }
32+
33+ if ( node . IsOn != on )
34+ this . ReportUnsupportedSetOption ( node , on ? "ON" : "OFF" , expression ) ;
35+ }
36+
37+ public override void Visit ( SetCommandStatement node )
38+ {
39+ foreach ( SetCommand setCommand in node . Commands )
40+ Visit ( node , setCommand ) ;
41+ }
42+
43+ public override void Visit ( SetErrorLevelStatement node ) => this . ReportUnsupportedSetStatement ( node ) ;
44+
45+ //public override void Visit(SetIdentityInsertStatement node) => this.VisitOtherSetStatement(node);
46+
47+ public override void Visit ( SetOffsetsStatement node ) => this . ReportUnsupportedSetStatement ( node ) ;
48+
49+ public override void Visit ( SetRowCountStatement node ) => this . ReportUnsupportedSetStatement ( node ) ;
50+
51+ public override void Visit ( SetStatisticsStatement node ) => this . ReportUnsupportedSetStatement ( node ) ;
52+
53+ public override void Visit ( SetTextSizeStatement node ) => this . ReportUnsupportedSetStatement ( node ) ;
54+
55+ //public override void Visit(SetTransactionIsolationLevelStatement node) => this.ReportUnsupportedSetStatement(node);
56+
57+ public override void Visit ( SetUserStatement node ) => this . ReportUnsupportedSetStatement ( node ) ;
58+
59+ public override void Visit ( CreateProcedureStatement node )
60+ {
61+ this . _procedureName = node . ProcedureReference . Name . BaseIdentifier . Value ;
62+
63+ // Verify that SET XACT_ABORT ON is set when BEGIN TRANSACTION is used without a custom CATCH block that contains ROLLBACK TRANSACTION.
64+ // This ensures that the transaction is properly rolled back whenever an error occurs.
65+ XactAbortVisitor xactAbortVisitor = new XactAbortVisitor ( ) ;
66+ node . Accept ( xactAbortVisitor ) ;
67+ if ( xactAbortVisitor . HasXactAbortOn )
68+ return ;
69+
70+ // Collect all BEGIN TRANSACTION statements (whether inside TRY..CATCH or not)
71+ ICollection < BeginTransactionStatement > allBeginTransactionStatements = new Collection < BeginTransactionStatement > ( ) ;
72+ BeginTransactionStatementVisitor beginTransactionStatementVisitor = new BeginTransactionStatementVisitor ( ) ;
73+ node . Accept ( beginTransactionStatementVisitor ) ;
74+ beginTransactionStatementVisitor . BeginTransactionStatements . Each ( allBeginTransactionStatements . Add ) ;
75+
76+ // Collect TRY..CATCH statements
77+ ICollection < TryCatchStatementDescriptor > tryCatchStatements = new Collection < TryCatchStatementDescriptor > ( ) ;
78+ TryCatchStatementVisitor tryCatchStatementVisitor = new TryCatchStatementVisitor ( parent : null , tryCatchStatements ) ;
79+ node . Accept ( tryCatchStatementVisitor ) ;
80+
81+ // Walk each TRY..CATCH statement bottom up
82+ ICollection < TryCatchStatementDescriptor > nestedTryCatchStatements = tryCatchStatements . Where ( x => x . IsLeaf ) . ToArray ( ) ;
83+ ICollection < BeginTransactionStatementDescriptor > currentBeginTransactionStatements = new HashSet < BeginTransactionStatementDescriptor > ( ) ;
84+ foreach ( TryCatchStatementDescriptor tryCatchStatementDescriptor in nestedTryCatchStatements )
85+ {
86+ VisitTryCatch ( allBeginTransactionStatements , currentBeginTransactionStatements , tryCatchStatementDescriptor ) ;
87+ }
88+
89+ // Populate errors for remaining violations
90+ foreach ( BeginTransactionStatement beginTransactionStatement in allBeginTransactionStatements )
91+ {
92+ base . Fail ( beginTransactionStatement , "SET XACT_ABORT ON should be set when working with BEGIN TRANSACTION without a custom TRY..CATCH block to ensure the transaction is rolled back in case of an error" ) ;
93+ }
94+ }
95+
96+ private void Visit ( TSqlFragment node , SetCommand command )
97+ {
98+ switch ( command )
99+ {
100+ case GeneralSetCommand generalSetCommand :
101+ this . Visit ( node , generalSetCommand ) ;
102+ break ;
103+
104+ default :
105+ this . ReportUnsupportedSetStatement ( node ) ;
106+ break ;
107+ }
108+ }
109+
110+ private void Visit ( TSqlFragment node , GeneralSetCommand generalSetCommand )
111+ {
112+ switch ( generalSetCommand . CommandType )
113+ {
114+ // SET DEADLOCK_PRIORITY LOW
115+ case GeneralSetCommandType . DeadlockPriority :
116+ if ( ! ( generalSetCommand . Parameter is IdentifierLiteral identifierLiteral ) )
117+ return ; // ???
118+
119+ const string expectedDeadlockPriority = "LOW" ;
120+ if ( ! String . Equals ( identifierLiteral . Value , expectedDeadlockPriority , StringComparison . OrdinalIgnoreCase ) )
121+ this . ReportUnsupportedSetOption ( node , expectedDeadlockPriority ) ;
122+
123+ break ;
124+
125+ // SET CONTEXT_INFO
126+ case GeneralSetCommandType . ContextInfo :
127+ break ;
128+
129+ // SET DATEFORMAT MDY
130+ case GeneralSetCommandType . DateFormat :
131+ string expression = CollectExpression ( node ) ;
132+ base . FailIfUnsuppressed ( node , this . _procedureName , BuildUnsupportedSetStatementMessage ( expression ) ) ;
133+ break ;
134+
135+ default :
136+ this . ReportUnsupportedSetStatement ( node ) ;
137+ break ;
138+ }
139+ }
140+
141+ private void ReportUnsupportedSetStatement ( TSqlFragment node )
142+ {
143+ string expression = CollectExpression ( node ) ;
144+ this . ReportUnsupportedSetStatement ( node , expression ) ;
145+ }
146+ private void ReportUnsupportedSetStatement ( TSqlFragment node , string expression ) => base . Fail ( node , BuildUnsupportedSetStatementMessage ( expression ) ) ;
147+
148+ private static string BuildUnsupportedSetStatementMessage ( string expression ) => $ "Unsupported SET statement: { expression } ";
149+
150+ private void ReportUnsupportedSetOption ( TSqlFragment node , string expectedOption )
151+ {
152+ string expression = CollectExpression ( node ) ;
153+ this . ReportUnsupportedSetOption ( node , expectedOption , expression ) ;
154+ }
155+ private void ReportUnsupportedSetOption ( TSqlFragment node , string expectedOption , string expression ) => base . Fail ( node , $ "Only { expectedOption } is supported for SET statement: { expression } ") ;
156+
157+ private static string CollectExpression ( TSqlFragment node ) => Regex . Replace ( node . Dump ( ) , @"\s{2,}" , " " ) ;
158+
159+ private static void VisitTryCatch ( ICollection < BeginTransactionStatement > allBeginTransactionStatements , ICollection < BeginTransactionStatementDescriptor > currentBeginTransactionStatements , TryCatchStatementDescriptor tryCatchStatementDescriptor )
160+ {
161+ TryCatchStatementDescriptor current = tryCatchStatementDescriptor ;
162+ do
163+ {
164+ // Collect BEGIN TRANSACTION within TRY
165+ BeginTransactionStatementVisitor beginTransactionStatementVisitor = new BeginTransactionStatementVisitor ( ) ;
166+ current . Statement . TryStatements . Accept ( beginTransactionStatementVisitor ) ;
167+ beginTransactionStatementVisitor . BeginTransactionStatements . Each ( x => currentBeginTransactionStatements . Add ( new BeginTransactionStatementDescriptor ( x ) ) ) ;
168+
169+ // Collect THROW statement within CATCH
170+ ThrowStatementVisitor throwStatementVisitor = new ThrowStatementVisitor ( ) ;
171+ current . Statement . CatchStatements . Accept ( throwStatementVisitor ) ;
172+
173+ // Collect ROLLBACK statement within CATCH
174+ RollbackTransactionStatementVisitor rollbackTransactionStatementVisitor = new RollbackTransactionStatementVisitor ( ) ;
175+ current . Statement . CatchStatements . Accept ( rollbackTransactionStatementVisitor ) ;
176+
177+ // If the CATCH contains ROLLBACK, all nested BEGIN TRANSACTION statements, can be treated as rolled back, if there was no CATCH block that ignored the error
178+ if ( rollbackTransactionStatementVisitor . HasRollback )
179+ {
180+ // If there was a nested CATCH block without rethrow, the parent CATCH containing ROLLBACK won't be hit,
181+ // The affected BEGIN TRANSACTION statements in this case, can't be treated as rolled back and are marked as Keep = true.
182+ currentBeginTransactionStatements . Where ( x => ! x . Keep ) . Each ( x => allBeginTransactionStatements . Remove ( x . Statement ) ) ;
183+ currentBeginTransactionStatements . Clear ( ) ;
184+ }
185+
186+ // If the CATCH did not rethrow, the error is ignored and not bubbled up.
187+ // Therefore all BEGIN TRANSACTION statements in the TRY block won't be rolled back.
188+ if ( ! throwStatementVisitor . HasThrow )
189+ currentBeginTransactionStatements . Each ( x => x . Keep = true ) ;
190+
191+ current = current . Parent ;
192+ } while ( current != null ) ;
193+ }
194+
195+ private sealed class BeginTransactionStatementVisitor : TSqlFragmentVisitor
196+ {
197+ public ICollection < BeginTransactionStatement > BeginTransactionStatements { get ; } = new Collection < BeginTransactionStatement > ( ) ;
198+
199+ public override void ExplicitVisit ( BeginTransactionStatement node )
200+ {
201+ this . BeginTransactionStatements . Add ( node ) ;
202+ }
203+ }
204+
205+ private sealed class ThrowStatementVisitor : TSqlFragmentVisitor
206+ {
207+ public bool HasThrow { get ; private set ; }
208+
209+ public override void ExplicitVisit ( ThrowStatement node )
210+ {
211+ this . HasThrow = true ;
212+ }
213+ }
214+
215+ private sealed class RollbackTransactionStatementVisitor : TSqlFragmentVisitor
216+ {
217+ public bool HasRollback { get ; private set ; }
218+
219+ public override void ExplicitVisit ( RollbackTransactionStatement node )
220+ {
221+ this . HasRollback = true ;
222+ }
223+ }
224+
225+ private sealed class XactAbortVisitor : TSqlFragmentVisitor
226+ {
227+ public bool HasXactAbortOn { get ; private set ; }
228+
229+ public override void Visit ( PredicateSetStatement node )
230+ {
231+ this . HasXactAbortOn = node . Options == SetOptions . XactAbort && node . IsOn ;
232+ }
233+ }
234+
235+ private sealed class TryCatchStatementVisitor : TSqlFragmentVisitor
236+ {
237+ private readonly TryCatchStatementDescriptor _parent ;
238+ private readonly ICollection < TryCatchStatementDescriptor > _target ;
239+
240+ public bool HasMatch { get ; private set ; }
241+
242+ public TryCatchStatementVisitor ( TryCatchStatementDescriptor parent , ICollection < TryCatchStatementDescriptor > target )
243+ {
244+ this . _parent = parent ;
245+ this . _target = target ;
246+ }
247+
248+ public override void ExplicitVisit ( TryCatchStatement node )
249+ {
250+ this . HasMatch = true ;
251+
252+ TryCatchStatementDescriptor descriptor = new TryCatchStatementDescriptor ( node ) ;
253+ descriptor . Parent = this . _parent ;
254+ this . _target . Add ( descriptor ) ;
255+
256+ TryCatchStatementVisitor visitor = new TryCatchStatementVisitor ( descriptor , this . _target ) ;
257+ node . TryStatements . Accept ( visitor ) ;
258+ node . CatchStatements . Accept ( visitor ) ;
259+ descriptor . IsLeaf = ! visitor . HasMatch ;
260+ }
261+ }
262+
263+ private sealed class TryCatchStatementDescriptor
264+ {
265+ public TryCatchStatement Statement { get ; }
266+ public TryCatchStatementDescriptor Parent { get ; set ; }
267+ public bool IsLeaf { get ; set ; }
268+
269+ public TryCatchStatementDescriptor ( TryCatchStatement statement )
270+ {
271+ this . Statement = statement ;
272+ }
273+ }
274+
275+ private sealed class BeginTransactionStatementDescriptor
276+ {
277+ public BeginTransactionStatement Statement { get ; }
278+ public bool Keep { get ; set ; }
279+
280+ public BeginTransactionStatementDescriptor ( BeginTransactionStatement statement )
281+ {
282+ this . Statement = statement ;
283+ }
284+
285+ public override bool Equals ( object obj )
286+ {
287+ return ReferenceEquals ( this , obj ) || obj is BeginTransactionStatementDescriptor other && Equals ( other ) ;
288+ }
289+
290+ public override int GetHashCode ( )
291+ {
292+ return this . Statement . GetHashCode ( ) ;
293+ }
294+
295+ private bool Equals ( BeginTransactionStatementDescriptor other )
296+ {
297+ return this . Statement . Equals ( other . Statement ) ;
298+ }
299+ }
300+ }
301+ }
0 commit comments