4141
4242using Microsoft . CodeAnalysis . CSharp ;
4343using System . Collections . Immutable ;
44+ using System . ComponentModel ;
45+ using System . Linq ;
4446using System . Runtime . InteropServices ;
4547using System . Xml ;
4648using static Microsoft . CodeAnalysis . CSharp . SyntaxFactory ;
@@ -78,11 +80,11 @@ internal enum SizeParamType
7880 static string UniqueName ( string n ) => $ "{ n } { uid ++ } ";
7981
8082 static readonly MethAttrHandler [ ] methodAttributes = [
81- new ( "System.Runtime.InteropServices.MarshalAsAttribute" , IsParamWithArgs , BuildMarshalAsMethod , ParentForExtMethod , GetParamMeth ) , // IUnknown, Interface, LPArray
82- new ( "Vanara.PInvoke.IgnoreAttribute" , IsParamInNotNewMethod , BuildIgnoreMethod , ParentForExtMethod , GetParamMeth ) ,
83- new ( "Vanara.PInvoke.SizeDefAttribute" , IsParamInNotNewMethod , BuildSizeDefMethod , ParentForExtMethod , GetParamMeth ) ,
84- new ( "Vanara.PInvoke.AddAsCtorAttribute" , IsParamInNotNewStaticMethod , DummyBuilder , GetParamType , GetParamMeth ) ,
85- new ( "Vanara.PInvoke.AddAsMemberAttribute" , IsParamInNotNewStaticMethod , DummyBuilder , GetParamType , GetParamMeth ) ,
83+ new ( "System.Runtime.InteropServices.MarshalAsAttribute" , IsParamInNestedType , BuildMarshalAsMethod , ParentForExtMethod , GetParamMeth ) , // IUnknown, Interface, LPArray
84+ new ( "Vanara.PInvoke.IgnoreAttribute" , IsParamInMethod , BuildIgnoreMethod , ParentForExtMethod , GetParamMeth ) ,
85+ new ( "Vanara.PInvoke.SizeDefAttribute" , IsParamInMethod , BuildSizeDefMethod , ParentForExtMethod , GetParamMeth ) ,
86+ new ( "Vanara.PInvoke.AddAsCtorAttribute" , IsParamInStaticMethod , DummyBuilder , GetParamType , GetParamMeth ) ,
87+ new ( "Vanara.PInvoke.AddAsMemberAttribute" , IsParamInStaticMethod , DummyBuilder , GetParamType , GetParamMeth ) ,
8688 //("Vanara.PInvoke.SuppressAutoGenAttribute", static (n, t) => n is MethodDeclarationSyntax),
8789 ] ;
8890
@@ -98,7 +100,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
98100 {
99101 // Process each attribute in the methodAttributes array
100102 var attributeProviders = methodAttributes
101- . Select ( attr => context . SyntaxProvider . ForAttributeWithMetadataName ( attr . AttrName , ( n , t ) => attr . Validator ( n , t ) && NoSuppress ( attr . meth ( ( ParameterSyntax ) n ) ) , ( ctx , _ ) => ( ( ParameterSyntax ) ctx . TargetNode , ctx . Attributes ) ) . Collect ( ) )
103+ . Select ( attr => context . SyntaxProvider . ForAttributeWithMetadataName ( attr . AttrName , ( n , t ) => attr . Validator ( n , t ) , ( ctx , _ ) => ( ( ParameterSyntax ) ctx . TargetNode , ctx . Attributes ) ) . Collect ( ) )
102104 . ToArray ( ) ;
103105
104106 // Process type-level methodAttributes (those without method transforms)
@@ -128,8 +130,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
128130 GenerateCode ( spc , value . Left . Left . Left . Left . Left . Left . Left . Left , value . Left . Left . Left . Left . Left . Left . Left . Right ,
129131 value . Right . AddRange ( value . Left . Right ) . AddRange ( value . Left . Left . Right ) . AddRange ( value . Left . Left . Left . Right ) . AddRange ( value . Left . Left . Left . Left . Right ) ,
130132 value . Left . Left . Left . Left . Left . Right . AddRange ( value . Left . Left . Left . Left . Left . Left . Right ) ) ) ;
131-
132- static bool NoSuppress ( MethodDeclarationSyntax ms ) => ! ms . Modifiers . Any ( SyntaxKind . NewKeyword ) && ! ms . AttributeLists . SelectMany ( al => al . Attributes ) . Any ( a => a . Name . ToString ( ) . Contains ( "SuppressAutoGen" ) ) ;
133133 }
134134
135135 private static void GenerateCode ( SourceProductionContext context , Compilation compilation , ImmutableArray < AdditionalText > addtlFiles , ImmutableArray < ( ParameterSyntax paramNode , ImmutableArray < AttributeData > attrDatas ) > paramNodes ,
@@ -187,7 +187,7 @@ private static void GenerateCode(SourceProductionContext context, Compilation co
187187 continue ;
188188
189189 // Add distinct using directives from 'usings' to the compilation unit
190- usings = [ .. usings . DistinctBy ( u => u . Name ? . ToString ( ) ) ] ;
190+ usings = [ .. usings . DistinctBy ( u => u . Alias ? . ToString ( ) ?? u . Name ? . ToString ( ) ) ] ;
191191
192192 // Build new syntax tree, starting with nsDecl, adding any containing types with their modifiers, and finally the typeDecl with the new values from methLookup.Values
193193 SyntaxNode topDecl = typeDecl . WithMembers ( List < MemberDeclarationSyntax > ( methLookup . Values . Select ( bb => bb . ToMethod ( ) ) ) ) ;
@@ -581,6 +581,7 @@ private static void BuildMarshalAsMethod(SourceProductionContext context, Compil
581581 UnmanagedType . IUnknown => decl . WithType ( ParseTypeName ( genericType + ( decl . Type is not null && decl . Type . ToString ( ) . EndsWith ( "?" ) ? "?" : "" ) ) )
582582 . WithoutAttribute ( "MarshalAs" ) ,
583583 UnmanagedType . Interface => decl . WithoutAttribute ( "MarshalAs" ) ,
584+ UnmanagedType . LPArray when modAttr is ModType . In => decl . WithoutAttribute ( "MarshalAs" ) ,
584585 _ => decl . WithoutTrivia ( )
585586 } ;
586587 tmpbuilder . parameters . Replace ( decl , newParam ) ;
@@ -598,7 +599,16 @@ private static void BuildMarshalAsMethod(SourceProductionContext context, Compil
598599 // Create the invocation expression capturing the return value if the return type is not `void`
599600 tmpbuilder . statements . invokeArgs = [ .. tmpbuilder . statements . invokeArgs . Select ( a => a switch
600601 {
601- var a1 when a1 . NameEquals ( refParamName ) => Argument ( ParseExpression ( $ "typeof({ ( unmanagedType == UnmanagedType . IUnknown ? genericType : decl . Type ! . ToString ( ) . TrimEnd ( '?' ) ) } ).GUID") ) ,
602+ var a1 when a1 . NameEquals ( refParamName ) => unmanagedType switch
603+ {
604+ UnmanagedType . IUnknown => Argument ( ParseExpression ( $ "typeof({ genericType } ).GUID") ) ,
605+ UnmanagedType . Interface => Argument ( ParseExpression ( $ "typeof({ decl . Type ! . ToString ( ) . TrimEnd ( '?' ) } ).GUID") ) ,
606+ UnmanagedType . LPArray when modAttr is ModType . In && methodDecl . ParameterList . Parameters . FirstOrDefault ( p => p . Identifier . Text == refParamName ) is ParameterSyntax refParam =>
607+ paramTypeIsNullable
608+ ? Argument ( ParseExpression ( $ "({ refParam . Type } )Convert.ChangeType({ decl . Identifier . Text } ?.Length ?? 0, typeof({ refParam . Type } ))") )
609+ : Argument ( ParseExpression ( $ "({ refParam . Type } )Convert.ChangeType({ decl . Identifier . Text } .Length, typeof({ refParam . Type } ))") ) ,
610+ _ => a1 ,
611+ } ,
602612 var a1 when a1 . Expression is IdentifierNameSyntax ins && ins . Identifier . Text == decl . Identifier . Text => isOutParam
603613 ? Argument ( null , MethodBodyBuilder . outToken , DeclarationExpression ( IdentifierName ( "var" ) , SingleVariableDesignation ( Identifier ( altArg ) ) ) )
604614 : Argument ( IdentifierName ( ins . Identifier . Text ) ) ,
@@ -624,16 +634,22 @@ var a1 when a1.NameEquals(refParamName) => Argument(ParseExpression($"typeof({(u
624634 // Process the xml docs for the method
625635 if ( tmpbuilder . docs is not null )
626636 {
627- // Get the xml node for the IID parameter docs
628- XmlNode ? iidNode = tmpbuilder . docs . SelectSingleNode ( $ "//param[@name='{ refParamName } ']") ;
629- if ( iidNode is not null )
637+ // Get the xml node for the ref parameter docs
638+ XmlNode ? refNode = tmpbuilder . docs . SelectSingleNode ( $ "//param[@name='{ refParamName } ']") ;
639+ if ( refNode is not null )
630640 {
631- // Add the IID parameter docs to the method docs as the value of the typeParam tag
641+ // Add the ref parameter docs to the method docs as the value of the typeParam tag
632642 if ( unmanagedType == UnmanagedType . IUnknown )
633- tmpbuilder . docs . InsertTypeParamDocAfter ( genericType , iidNode . InnerXml ) ;
643+ tmpbuilder . docs . InsertTypeParamDocAfter ( genericType , refNode . InnerXml ) ;
634644
635- // Remove the IID parameter docs
636- iidNode . ParentNode ? . RemoveChild ( iidNode ) ;
645+ // Remove the ref parameter docs
646+ refNode . ParentNode ? . RemoveChild ( refNode ) ;
647+
648+ // Remove the "<paramref name="refParamName"/>" tags from the entire document
649+ XmlElement replElem = tmpbuilder . docs . CreateElement ( "c" ) ;
650+ replElem . InnerText = refParamName ;
651+ foreach ( var n in tmpbuilder . docs . SelectNodes ( $ "//paramref[@name='{ refParamName } ']") . Cast < XmlElement > ( ) . Where ( n => n . ParentNode is not null ) )
652+ n . ParentNode ? . ReplaceChild ( replElem , n ) ;
637653 }
638654 }
639655
@@ -676,25 +692,36 @@ bool ValidateAttr(out string refParamName, out UnmanagedType unmanagedType)
676692 break ;
677693
678694 case UnmanagedType . LPArray :
679- //refindex = GetIndex("SizeParamIndex");
680- //if (refindex == -1 || modAttr != ModType.In)
681- // return false;
682- //break;
695+ refindex = GetIndex ( "SizeParamIndex" ) ;
696+ if ( refindex == - 1 || modAttr != ModType . In )
697+ return false ;
698+ break ;
699+
683700 default :
684701 return false ;
685702 }
686703 }
687704
688- // If there's an refindex, then make sure it points to a valid Guid parameter
705+ // If there's an refindex, then make sure it points to a valid parameter
689706 if ( refindex >= 0 )
690707 {
691708 var iidParam = methodDecl . ParameterList . Parameters [ refindex ] ;
692709 var paramType = iidParam . Type ? . ToString ( ) ;
693- var hasInModifier = iidParam . Modifiers . Any ( SyntaxKind . InKeyword ) ;
694- var hasStructAttribute = ! hasInModifier && iidParam . AttributeLists
695- . SelectMany ( al => al . Attributes )
696- . Any ( attr => attr . Name . ToFullString ( ) == "MarshalAs" && attr . ArgumentList ? . Arguments . FirstOrDefault ( ) ? . ToString ( ) == "UnmanagedType.Struct" ) ;
697- if ( ( paramType == "System.Guid" || paramType == "Guid" ) && ( hasInModifier || hasStructAttribute ) )
710+ // For interfaces, confirm param type is Guid
711+ if ( unmanagedType is UnmanagedType . IUnknown or UnmanagedType . Interface )
712+ {
713+ var hasInModifier = iidParam . Modifiers . Any ( SyntaxKind . InKeyword ) ;
714+ var hasStructAttribute = ! hasInModifier && iidParam . AttributeLists
715+ . SelectMany ( al => al . Attributes )
716+ . Any ( attr => attr . Name . ToFullString ( ) == "MarshalAs" && attr . ArgumentList ? . Arguments . FirstOrDefault ( ) ? . ToString ( ) == "UnmanagedType.Struct" ) ;
717+ if ( ( paramType == "System.Guid" || paramType == "Guid" ) && ( hasInModifier || hasStructAttribute ) )
718+ {
719+ refParamName = iidParam . Identifier . Text ;
720+ return true ;
721+ }
722+ }
723+ // For LPArray, type is integral so pass along
724+ else if ( unmanagedType == UnmanagedType . LPArray )
698725 {
699726 refParamName = iidParam . Identifier . Text ;
700727 return true ;
@@ -715,18 +742,27 @@ int GetIndex(string attrNamedArg)
715742
716743 private static TypeDeclarationSyntax GetParamType ( SourceProductionContext context , ParameterSyntax decl ) => GetFirstTypeParent ( context , decl ) ;
717744
718- private static bool IsParamInNotNewStaticMethod ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) =>
719- syntaxNode is ParameterSyntax ps && ps . Parent ? . Parent is MethodDeclarationSyntax ms && ! ms . Modifiers . Any ( SyntaxKind . NewKeyword ) && ms . Modifiers . Any ( SyntaxKind . StaticKeyword ) ;
745+ // Confirm param is in a method is not new or unsafe and does not have SuppressAutoGen attribute
746+ private static bool IsParamInMethod ( SyntaxNode syntaxNode , CancellationToken cancellationToken , out MethodDeclarationSyntax ? ms )
747+ {
748+ ms = syntaxNode is ParameterSyntax ps && ps . Parent ? . Parent is MethodDeclarationSyntax mds
749+ && ! mds . Modifiers . Any ( SyntaxKind . UnsafeKeyword ) && ! mds . Modifiers . Any ( SyntaxKind . NewKeyword )
750+ && ! mds . AttributeLists . SelectMany ( al => al . Attributes ) . Any ( a => a . Name . ToString ( ) . Contains ( "SuppressAutoGen" ) ) ? mds : null ;
751+ return ms is not null ;
752+ }
753+
754+ private static bool IsParamInMethod ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) => IsParamInMethod ( syntaxNode , cancellationToken , out _ ) ;
720755
721- private static bool IsParamInNotNewMethod ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) =>
722- syntaxNode is ParameterSyntax ps && ps . Parent ? . Parent is MethodDeclarationSyntax ms && ! ms . Modifiers . Any ( SyntaxKind . NewKeyword ) ;
756+ private static bool IsParamInStaticMethod ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) =>
757+ IsParamInMethod ( syntaxNode , cancellationToken , out var ms ) && ms ! . Modifiers . Any ( SyntaxKind . StaticKeyword ) ;
723758
724- private static bool IsParamWithArgs ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) => syntaxNode is ParameterSyntax ps &&
725- ps . Parent ? . Parent is MethodDeclarationSyntax ms && ! ms . Modifiers . Any ( SyntaxKind . NewKeyword ) &&
726- ( ms ? . Parent is ClassDeclarationSyntax cs && cs . IsPartial ( ) || ms ? . Parent is InterfaceDeclarationSyntax && ms ? . Parent ? . Parent is ClassDeclarationSyntax ccs && ccs . IsPartial ( ) ) ;
759+ private static bool IsParamInNestedType ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) =>
760+ IsParamInMethod ( syntaxNode , cancellationToken , out var ms ) &&
761+ ( ( ms ? . Parent is ClassDeclarationSyntax cs && cs . IsPartial ( ) && ms . Modifiers . Any ( SyntaxKind . StaticKeyword ) && ms . Modifiers . Any ( SyntaxKind . ExternKeyword ) ) ||
762+ ( ms ? . Parent is InterfaceDeclarationSyntax && ms ? . Parent ? . Parent is ClassDeclarationSyntax ccs && ccs . IsPartial ( ) ) ) ;
727763
728764 private static bool IsPartialType ( SyntaxNode syntaxNode , CancellationToken cancellationToken ) => syntaxNode is TypeDeclarationSyntax tds && tds . IsPartial ( ) ;
729-
765+
730766 private static TypeDeclarationSyntax ParentForExtMethod ( SourceProductionContext context , ParameterSyntax decl )
731767 {
732768 var parentClass = decl . Parent ? . Parent ? . Parent is ClassDeclarationSyntax cs ? cs : ( decl . Parent ? . Parent ? . Parent is InterfaceDeclarationSyntax && decl . Parent ? . Parent ? . Parent ? . Parent is ClassDeclarationSyntax ccs ? ccs : null ) ;
0 commit comments