Skip to content

Commit 9b4355a

Browse files
committed
Merge branch 'strict-bind' of https://github.com/DapperLib/DapperAOT into strict-bind
# Conflicts: # test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.netfx.cs # test/Dapper.AOT.Test/Interceptors/QueryStrictBind.output.netfx.txt
2 parents 5147f54 + 840a4eb commit 9b4355a

17 files changed

+260
-819
lines changed

docs/rules/DAP049.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# DAP049
22

3-
When using `[StrictBind(...)]`, the elements should be the member names on the corresponding type. This error simply means that Dapper
3+
When using `[QueryColumns(...)]`, the elements should be the member names on the corresponding type. This error simply means that Dapper
44
could not find a member you specified. You can skip unwanted columns by passing `null` or `""`.

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs

+7-8
Original file line numberDiff line numberDiff line change
@@ -665,12 +665,11 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO
665665
}
666666
}
667667

668-
if (flags.HasAll(OperationFlags.BindResultsByName) && GetClosestDapperAttribute(ctx, op, Types.StrictBindAttribute) is not null)
668+
if (flags.HasFlag(OperationFlags.Query) && (IsEnabled(ctx, op, Types.StrictTypesAttribute, out _)))
669669
{
670-
flags |= OperationFlags.StrictBind;
670+
flags |= OperationFlags.StrictTypes;
671671
}
672672

673-
674673
if (exitFirstFailure && flags.HasAny(OperationFlags.DoNotGenerate))
675674
{
676675
resultType = null;
@@ -781,7 +780,7 @@ enum ParameterMode
781780
}
782781
}
783782

784-
ImmutableArray<string> strictBind = default;
783+
ImmutableArray<string> queryColumns = default;
785784
int? batchSize = null;
786785
foreach (var attrib in methodAttribs)
787786
{
@@ -820,8 +819,8 @@ enum ParameterMode
820819
batchSize = batchTmp;
821820
}
822821
break;
823-
case Types.StrictBindAttribute:
824-
strictBind = ParseStrictBindColumns(attrib);
822+
case Types.QueryColumnsAttribute:
823+
queryColumns = ParseQueryColumns(attrib);
825824
break;
826825
}
827826
}
@@ -851,8 +850,8 @@ enum ParameterMode
851850
}
852851

853852

854-
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null && strictBind.IsDefault
855-
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps, strictBind);
853+
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null && queryColumns.IsDefault
854+
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps, queryColumns);
856855
}
857856

858857
static void ValidateParameters(MemberMap? parameters, OperationFlags flags, Action<Diagnostic> onDiagnostic)

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ static void WriteSingleImplementation(
9292
break;
9393
}
9494
}
95-
sb.AppendReader(resultType, readers, additionalCommandState?.StrictBind ?? default);
95+
sb.AppendReader(resultType, readers, additionalCommandState?.QueryColumns ?? default);
9696
}
9797
else if (flags.HasAny(OperationFlags.Execute))
9898
{

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs

+23-22
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ internal void Generate(in GenerateState ctx)
391391

392392
if (flags.HasAny(OperationFlags.GetRowParser))
393393
{
394-
WriteGetRowParser(sb, resultType, readers, grp.Key.AdditionalCommandState?.StrictBind ?? default);
394+
WriteGetRowParser(sb, resultType, readers, grp.Key.AdditionalCommandState?.QueryColumns ?? default);
395395
}
396396
else if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql, additionalCommandState))
397397
{
@@ -451,7 +451,7 @@ internal void Generate(in GenerateState ctx)
451451

452452
foreach (var tuple in readers)
453453
{
454-
WriteRowFactory(ctx, sb, tuple.Type, tuple.Index, tuple.StrictBind, null /* TODO */);
454+
WriteRowFactory(ctx, sb, tuple.Type, tuple.Index, tuple.QueryColumns, null /* TODO */);
455455
}
456456

457457
foreach (var tuple in factories)
@@ -468,9 +468,9 @@ internal void Generate(in GenerateState ctx)
468468
ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.InterceptorsGenerated, null, callSiteCount, ctx.Nodes.Length, methodIndex, factories.Count(), readers.Count()));
469469
}
470470

471-
private static void WriteGetRowParser(CodeWriter sb, ITypeSymbol? resultType, in RowReaderState readers, ImmutableArray<string> strictBind)
471+
private static void WriteGetRowParser(CodeWriter sb, ITypeSymbol? resultType, in RowReaderState readers, ImmutableArray<string> queryColumns)
472472
{
473-
sb.Append("return ").AppendReader(resultType, readers, strictBind)
473+
sb.Append("return ").AppendReader(resultType, readers, queryColumns)
474474
.Append(".GetRowParser(reader, startIndex, length, returnNullIfFirstMissing);").NewLine();
475475
}
476476

@@ -732,7 +732,7 @@ static bool IsReserved(string name)
732732
}
733733
}
734734

735-
private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITypeSymbol type, int index, ImmutableArray<string> strictBind, Location? location)
735+
private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITypeSymbol type, int index, ImmutableArray<string> queryColumns, Location? location)
736736
{
737737
var map = MemberMap.CreateForResults(type);
738738
if (map is null) return;
@@ -780,7 +780,7 @@ void WriteRowFactoryFooter()
780780
void WriteTokenizeMethod()
781781
{
782782
sb.Append("public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span<int> tokens, int columnOffset)").Indent().NewLine();
783-
if (strictBind.IsDefault) // don't emit any tokens for strict binding
783+
if (queryColumns.IsDefault) // don't emit any tokens for strict binding
784784
{
785785
sb.Append("for (int i = 0; i < tokens.Length; i++)").Indent().NewLine()
786786
.Append("int token = -1;").NewLine()
@@ -808,11 +808,11 @@ void WriteTokenizeMethod()
808808
}
809809
else
810810
{
811-
sb.Append("// strict-bind: ");
812-
for (int i = 0; i < strictBind.Length; i++)
811+
sb.Append("// query columns: ");
812+
for (int i = 0; i < queryColumns.Length; i++)
813813
{
814814
if (i != 0) sb.Append(", ");
815-
var name = strictBind[i];
815+
var name = queryColumns[i];
816816
if (string.IsNullOrWhiteSpace(name))
817817
{
818818
sb.Append("(n/a)");
@@ -826,7 +826,7 @@ void WriteTokenizeMethod()
826826
sb.Append("'").Append(name).Append("'");
827827
}
828828
}
829-
sb.NewLine().Append("global::System.Diagnostics.Debug.Assert(tokens.Length == ").Append(strictBind.Length).Append(""", "Strict-bind column count mismatch");""").NewLine();
829+
sb.NewLine().Append("global::System.Diagnostics.Debug.Assert(tokens.Length >= ").Append(queryColumns.Length).Append(""", "Query columns count mismatch");""").NewLine();
830830
}
831831
sb.Append("return null;").Outdent().NewLine();
832832
}
@@ -879,15 +879,16 @@ void WriteReadMethod(in GenerateState context)
879879
}
880880

881881
ImmutableArray<ElementMember> readMembers;
882-
if (strictBind.IsDefault)
882+
if (queryColumns.IsDefault)
883883
{
884884
readMembers = members; // try to parse everything
885885
sb.Append("foreach (var token in tokens)");
886886
}
887887
else
888888
{
889-
readMembers = MapStrictBind(context, members, strictBind, location);
890-
sb.Append("for (int token = 0; token < tokens.Length; token++) // strict-bind");
889+
readMembers = MapQueryColumns(context, members, queryColumns, location);
890+
sb.Append("int lim = global::System.Math.Min(tokens.Length, ").Append(queryColumns.Length).Append(");").NewLine()
891+
.Append("for (int token = 0; token < lim; token++) // query-columns predefined");
891892
}
892893
sb.Indent().NewLine().Append("switch (token)").Indent().NewLine();
893894

@@ -920,7 +921,7 @@ void WriteReadMethod(in GenerateState context)
920921

921922
sb.NewLine().Append("break;").NewLine().Outdent(false);
922923

923-
if (strictBind.IsDefault) // type-forgiving version; only emitted when not using strict-bind
924+
if (queryColumns.IsDefault) // type-forgiving version; only emitted when not using strict-bind
924925
{
925926
sb.Append("case ").Append(token + map.Members.Length).Append(":").NewLine().Indent(false);
926927

@@ -1020,27 +1021,27 @@ void WriteDeferredMethodArgs()
10201021
}
10211022
}
10221023

1023-
private static ImmutableArray<ElementMember> MapStrictBind(in GenerateState state, ImmutableArray<ElementMember> members, ImmutableArray<string> strictBind, Location? location)
1024+
private static ImmutableArray<ElementMember> MapQueryColumns(in GenerateState state, ImmutableArray<ElementMember> members, ImmutableArray<string> queryColumns, Location? location)
10241025
{
1025-
if (strictBind.IsDefault) return members; // not bound
1026+
if (queryColumns.IsDefault) return members; // not bound
10261027

1027-
var result = ImmutableArray.CreateBuilder<ElementMember>(strictBind.Length);
1028-
foreach (var seek in strictBind)
1028+
var result = ImmutableArray.CreateBuilder<ElementMember>(queryColumns.Length);
1029+
foreach (var seek in queryColumns)
10291030
{
10301031
ElementMember found = default;
10311032
if (!string.IsNullOrWhiteSpace(seek))
10321033
{
1033-
foreach (var member in members)
1034+
foreach (var member in members) // look for direct match
10341035
{
1035-
if (member.CodeName == seek)
1036+
if (string.Equals(member.CodeName, seek, StringComparison.InvariantCultureIgnoreCase))
10361037
{
10371038
found = member;
10381039
break;
10391040
}
10401041
}
1041-
if (found.Member is null)
1042+
if (found.Member is null) // additional underscore-etc deviation
10421043
{
1043-
var normalizedSeek = StringHashing.Normalize(seek);
1044+
var normalizedSeek = StringHashing.Normalize(seek); // note: should already *be* normalized, but: be sure
10441045
foreach (var member in members)
10451046
{
10461047
if (StringHashing.NormalizedEquals(member.CodeName, normalizedSeek))

src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ internal sealed class AdditionalCommandState : IEquatable<AdditionalCommandState
3838
public readonly int? BatchSize;
3939
public readonly string? RowCountHintMemberName;
4040
public readonly ImmutableArray<CommandProperty> CommandProperties;
41-
public readonly ImmutableArray<string> StrictBind;
41+
public readonly ImmutableArray<string> QueryColumns;
4242

4343
public bool HasRowCountHint => RowCountHint > 0 || RowCountHintMemberName is not null;
4444

@@ -76,7 +76,7 @@ private static AdditionalCommandState Combine(AdditionalCommandState inherited,
7676

7777
return new(count, countMember, inherited.BatchSize ?? overrides.BatchSize,
7878
Concat(inherited.CommandProperties, overrides.CommandProperties),
79-
overrides.StrictBind.IsDefault ? inherited.StrictBind : overrides.StrictBind);
79+
overrides.QueryColumns.IsDefault ? inherited.QueryColumns : overrides.QueryColumns);
8080
}
8181

8282
static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x, ImmutableArray<CommandProperty> y)
@@ -91,13 +91,13 @@ static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x,
9191

9292
internal AdditionalCommandState(
9393
int rowCountHint, string? rowCountHintMemberName, int? batchSize,
94-
ImmutableArray<CommandProperty> commandProperties, ImmutableArray<string> strictBind)
94+
ImmutableArray<CommandProperty> commandProperties, ImmutableArray<string> queryColumns)
9595
{
9696
RowCountHint = rowCountHint;
9797
RowCountHintMemberName = rowCountHintMemberName;
9898
BatchSize = batchSize;
9999
CommandProperties = commandProperties;
100-
StrictBind = strictBind;
100+
QueryColumns = queryColumns;
101101
}
102102

103103

@@ -110,7 +110,7 @@ public bool Equals(in AdditionalCommandState other)
110110
&& BatchSize == other.BatchSize
111111
&& RowCountHintMemberName == other.RowCountHintMemberName
112112
&& ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties))
113-
&& StrictBind.Equals(other.StrictBind);
113+
&& QueryColumns.Equals(other.QueryColumns);
114114

115115
private static bool Equals(in ImmutableArray<CommandProperty> x, in ImmutableArray<CommandProperty> y)
116116
{
@@ -150,5 +150,5 @@ public override int GetHashCode()
150150
=> (RowCountHint + BatchSize.GetValueOrDefault()
151151
+ (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
152152
^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties))
153-
^ StrictBind.GetHashCode();
153+
^ QueryColumns.GetHashCode();
154154
}

src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,15 @@ public string ToStringRecycle()
328328
return s;
329329
}
330330

331-
internal CodeWriter AppendReader(ITypeSymbol? resultType, RowReaderState readers, ImmutableArray<string> strictBind)
331+
internal CodeWriter AppendReader(ITypeSymbol? resultType, RowReaderState readers, ImmutableArray<string> queryColumns)
332332
{
333333
if (IsInbuilt(resultType, out var helper))
334334
{
335335
return Append("global::Dapper.RowFactory.Inbuilt.").Append(helper);
336336
}
337337
else
338338
{
339-
return Append("RowFactory").Append(readers.GetIndex(resultType!, strictBind)).Append(".Instance");
339+
return Append("RowFactory").Append(readers.GetIndex(resultType!, queryColumns)).Append(".Instance");
340340
}
341341

342342
static bool IsInbuilt(ITypeSymbol? type, out string? helper)

src/Dapper.AOT.Analyzers/Internal/Inspection.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public static bool IsEnabled(in ParseState ctx, IOperation op, string attributeN
147147
return false;
148148
}
149149

150-
public static ImmutableArray<string> ParseStrictBindColumns(AttributeData attrib)
150+
public static ImmutableArray<string> ParseQueryColumns(AttributeData attrib)
151151
{
152152
ImmutableArray<string> result = default;
153153
if (attrib is not null && attrib.ConstructorArguments.Length == 1
@@ -170,7 +170,7 @@ public static ImmutableArray<string> ParseStrictBindColumns(AttributeData attrib
170170
arr[i] = "";
171171
break;
172172
case string s when s.IndexOf('\x03') < 0:
173-
arr[i] = s;
173+
arr[i] = string.IsNullOrWhiteSpace(s) ? "" : StringHashing.Normalize(s);
174174
break;
175175
default:
176176
fail = true;
@@ -1580,6 +1580,6 @@ enum OperationFlags
15801580
KnownParameters = 1 << 21,
15811581
QueryMultiple = 1 << 22,
15821582
GetRowParser = 1 << 23,
1583-
StrictBind = 1 << 24,
1583+
StrictTypes = 1 << 24,
15841584
NotAotSupported = 1 << 31,
15851585
}

src/Dapper.AOT.Analyzers/Internal/RowReaderState.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66

77
namespace Dapper.Internal;
88

9-
internal readonly struct RowReaderState : IEnumerable<(ITypeSymbol Type, ImmutableArray<string> StrictBind, int Index)>
9+
internal readonly struct RowReaderState : IEnumerable<(ITypeSymbol Type, ImmutableArray<string> QueryColumns, int Index)>
1010
{
1111
public RowReaderState() { }
12-
private readonly Dictionary<(ITypeSymbol Type, ImmutableArray<string> StrictBind), int> resultTypes = new ();
12+
private readonly Dictionary<(ITypeSymbol Type, ImmutableArray<string> QueryColumns), int> resultTypes = new ();
1313

1414
public int Count() => resultTypes.Count();
1515

16-
public IEnumerator<(ITypeSymbol Type, ImmutableArray<string> StrictBind, int Index)> GetEnumerator()
16+
public IEnumerator<(ITypeSymbol Type, ImmutableArray<string> QueryColumns, int Index)> GetEnumerator()
1717
{
1818
// retain discovery order
19-
return resultTypes.OrderBy(x => x.Value).Select(x => (x.Key.Type, x.Key.StrictBind, x.Value)).GetEnumerator();
19+
return resultTypes.OrderBy(x => x.Value).Select(x => (x.Key.Type, x.Key.QueryColumns, x.Value)).GetEnumerator();
2020
}
2121

2222
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
2323

24-
public int GetIndex(ITypeSymbol type, ImmutableArray<string> strictBind)
24+
public int GetIndex(ITypeSymbol type, ImmutableArray<string> queryColumns)
2525
{
26-
if (!resultTypes.TryGetValue((type, strictBind), out var index))
26+
if (!resultTypes.TryGetValue((type, queryColumns), out var index))
2727
{
28-
resultTypes.Add((type, strictBind), index = resultTypes.Count);
28+
resultTypes.Add((type, queryColumns), index = resultTypes.Count);
2929
}
3030
return index;
3131
}

src/Dapper.AOT.Analyzers/Internal/Types.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ public const string
1515
IDynamicParameters = nameof(IDynamicParameters),
1616
IncludeLocationAttribute = nameof(IncludeLocationAttribute),
1717
SqlSyntaxAttribute = nameof(SqlSyntaxAttribute),
18-
StrictBindAttribute = nameof(StrictBindAttribute),
18+
StrictTypesAttribute = nameof(StrictTypesAttribute),
19+
QueryColumnsAttribute = nameof(QueryColumnsAttribute),
1920
RowCountAttribute = nameof(RowCountAttribute),
2021
RowCountHintAttribute = nameof(RowCountHintAttribute),
2122
SqlAttribute = nameof(SqlAttribute),
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Immutable;
3+
using System.ComponentModel;
4+
using System.Diagnostics;
5+
6+
namespace Dapper;
7+
8+
/// <summary>
9+
/// Specifies the ordered columns returned by a query.
10+
/// </summary>
11+
/// <param name="columns"></param>
12+
[AttributeUsage(AttributeTargets.Method)]
13+
[ImmutableObject(true), Conditional("DEBUG")]
14+
public sealed class QueryColumnsAttribute(params string[] columns) : Attribute
15+
{
16+
/// <summary>
17+
/// The ordered columns returned by a query.
18+
/// </summary>
19+
public ImmutableArray<string> Columns { get; } = columns.ToImmutableArray();
20+
}

src/Dapper.AOT/StrictBindAttribute.cs

-22
This file was deleted.

0 commit comments

Comments
 (0)