Skip to content

Commit e2abdf7

Browse files
committed
Validate SET statements
1 parent cb0cf41 commit e2abdf7

6 files changed

Lines changed: 496 additions & 0 deletions

File tree

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Collections.ObjectModel;
4+
using System.Linq;
5+
using System.Text.RegularExpressions;
6+
using Dibix.Sdk.Sql;
7+
using Microsoft.SqlServer.TransactSql.ScriptDom;
8+
9+
namespace Dibix.Sdk.CodeAnalysis.Rules
10+
{
11+
[SqlCodeAnalysisRule(id: 42)]
12+
public sealed class SetStatementSqlCodeAnalysisRule : SqlCodeAnalysisRule
13+
{
14+
private static readonly IDictionary<SetOptions, bool> SupportedOptions = new Dictionary<SetOptions,bool>
15+
{
16+
[SetOptions.NoCount] = true
17+
, [SetOptions.XactAbort] = true
18+
};
19+
private string _procedureName;
20+
21+
protected override string ErrorMessageTemplate => "{0}";
22+
23+
public override void Visit(PredicateSetStatement node)
24+
{
25+
// Validate supported SET statements
26+
string expression = CollectExpression(node);
27+
if (!SupportedOptions.TryGetValue(node.Options, out bool on))
28+
{
29+
this.ReportUnsupportedSetStatement(node, expression);
30+
return;
31+
}
32+
33+
if (node.IsOn != on)
34+
this.ReportUnsupportedSetOption(node, on ? "ON" : "OFF", expression);
35+
}
36+
37+
public override void Visit(SetCommandStatement node)
38+
{
39+
foreach (SetCommand setCommand in node.Commands)
40+
Visit(node, setCommand);
41+
}
42+
43+
public override void Visit(SetErrorLevelStatement node) => this.ReportUnsupportedSetStatement(node);
44+
45+
//public override void Visit(SetIdentityInsertStatement node) => this.VisitOtherSetStatement(node);
46+
47+
public override void Visit(SetOffsetsStatement node) => this.ReportUnsupportedSetStatement(node);
48+
49+
public override void Visit(SetRowCountStatement node) => this.ReportUnsupportedSetStatement(node);
50+
51+
public override void Visit(SetStatisticsStatement node) => this.ReportUnsupportedSetStatement(node);
52+
53+
public override void Visit(SetTextSizeStatement node) => this.ReportUnsupportedSetStatement(node);
54+
55+
//public override void Visit(SetTransactionIsolationLevelStatement node) => this.ReportUnsupportedSetStatement(node);
56+
57+
public override void Visit(SetUserStatement node) => this.ReportUnsupportedSetStatement(node);
58+
59+
public override void Visit(CreateProcedureStatement node)
60+
{
61+
this._procedureName = node.ProcedureReference.Name.BaseIdentifier.Value;
62+
63+
// Verify that SET XACT_ABORT ON is set when BEGIN TRANSACTION is used without a custom CATCH block that contains ROLLBACK TRANSACTION.
64+
// This ensures that the transaction is properly rolled back whenever an error occurs.
65+
XactAbortVisitor xactAbortVisitor = new XactAbortVisitor();
66+
node.Accept(xactAbortVisitor);
67+
if (xactAbortVisitor.HasXactAbortOn)
68+
return;
69+
70+
// Collect all BEGIN TRANSACTION statements (whether inside TRY..CATCH or not)
71+
ICollection<BeginTransactionStatement> allBeginTransactionStatements = new Collection<BeginTransactionStatement>();
72+
BeginTransactionStatementVisitor beginTransactionStatementVisitor = new BeginTransactionStatementVisitor();
73+
node.Accept(beginTransactionStatementVisitor);
74+
beginTransactionStatementVisitor.BeginTransactionStatements.Each(allBeginTransactionStatements.Add);
75+
76+
// Collect TRY..CATCH statements
77+
ICollection<TryCatchStatementDescriptor> tryCatchStatements = new Collection<TryCatchStatementDescriptor>();
78+
TryCatchStatementVisitor tryCatchStatementVisitor = new TryCatchStatementVisitor(parent: null, tryCatchStatements);
79+
node.Accept(tryCatchStatementVisitor);
80+
81+
// Walk each TRY..CATCH statement bottom up
82+
ICollection<TryCatchStatementDescriptor> nestedTryCatchStatements = tryCatchStatements.Where(x => x.IsLeaf).ToArray();
83+
ICollection<BeginTransactionStatementDescriptor> currentBeginTransactionStatements = new HashSet<BeginTransactionStatementDescriptor>();
84+
foreach (TryCatchStatementDescriptor tryCatchStatementDescriptor in nestedTryCatchStatements)
85+
{
86+
VisitTryCatch(allBeginTransactionStatements, currentBeginTransactionStatements, tryCatchStatementDescriptor);
87+
}
88+
89+
// Populate errors for remaining violations
90+
foreach (BeginTransactionStatement beginTransactionStatement in allBeginTransactionStatements)
91+
{
92+
base.Fail(beginTransactionStatement, "SET XACT_ABORT ON should be set when working with BEGIN TRANSACTION without a custom TRY..CATCH block to ensure the transaction is rolled back in case of an error");
93+
}
94+
}
95+
96+
private void Visit(TSqlFragment node, SetCommand command)
97+
{
98+
switch (command)
99+
{
100+
case GeneralSetCommand generalSetCommand:
101+
this.Visit(node, generalSetCommand);
102+
break;
103+
104+
default:
105+
this.ReportUnsupportedSetStatement(node);
106+
break;
107+
}
108+
}
109+
110+
private void Visit(TSqlFragment node, GeneralSetCommand generalSetCommand)
111+
{
112+
switch (generalSetCommand.CommandType)
113+
{
114+
// SET DEADLOCK_PRIORITY LOW
115+
case GeneralSetCommandType.DeadlockPriority:
116+
if (!(generalSetCommand.Parameter is IdentifierLiteral identifierLiteral))
117+
return; // ???
118+
119+
const string expectedDeadlockPriority = "LOW";
120+
if (!String.Equals(identifierLiteral.Value, expectedDeadlockPriority, StringComparison.OrdinalIgnoreCase))
121+
this.ReportUnsupportedSetOption(node, expectedDeadlockPriority);
122+
123+
break;
124+
125+
// SET CONTEXT_INFO
126+
case GeneralSetCommandType.ContextInfo:
127+
break;
128+
129+
// SET DATEFORMAT MDY
130+
case GeneralSetCommandType.DateFormat:
131+
string expression = CollectExpression(node);
132+
base.FailIfUnsuppressed(node, this._procedureName, BuildUnsupportedSetStatementMessage(expression));
133+
break;
134+
135+
default:
136+
this.ReportUnsupportedSetStatement(node);
137+
break;
138+
}
139+
}
140+
141+
private void ReportUnsupportedSetStatement(TSqlFragment node)
142+
{
143+
string expression = CollectExpression(node);
144+
this.ReportUnsupportedSetStatement(node, expression);
145+
}
146+
private void ReportUnsupportedSetStatement(TSqlFragment node, string expression) => base.Fail(node, BuildUnsupportedSetStatementMessage(expression));
147+
148+
private static string BuildUnsupportedSetStatementMessage(string expression) => $"Unsupported SET statement: {expression}";
149+
150+
private void ReportUnsupportedSetOption(TSqlFragment node, string expectedOption)
151+
{
152+
string expression = CollectExpression(node);
153+
this.ReportUnsupportedSetOption(node, expectedOption, expression);
154+
}
155+
private void ReportUnsupportedSetOption(TSqlFragment node, string expectedOption, string expression) => base.Fail(node, $"Only {expectedOption} is supported for SET statement: {expression}");
156+
157+
private static string CollectExpression(TSqlFragment node) => Regex.Replace(node.Dump(), @"\s{2,}", " ");
158+
159+
private static void VisitTryCatch(ICollection<BeginTransactionStatement> allBeginTransactionStatements, ICollection<BeginTransactionStatementDescriptor> currentBeginTransactionStatements, TryCatchStatementDescriptor tryCatchStatementDescriptor)
160+
{
161+
TryCatchStatementDescriptor current = tryCatchStatementDescriptor;
162+
do
163+
{
164+
// Collect BEGIN TRANSACTION within TRY
165+
BeginTransactionStatementVisitor beginTransactionStatementVisitor = new BeginTransactionStatementVisitor();
166+
current.Statement.TryStatements.Accept(beginTransactionStatementVisitor);
167+
beginTransactionStatementVisitor.BeginTransactionStatements.Each(x => currentBeginTransactionStatements.Add(new BeginTransactionStatementDescriptor(x)));
168+
169+
// Collect THROW statement within CATCH
170+
ThrowStatementVisitor throwStatementVisitor = new ThrowStatementVisitor();
171+
current.Statement.CatchStatements.Accept(throwStatementVisitor);
172+
173+
// Collect ROLLBACK statement within CATCH
174+
RollbackTransactionStatementVisitor rollbackTransactionStatementVisitor = new RollbackTransactionStatementVisitor();
175+
current.Statement.CatchStatements.Accept(rollbackTransactionStatementVisitor);
176+
177+
// If the CATCH contains ROLLBACK, all nested BEGIN TRANSACTION statements, can be treated as rolled back, if there was no CATCH block that ignored the error
178+
if (rollbackTransactionStatementVisitor.HasRollback)
179+
{
180+
// If there was a nested CATCH block without rethrow, the parent CATCH containing ROLLBACK won't be hit,
181+
// The affected BEGIN TRANSACTION statements in this case, can't be treated as rolled back and are marked as Keep = true.
182+
currentBeginTransactionStatements.Where(x => !x.Keep).Each(x => allBeginTransactionStatements.Remove(x.Statement));
183+
currentBeginTransactionStatements.Clear();
184+
}
185+
186+
// If the CATCH did not rethrow, the error is ignored and not bubbled up.
187+
// Therefore all BEGIN TRANSACTION statements in the TRY block won't be rolled back.
188+
if (!throwStatementVisitor.HasThrow)
189+
currentBeginTransactionStatements.Each(x => x.Keep = true);
190+
191+
current = current.Parent;
192+
} while (current != null);
193+
}
194+
195+
private sealed class BeginTransactionStatementVisitor : TSqlFragmentVisitor
196+
{
197+
public ICollection<BeginTransactionStatement> BeginTransactionStatements { get; } = new Collection<BeginTransactionStatement>();
198+
199+
public override void ExplicitVisit(BeginTransactionStatement node)
200+
{
201+
this.BeginTransactionStatements.Add(node);
202+
}
203+
}
204+
205+
private sealed class ThrowStatementVisitor : TSqlFragmentVisitor
206+
{
207+
public bool HasThrow { get; private set; }
208+
209+
public override void ExplicitVisit(ThrowStatement node)
210+
{
211+
this.HasThrow = true;
212+
}
213+
}
214+
215+
private sealed class RollbackTransactionStatementVisitor : TSqlFragmentVisitor
216+
{
217+
public bool HasRollback { get; private set; }
218+
219+
public override void ExplicitVisit(RollbackTransactionStatement node)
220+
{
221+
this.HasRollback = true;
222+
}
223+
}
224+
225+
private sealed class XactAbortVisitor : TSqlFragmentVisitor
226+
{
227+
public bool HasXactAbortOn { get; private set; }
228+
229+
public override void Visit(PredicateSetStatement node)
230+
{
231+
this.HasXactAbortOn = node.Options == SetOptions.XactAbort && node.IsOn;
232+
}
233+
}
234+
235+
private sealed class TryCatchStatementVisitor : TSqlFragmentVisitor
236+
{
237+
private readonly TryCatchStatementDescriptor _parent;
238+
private readonly ICollection<TryCatchStatementDescriptor> _target;
239+
240+
public bool HasMatch { get; private set; }
241+
242+
public TryCatchStatementVisitor(TryCatchStatementDescriptor parent, ICollection<TryCatchStatementDescriptor> target)
243+
{
244+
this._parent = parent;
245+
this._target = target;
246+
}
247+
248+
public override void ExplicitVisit(TryCatchStatement node)
249+
{
250+
this.HasMatch = true;
251+
252+
TryCatchStatementDescriptor descriptor = new TryCatchStatementDescriptor(node);
253+
descriptor.Parent = this._parent;
254+
this._target.Add(descriptor);
255+
256+
TryCatchStatementVisitor visitor = new TryCatchStatementVisitor(descriptor, this._target);
257+
node.TryStatements.Accept(visitor);
258+
node.CatchStatements.Accept(visitor);
259+
descriptor.IsLeaf = !visitor.HasMatch;
260+
}
261+
}
262+
263+
private sealed class TryCatchStatementDescriptor
264+
{
265+
public TryCatchStatement Statement { get; }
266+
public TryCatchStatementDescriptor Parent { get; set; }
267+
public bool IsLeaf { get; set; }
268+
269+
public TryCatchStatementDescriptor(TryCatchStatement statement)
270+
{
271+
this.Statement = statement;
272+
}
273+
}
274+
275+
private sealed class BeginTransactionStatementDescriptor
276+
{
277+
public BeginTransactionStatement Statement { get; }
278+
public bool Keep { get; set; }
279+
280+
public BeginTransactionStatementDescriptor(BeginTransactionStatement statement)
281+
{
282+
this.Statement = statement;
283+
}
284+
285+
public override bool Equals(object obj)
286+
{
287+
return ReferenceEquals(this, obj) || obj is BeginTransactionStatementDescriptor other && Equals(other);
288+
}
289+
290+
public override int GetHashCode()
291+
{
292+
return this.Statement.GetHashCode();
293+
}
294+
295+
private bool Equals(BeginTransactionStatementDescriptor other)
296+
{
297+
return this.Statement.Equals(other.Statement);
298+
}
299+
}
300+
}
301+
}

src/Dibix.Sdk/Environment/lockfile

80 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)