Skip to content

Commit c3f0325

Browse files
authored
Merge pull request #1341 from microsoft/fix614
Improve `[Out] PWSTR` parameters in friendly overloads
2 parents 22d9475 + 8a0399e commit c3f0325

File tree

4 files changed

+123
-66
lines changed

4 files changed

+123
-66
lines changed

src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs

+95-66
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,26 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
107107

108108
IdentifierNameSyntax origName = IdentifierName(externParam.Identifier.ValueText);
109109

110+
bool isArray = false;
111+
bool isNullTerminated = false; // TODO
112+
short? countParamIndex = null;
113+
int? countConst = null;
114+
if (this.FindInteropDecorativeAttribute(paramAttributes, NativeArrayInfoAttribute) is CustomAttribute nativeArrayInfoAttribute)
115+
{
116+
isArray = true;
117+
NativeArrayInfo nativeArrayInfo = DecodeNativeArrayInfoAttribute(nativeArrayInfoAttribute);
118+
countParamIndex = nativeArrayInfo.CountParamIndex;
119+
countConst = nativeArrayInfo.CountConst;
120+
}
121+
else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(paramAttributes, MemorySizeAttribute) is CustomAttribute memorySizeAttribute)
122+
{
123+
// A very special case as documented in https://github.com/microsoft/win32metadata/issues/1555
124+
// where MemorySizeAttribute is applied to byte* parameters to indicate the size of the buffer.
125+
isArray = true;
126+
MemorySize memorySize = DecodeMemorySizeAttribute(memorySizeAttribute);
127+
countParamIndex = memorySize.BytesParamIndex;
128+
}
129+
110130
if (mustRemainAsPointer)
111131
{
112132
// This block intentionally left blank, so as to disable further processing that might try to
@@ -243,80 +263,19 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
243263
}
244264
}
245265

246-
bool isArray = false;
247-
bool isNullTerminated = false; // TODO
248-
short? sizeParamIndex = null;
249-
int? sizeConst = null;
250-
if (this.FindInteropDecorativeAttribute(paramAttributes, NativeArrayInfoAttribute) is CustomAttribute att)
251-
{
252-
isArray = true;
253-
NativeArrayInfo nativeArrayInfo = DecodeNativeArrayInfoAttribute(att);
254-
sizeParamIndex = nativeArrayInfo.CountParamIndex;
255-
sizeConst = nativeArrayInfo.CountConst;
256-
}
257-
else if (externParam.Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ByteKeyword } } && this.FindInteropDecorativeAttribute(paramAttributes, MemorySizeAttribute) is CustomAttribute att2)
258-
{
259-
// A very special case as documented in https://github.com/microsoft/win32metadata/issues/1555
260-
// where MemorySizeAttribute is applied to byte* parameters to indicate the size of the buffer.
261-
isArray = true;
262-
MemorySize memorySize = DecodeMemorySizeAttribute(att2);
263-
sizeParamIndex = memorySize.BytesParamIndex;
264-
}
265-
266266
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
267267
if (isArray)
268268
{
269269
// TODO: add support for in/out size parameters. (e.g. RSGetViewports)
270270
// TODO: add support for lists of pointers via a generated pointer-wrapping struct (e.g. PSSetSamplers)
271271

272-
// It is possible that sizeParamIndex points to a parameter that is not on the extern method
272+
// It is possible that countParamIndex points to a parameter that is not on the extern method
273273
// when the parameter is the last one and was moved to a return value.
274-
if (sizeParamIndex.HasValue
275-
&& this.canUseSpan
276-
&& externMethodDeclaration.ParameterList.Parameters.Count > sizeParamIndex.Value
277-
&& !(externMethodDeclaration.ParameterList.Parameters[sizeParamIndex.Value].Type is PointerTypeSyntax)
278-
&& !(externMethodDeclaration.ParameterList.Parameters[sizeParamIndex.Value].Modifiers.Any(SyntaxKind.OutKeyword) || externMethodDeclaration.ParameterList.Parameters[sizeParamIndex.Value].Modifiers.Any(SyntaxKind.RefKeyword))
279-
&& !isPointerToPointer)
274+
if (!isPointerToPointer && TryHandleCountParam(elementType, nullableSource: true))
280275
{
281-
signatureChanged = true;
282-
bool remainsRefType = true;
283-
if (externParam.Type is PointerTypeSyntax)
284-
{
285-
remainsRefType = false;
286-
parameters[param.SequenceNumber - 1] = parameters[param.SequenceNumber - 1]
287-
.WithType((isIn ? MakeReadOnlySpanOfT(elementType) : MakeSpanOfT(elementType)).WithTrailingTrivia(TriviaList(Space)));
288-
fixedBlocks.Add(VariableDeclaration(externParam.Type).AddVariables(
289-
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
290-
arguments[param.SequenceNumber - 1] = Argument(localName);
291-
}
292-
293-
if (lengthParamUsedBy.TryGetValue(sizeParamIndex.Value, out int userIndex))
294-
{
295-
// Multiple array parameters share a common 'length' parameter.
296-
// Since we're making this a little less obvious, add a quick if check in the helper method
297-
// that enforces that all such parameters have a common span length.
298-
ExpressionSyntax otherUserName = IdentifierName(parameters[userIndex].Identifier.ValueText);
299-
leadingStatements.Add(IfStatement(
300-
BinaryExpression(
301-
SyntaxKind.NotEqualsExpression,
302-
GetSpanLength(otherUserName, parameters[userIndex].Type is ArrayTypeSyntax),
303-
GetSpanLength(origName, remainsRefType)),
304-
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentException))).WithArgumentList(ArgumentList()))));
305-
}
306-
else
307-
{
308-
lengthParamUsedBy.Add(sizeParamIndex.Value, param.SequenceNumber - 1);
309-
}
310-
311-
ExpressionSyntax sizeArgExpression = GetSpanLength(origName, remainsRefType);
312-
if (!(parameters[sizeParamIndex.Value].Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.IntKeyword } }))
313-
{
314-
sizeArgExpression = CastExpression(parameters[sizeParamIndex.Value].Type!, sizeArgExpression);
315-
}
316-
317-
arguments[sizeParamIndex.Value] = Argument(sizeArgExpression);
276+
// This block intentionally left blank.
318277
}
319-
else if (sizeConst.HasValue && !isPointerToPointer && this.canUseSpan && externParam.Type is PointerTypeSyntax)
278+
else if (countConst.HasValue && !isPointerToPointer && this.canUseSpan && externParam.Type is PointerTypeSyntax)
320279
{
321280
// TODO: add support for lists of pointers via a generated pointer-wrapping struct
322281
signatureChanged = true;
@@ -333,7 +292,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
333292
BinaryExpression(
334293
SyntaxKind.LessThanExpression,
335294
GetSpanLength(origName, false /* we've converted it to be a span */),
336-
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(sizeConst.Value))),
295+
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(countConst.Value))),
337296
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentException))).WithArgumentList(ArgumentList()))));
338297
}
339298
else if (isNullTerminated && isConst && parameters[param.SequenceNumber - 1].Type is PointerTypeSyntax { ElementType: PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.CharKeyword } } })
@@ -475,6 +434,24 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
475434
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
476435
Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, localWstrName, IdentifierName("Length"))))))));
477436
}
437+
else if (!isIn && isOut && this.canUseSpan && externParam.Type is QualifiedNameSyntax { Right: { Identifier: { ValueText: "PWSTR" } } })
438+
{
439+
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
440+
signatureChanged = true;
441+
parameters[param.SequenceNumber - 1] = externParam
442+
.WithType(MakeSpanOfT(PredefinedType(Token(SyntaxKind.CharKeyword))));
443+
444+
// fixed (char* pParam1 = Param1)
445+
fixedBlocks.Add(VariableDeclaration(PointerType(PredefinedType(Token(SyntaxKind.CharKeyword)))).AddVariables(
446+
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(
447+
origName))));
448+
449+
// Use the char* pointer as the argument instead of the parameter.
450+
arguments[param.SequenceNumber - 1] = Argument(localName);
451+
452+
// Remove the size parameter if one exists.
453+
TryHandleCountParam(PredefinedType(Token(SyntaxKind.CharKeyword)), nullableSource: false);
454+
}
478455
else if (isIn && isOptional && !isOut && isManagedParameterType && parameterTypeInfo is PointerTypeHandleInfo ptrInfo && ptrInfo.ElementType.IsValueType(parameterTypeSyntaxSettings) is true && this.canUseUnsafeAsRef)
479456
{
480457
// The extern method couldn't have exposed the parameter as a pointer because the type is managed.
@@ -504,6 +481,58 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
504481
localName,
505482
nullRef));
506483
}
484+
485+
bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
486+
{
487+
IdentifierNameSyntax localName = IdentifierName(origName + "Local");
488+
if (countParamIndex.HasValue
489+
&& this.canUseSpan
490+
&& externMethodDeclaration.ParameterList.Parameters.Count > countParamIndex.Value
491+
&& !(externMethodDeclaration.ParameterList.Parameters[countParamIndex.Value].Type is PointerTypeSyntax)
492+
&& !(externMethodDeclaration.ParameterList.Parameters[countParamIndex.Value].Modifiers.Any(SyntaxKind.OutKeyword) || externMethodDeclaration.ParameterList.Parameters[countParamIndex.Value].Modifiers.Any(SyntaxKind.RefKeyword)))
493+
{
494+
signatureChanged = true;
495+
bool remainsRefType = nullableSource;
496+
if (externParam.Type is PointerTypeSyntax)
497+
{
498+
remainsRefType = false;
499+
parameters[param.SequenceNumber - 1] = parameters[param.SequenceNumber - 1]
500+
.WithType((isIn ? MakeReadOnlySpanOfT(elementType) : MakeSpanOfT(elementType)).WithTrailingTrivia(TriviaList(Space)));
501+
fixedBlocks.Add(VariableDeclaration(externParam.Type).AddVariables(
502+
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
503+
arguments[param.SequenceNumber - 1] = Argument(localName);
504+
}
505+
506+
if (lengthParamUsedBy.TryGetValue(countParamIndex.Value, out int userIndex))
507+
{
508+
// Multiple array parameters share a common 'length' parameter.
509+
// Since we're making this a little less obvious, add a quick if check in the helper method
510+
// that enforces that all such parameters have a common span length.
511+
ExpressionSyntax otherUserName = IdentifierName(parameters[userIndex].Identifier.ValueText);
512+
leadingStatements.Add(IfStatement(
513+
BinaryExpression(
514+
SyntaxKind.NotEqualsExpression,
515+
GetSpanLength(otherUserName, parameters[userIndex].Type is ArrayTypeSyntax),
516+
GetSpanLength(origName, remainsRefType)),
517+
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentException))).WithArgumentList(ArgumentList()))));
518+
}
519+
else
520+
{
521+
lengthParamUsedBy.Add(countParamIndex.Value, param.SequenceNumber - 1);
522+
}
523+
524+
ExpressionSyntax sizeArgExpression = GetSpanLength(origName, remainsRefType);
525+
if (!(parameters[countParamIndex.Value].Type is PredefinedTypeSyntax { Keyword: { RawKind: (int)SyntaxKind.IntKeyword } }))
526+
{
527+
sizeArgExpression = CastExpression(parameters[countParamIndex.Value].Type!, sizeArgExpression);
528+
}
529+
530+
arguments[countParamIndex.Value] = Argument(sizeArgExpression);
531+
return true;
532+
}
533+
534+
return false;
535+
}
507536
}
508537

509538
TypeSyntax? returnSafeHandleType = originalSignature.ReturnType is HandleTypeHandleInfo returnTypeHandleInfo

test/GenerationSandbox.Tests/BasicTests.cs

+11
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ public void HANDLE_OverridesEqualityOperator()
203203
Assert.False(handle5 == handle8);
204204
}
205205

206+
[Fact]
207+
public void GetWindowText_FriendlyOverload()
208+
{
209+
HWND hwnd = PInvoke.GetForegroundWindow();
210+
Span<char> text = stackalloc char[100];
211+
int len = PInvoke.GetWindowText(hwnd, text);
212+
Assert.NotEqual(0, len);
213+
string title = text.Slice(0, len).ToString();
214+
this.logger.WriteLine(title);
215+
}
216+
206217
[Fact]
207218
public void CreateFile()
208219
{

test/GenerationSandbox.Tests/NativeMethods.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ DISPLAYCONFIG_VIDEO_SIGNAL_INFO
1010
EnumWindows
1111
FILE_ACCESS_RIGHTS
1212
FLICK_DATA
13+
GetForegroundWindow
1314
GetProcAddress
1415
GetTickCount
1516
GetWindowText

test/Microsoft.Windows.CsWin32.Tests/FriendlyOverloadTests.cs

+16
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ public void InAttributeOnArraysProjectedAsReadOnlySpan()
5757
Assert.Equal(3, method.ParameterList.Parameters.Count(p => p.Type is GenericNameSyntax { Identifier.ValueText: "ReadOnlySpan" }));
5858
}
5959

60+
[Fact]
61+
public void OutPWSTR_Parameters_AsSpan()
62+
{
63+
const string name = "GetWindowText";
64+
this.Generate(name);
65+
MethodDeclarationSyntax friendlyOverload = Assert.Single(this.FindGeneratedMethod(name), m => m.ParameterList.Parameters.Count == 2);
66+
Assert.Equal("Span<char>", friendlyOverload.ParameterList.Parameters[1].Type?.ToString());
67+
}
68+
69+
[Theory]
70+
[InlineData("WSManGetSessionOptionAsString")] // Uses the reserved keyword 'string' as a parameter name
71+
public void InterestingAPIs(string name)
72+
{
73+
this.Generate(name);
74+
}
75+
6076
private void Generate(string name)
6177
{
6278
this.compilation = this.compilation.WithOptions(this.compilation.Options.WithPlatform(Platform.X64));

0 commit comments

Comments
 (0)