Skip to content

Commit 82ec002

Browse files
authored
Merge pull request #1346 from microsoft/fix322
Friendly overloads replace `PCWSTR*` parameters with `ReadOnlySpan<string>`
2 parents cb9ed22 + 6970df6 commit 82ec002

File tree

6 files changed

+159
-3
lines changed

6 files changed

+159
-3
lines changed

src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla
8282
return SyntaxFactory.ForStatement(Token(SyntaxKind.ForKeyword), Token(SyntaxKind.OpenParenToken), declaration!, default, semicolonToken, condition, semicolonToken, incrementors, Token(SyntaxKind.CloseParenToken), statement);
8383
}
8484

85+
internal static ForEachStatementSyntax ForEachStatement(TypeSyntax type, SyntaxToken identifier, ExpressionSyntax expression, StatementSyntax statement) => SyntaxFactory.ForEachStatement(type, identifier, expression, statement);
86+
8587
internal static StatementSyntax EmptyStatement() => SyntaxFactory.EmptyStatement(Token(SyntaxKind.SemicolonToken));
8688

8789
internal static NamespaceDeclarationSyntax NamespaceDeclaration(NameSyntax name) => SyntaxFactory.NamespaceDeclaration(Token(TriviaList(), SyntaxKind.NamespaceKeyword, TriviaList(Space)), name.WithTrailingTrivia(LineFeed), OpenBrace, default, default, default, CloseBrace, default);

src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs

+135-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ namespace Microsoft.Windows.CsWin32;
55

66
public partial class Generator
77
{
8+
private static readonly TypeSyntax PCWSTRTypeSyntax = QualifiedName(QualifiedName(IdentifierName(GlobalWinmdRootNamespaceAlias), IdentifierName("Foundation")), IdentifierName("PCWSTR"));
9+
810
private enum FriendlyOverloadOf
911
{
1012
ExternMethod,
@@ -268,9 +270,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
268270
{
269271
// TODO: add support for in/out size parameters. (e.g. RSGetViewports)
270272
// TODO: add support for lists of pointers via a generated pointer-wrapping struct (e.g. PSSetSamplers)
271-
272-
// It is possible that countParamIndex points to a parameter that is not on the extern method
273-
// when the parameter is the last one and was moved to a return value.
274273
if (!isPointerToPointer && TryHandleCountParam(elementType, nullableSource: true))
275274
{
276275
// This block intentionally left blank.
@@ -305,6 +304,136 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
305304
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
306305
arguments[param.SequenceNumber - 1] = Argument(localName);
307306
}
307+
308+
// Translate ReadOnlySpan<PCWSTR> to ReadOnlySpan<string>
309+
if (isIn && !isOut && isConst && externParam.Type is PointerTypeSyntax { ElementType: QualifiedNameSyntax { Right: { Identifier: { ValueText: "PCWSTR" } } } })
310+
{
311+
signatureChanged = true;
312+
313+
// Change the parameter type to ReadOnlySpan<string>
314+
parameters[param.SequenceNumber - 1] = externParam
315+
.WithType(MakeReadOnlySpanOfT(PredefinedType(Token(SyntaxKind.StringKeyword))));
316+
317+
IdentifierNameSyntax gcHandlesLocal = IdentifierName($"{origName}GCHandles");
318+
IdentifierNameSyntax pcwstrLocal = IdentifierName($"{origName}Pointers");
319+
320+
// var paramNameGCHandles = ArrayPool<GCHandle>.Shared.Rent(paramName.Length);
321+
var gcHandlesArrayDecl = LocalDeclarationStatement(VariableDeclaration(
322+
ArrayType(IdentifierName("var"))).AddVariables(
323+
VariableDeclarator(gcHandlesLocal.Identifier).WithInitializer(EqualsValueClause(
324+
InvocationExpression(
325+
MemberAccessExpression(
326+
SyntaxKind.SimpleMemberAccessExpression,
327+
MemberAccessExpression(
328+
SyntaxKind.SimpleMemberAccessExpression,
329+
ParseTypeName("global::System.Buffers.ArrayPool<global::System.Runtime.InteropServices.GCHandle>"),
330+
IdentifierName("Shared")),
331+
IdentifierName("Rent")))
332+
.WithArgumentList(ArgumentList().AddArguments(Argument(GetSpanLength(origName, false))))))));
333+
334+
// var paramNamePointers = ArrayPool<PCWSTR>.Shared.Rent(paramName.Length);
335+
var strsArrayDecl = LocalDeclarationStatement(VariableDeclaration(
336+
ArrayType(IdentifierName("var"))).AddVariables(
337+
VariableDeclarator(pcwstrLocal.Identifier).WithInitializer(EqualsValueClause(
338+
InvocationExpression(
339+
MemberAccessExpression(
340+
SyntaxKind.SimpleMemberAccessExpression,
341+
MemberAccessExpression(
342+
SyntaxKind.SimpleMemberAccessExpression,
343+
ParseTypeName($"global::System.Buffers.ArrayPool<{PCWSTRTypeSyntax.ToString()}>"),
344+
IdentifierName("Shared")),
345+
IdentifierName("Rent")))
346+
.WithArgumentList(ArgumentList().AddArguments(Argument(GetSpanLength(origName, false))))))));
347+
348+
// for (int i = 0; i < paramName.Length; i++)
349+
// {
350+
// paramNameGCHandles[i] = GCHandle.Alloc(paramName[i], GCHandleType.Pinned);
351+
// paramNamePointers[i] = (char*)paramNameGCHandles[i].AddrOfPinnedObject();
352+
// }
353+
IdentifierNameSyntax loopVariable = IdentifierName("i");
354+
var forLoop = ForStatement(
355+
VariableDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword))).AddVariables(
356+
VariableDeclarator(loopVariable.Identifier).WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))))),
357+
BinaryExpression(SyntaxKind.LessThanExpression, loopVariable, GetSpanLength(origName, false)),
358+
SingletonSeparatedList<ExpressionSyntax>(PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, loopVariable)),
359+
Block().AddStatements(
360+
ExpressionStatement(AssignmentExpression(
361+
SyntaxKind.SimpleAssignmentExpression,
362+
ElementAccessExpression(gcHandlesLocal).AddArgumentListArguments(Argument(loopVariable)),
363+
InvocationExpression(
364+
MemberAccessExpression(
365+
SyntaxKind.SimpleMemberAccessExpression,
366+
ParseTypeName("global::System.Runtime.InteropServices.GCHandle"),
367+
IdentifierName("Alloc")))
368+
.WithArgumentList(ArgumentList().AddArguments(
369+
Argument(ElementAccessExpression(origName).AddArgumentListArguments(Argument(loopVariable))),
370+
Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseTypeName("global::System.Runtime.InteropServices.GCHandleType"), IdentifierName("Pinned"))))))),
371+
ExpressionStatement(AssignmentExpression(
372+
SyntaxKind.SimpleAssignmentExpression,
373+
ElementAccessExpression(pcwstrLocal).AddArgumentListArguments(Argument(loopVariable)),
374+
CastExpression(
375+
PointerType(PredefinedType(Token(SyntaxKind.CharKeyword))),
376+
InvocationExpression(
377+
MemberAccessExpression(
378+
SyntaxKind.SimpleMemberAccessExpression,
379+
ElementAccessExpression(gcHandlesLocal).AddArgumentListArguments(Argument(loopVariable)),
380+
IdentifierName("AddrOfPinnedObject"))).WithArgumentList(ArgumentList()))))));
381+
382+
leadingOutsideTryStatements.AddRange([gcHandlesArrayDecl, strsArrayDecl, forLoop]);
383+
384+
// foreach (var gcHandle in paramNameGCHandles) gcHandle.Free();
385+
var freeHandleStatement = ForEachStatement(
386+
IdentifierName("var").WithTrailingTrivia(Space),
387+
Identifier("gcHandle").WithTrailingTrivia(Space),
388+
gcHandlesLocal.WithLeadingTrivia(Space),
389+
ExpressionStatement(
390+
InvocationExpression(
391+
MemberAccessExpression(
392+
SyntaxKind.SimpleMemberAccessExpression,
393+
IdentifierName("gcHandle"),
394+
IdentifierName("Free")))).WithLeadingTrivia(LineFeed));
395+
396+
// ArrayPool<GCHandle>.Shared.Return(gcHandlesArray);
397+
var returnGCHandlesArray = ExpressionStatement(
398+
InvocationExpression(
399+
MemberAccessExpression(
400+
SyntaxKind.SimpleMemberAccessExpression,
401+
ParseTypeName("global::System.Buffers.ArrayPool<global::System.Runtime.InteropServices.GCHandle>"),
402+
IdentifierName("Shared.Return")))
403+
.WithArgumentList(ArgumentList().AddArguments(Argument(gcHandlesLocal))));
404+
405+
// ArrayPool<PCWSTR>.Shared.Return(paramNamePointers);
406+
var returnStrsArray = ExpressionStatement(
407+
InvocationExpression(
408+
MemberAccessExpression(
409+
SyntaxKind.SimpleMemberAccessExpression,
410+
ParseTypeName($"global::System.Buffers.ArrayPool<{PCWSTRTypeSyntax.ToString()}> "),
411+
IdentifierName("Shared.Return")))
412+
.WithArgumentList(ArgumentList().AddArguments(Argument(pcwstrLocal))));
413+
414+
finallyStatements.AddRange([freeHandleStatement, returnGCHandlesArray, returnStrsArray]);
415+
416+
// Update fixed blocks already created to consume our array of pinned pointers
417+
bool found = false;
418+
for (int i = 0; i < fixedBlocks.Count; i++)
419+
{
420+
if (fixedBlocks[i] is VariableDeclarationSyntax { Variables: [VariableDeclaratorSyntax { Initializer: { Value: IdentifierNameSyntax { Identifier: SyntaxToken id } } initializer } variable] } declaration
421+
&& id.ValueText == externParam.Identifier.ValueText)
422+
{
423+
// fixed (PCWSTR* paramNamePointersPtr = strsArray)
424+
fixedBlocks[i] = declaration.WithVariables(SingletonSeparatedList(variable.WithInitializer(initializer.WithValue(pcwstrLocal))));
425+
found = true;
426+
break;
427+
}
428+
}
429+
430+
if (!found)
431+
{
432+
throw new GenerationFailedException("Unable to find existing fixed block to change.");
433+
}
434+
435+
arguments[param.SequenceNumber - 1] = Argument(localName);
436+
}
308437
}
309438
else if (isIn && isOptional && !isOut && !isPointerToPointer)
310439
{
@@ -485,6 +614,9 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
485614
bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
486615
{
487616
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
617+
618+
// It is possible that countParamIndex points to a parameter that is not on the extern method
619+
// when the parameter is the last one and was moved to a return value.
488620
if (countParamIndex.HasValue
489621
&& this.canUseSpan
490622
&& externMethodDeclaration.ParameterList.Parameters.Count > countParamIndex.Value

src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs

+14
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,20 @@ internal WhitespaceRewriter()
282282
}
283283
}
284284

285+
public override SyntaxNode? VisitForEachStatement(ForEachStatementSyntax node)
286+
{
287+
node = this.WithIndentingTrivia(node);
288+
if (node.Statement is BlockSyntax)
289+
{
290+
return base.VisitForEachStatement(node);
291+
}
292+
else
293+
{
294+
using var indent = new Indent(this);
295+
return base.VisitForEachStatement(node);
296+
}
297+
}
298+
285299
public override SyntaxNode? VisitReturnStatement(ReturnStatementSyntax node)
286300
{
287301
return base.VisitReturnStatement(this.WithIndentingTrivia(node));

test/GenerationSandbox.Tests/GeneratedForm.cs

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using Windows.Win32.Networking.ActiveDirectory;
99
using Windows.Win32.System.Com;
1010
using Windows.Win32.System.Diagnostics.Debug;
11+
using Windows.Win32.System.RestartManager;
1112
using Windows.Win32.System.Threading;
1213

1314
#pragma warning disable CA1812 // dead code
@@ -81,6 +82,11 @@ private static void WriteFile()
8182
PInvoke.WriteFile((SafeHandle?)null, new byte[2], &written, (NativeOverlapped*)null);
8283
}
8384

85+
private static void RmRegisterResources()
86+
{
87+
PInvoke.RmRegisterResources(0, ["a", "b"], [default(RM_UNIQUE_PROCESS)], ["a", "b"]);
88+
}
89+
8490
private class MyStream : IStream
8591
{
8692
public HRESULT Read(void* pv, uint cb, uint* pcbRead)

test/GenerationSandbox.Tests/NativeMethods.txt

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ GetProcAddress
1515
GetTickCount
1616
GetWindowText
1717
GetWindowTextLength
18+
RmRegisterResources
1819
HDC_UserSize
1920
HRESULT_FROM_WIN32
2021
IDirectorySearch

test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public void OutPWSTR_Parameters_AsSpan()
6868

6969
[Theory]
7070
[InlineData("WSManGetSessionOptionAsString")] // Uses the reserved keyword 'string' as a parameter name
71+
[InlineData("RmRegisterResources")] // Parameter with PCWSTR* (an array of native strings)
7172
public void InterestingAPIs(string name)
7273
{
7374
this.Generate(name);

0 commit comments

Comments
 (0)