Skip to content

Commit 73ab302

Browse files
authored
Merge pull request #1335 from microsoft/fix397
Generate `WinRTCustomMarshaler` when referenced from extern methods
2 parents 1beb6ff + d3ce391 commit 73ab302

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

src/Microsoft.Windows.CsWin32/Generator.GeneratedCode.cs

+16
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ internal GeneratedCode(GeneratedCode parent)
6666
internal bool IsEmpty => this.modulesAndMembers.Count == 0 && this.types.Count == 0 && this.fieldsToSyntax.Count == 0 && this.safeHandleTypes.Count == 0 && this.specialTypes.Count == 0
6767
&& this.inlineArrayIndexerExtensionsMembers.Count == 0 && this.comInterfaceFriendlyExtensionsMembers.Count == 0 && this.macros.Count == 0 && this.inlineArrays.Count == 0;
6868

69+
internal bool NeedsWinRTCustomMarshaler { get; private set; }
70+
6971
internal IEnumerable<MemberDeclarationSyntax> GeneratedTypes => this.GetTypesWithInjectedFields()
7072
.Concat(this.specialTypes.Values.Where(st => !st.TopLevel).Select(st => st.Type))
7173
.Concat(this.safeHandleTypes)
@@ -111,6 +113,7 @@ internal void AddMemberToModule(string moduleName, MemberDeclarationSyntax membe
111113
}
112114

113115
methodsList.Add(member);
116+
this.NeedsWinRTCustomMarshaler |= RequiresWinRTCustomMarshaler(member);
114117
}
115118

116119
internal void AddMemberToModule(string moduleName, IEnumerable<MemberDeclarationSyntax> members)
@@ -123,6 +126,7 @@ internal void AddMemberToModule(string moduleName, IEnumerable<MemberDeclaration
123126
}
124127

125128
methodsList.AddRange(members);
129+
this.NeedsWinRTCustomMarshaler |= members.Any(m => RequiresWinRTCustomMarshaler(m));
126130
}
127131

128132
internal void AddConstant(FieldDefinitionHandle fieldDefHandle, FieldDeclarationSyntax constantDeclaration, TypeDefinitionHandle? fieldType)
@@ -183,6 +187,7 @@ internal void AddInteropType(TypeDefinitionHandle typeDefinitionHandle, bool has
183187
{
184188
this.ThrowIfNotGenerating();
185189
this.types.Add((typeDefinitionHandle, hasUnmanagedName), typeDeclaration);
190+
this.NeedsWinRTCustomMarshaler |= RequiresWinRTCustomMarshaler(typeDeclaration);
186191
}
187192

188193
internal void GenerationTransaction(Action generator)
@@ -378,6 +383,10 @@ private static void Commit<T>(List<T> source, List<T>? target)
378383
source.Clear();
379384
}
380385

386+
private static bool RequiresWinRTCustomMarshaler(SyntaxNode node)
387+
=> node.DescendantNodesAndSelf().OfType<AttributeSyntax>()
388+
.Any(a => a.Name.ToString() == "MarshalAs" && a.ToString().Contains(WinRTCustomMarshalerFullName));
389+
381390
private void Commit(GeneratedCode? parent)
382391
{
383392
foreach (KeyValuePair<string, List<MemberDeclarationSyntax>> item in this.modulesAndMembers)
@@ -407,6 +416,13 @@ private void Commit(GeneratedCode? parent)
407416
Commit(this.releaseMethodsWithSafeHandleTypesGenerating, parent?.releaseMethodsWithSafeHandleTypesGenerating);
408417
Commit(this.inlineArrayIndexerExtensionsMembers, parent?.inlineArrayIndexerExtensionsMembers);
409418
Commit(this.comInterfaceFriendlyExtensionsMembers, parent?.comInterfaceFriendlyExtensionsMembers);
419+
420+
if (parent is not null)
421+
{
422+
parent.NeedsWinRTCustomMarshaler |= this.NeedsWinRTCustomMarshaler;
423+
}
424+
425+
this.NeedsWinRTCustomMarshaler = false;
410426
}
411427

412428
private IEnumerable<MemberDeclarationSyntax> GetTypesWithInjectedFields()

src/Microsoft.Windows.CsWin32/Generator.cs

+1-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ public partial class Generator : IGenerator, IDisposable
4747
private readonly HashSet<string> injectedPInvokeHelperMethods = new();
4848
private readonly HashSet<string> injectedPInvokeMacros = new();
4949
private readonly Dictionary<TypeDefinitionHandle, bool> managedTypesCheck = new();
50-
private bool needsWinRTCustomMarshaler;
5150
private MethodDeclarationSyntax? sliceAtNullMethodDecl;
5251

5352
static Generator()
@@ -840,7 +839,7 @@ nsContents.Key is object
840839
}
841840
}
842841

843-
if (this.needsWinRTCustomMarshaler)
842+
if (this.committedCode.NeedsWinRTCustomMarshaler)
844843
{
845844
string? marshalerText = FetchTemplateText(WinRTCustomMarshalerClass);
846845
if (marshalerText == null)
@@ -998,9 +997,6 @@ internal void RequestInteropType(TypeDefinitionHandle typeDefHandle, Context con
998997
new SyntaxAnnotation(NamespaceContainerAnnotation, shortNamespace));
999998
}
1000999

1001-
this.needsWinRTCustomMarshaler |= typeDeclaration.DescendantNodes().OfType<AttributeSyntax>()
1002-
.Any(a => a.Name.ToString() == "MarshalAs" && a.ToString().Contains(WinRTCustomMarshalerFullName));
1003-
10041000
this.volatileCode.AddInteropType(typeDefHandle, hasUnmanagedName, typeDeclaration);
10051001
}
10061002
});

test/Microsoft.Windows.CsWin32.Tests/COMTests.cs

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ public void IDispatchInterfaceIsDual(string ifaceName)
3434
Assert.Equal(nameof(ComInterfaceType.InterfaceIsDual), arg.Name.Identifier.ValueText);
3535
}
3636

37+
[Fact]
38+
public void CreateDispatcherQueueController_CreatesWinRTCustomMarshaler()
39+
{
40+
this.GenerateApi("CreateDispatcherQueueController");
41+
Assert.Single(this.FindGeneratedType(WinRTCustomMarshalerClass));
42+
}
43+
3744
[Fact]
3845
public void IInpectableDerivedInterface()
3946
{

0 commit comments

Comments
 (0)