Skip to content

ILLink: Rebuild override info after custom step runs #104566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/tools/illink/src/linker/CompatibilitySuppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,12 @@
<DiagnosticId>CP0008</DiagnosticId>
<Target>T:Mono.Linker.MessageOrigin</Target>
</Suppression>
<Suppression>
<DiagnosticId>CP0008</DiagnosticId>
<Target>T:Mono.Linker.OverrideInformation</Target>
<Left>ref/net9.0/illink.dll</Left>
<Right>lib/net9.0/illink.dll</Right>
</Suppression>
<Suppression>
<DiagnosticId>CP0009</DiagnosticId>
<Target>T:Mono.Linker.AnnotationStore</Target>
Expand Down
11 changes: 9 additions & 2 deletions src/tools/illink/src/linker/Linker.Steps/MarkStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2012,8 +2012,15 @@ internal void MarkStaticConstructorVisibleToReflection (TypeDefinition type, in

var typeOrigin = new MessageOrigin (type);

foreach (Action<TypeDefinition> handleMarkType in MarkContext.MarkTypeActions)
handleMarkType (type);
if (MarkContext.MarkTypeActions.Count > 0) {
foreach (Action<TypeDefinition> handleMarkType in MarkContext.MarkTypeActions)
handleMarkType (type);

if (Context.HasCustomMarkHandler) {
// Rebuild type info for the type in case a mark action added new methods.
Annotations.TypeMapInfo.MapType (type);
}
}

MarkType (type.BaseType, new DependencyInfo (DependencyKind.BaseType, type), typeOrigin);

Expand Down
17 changes: 12 additions & 5 deletions src/tools/illink/src/linker/Linker/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Mono.Linker
{
internal static class DictionaryExtensions
{
public static void AddToList<TKey, TElement> (this Dictionary<TKey, List<TElement>> me, TKey key, TElement value)
public static void AddToSet<TKey, TElement> (this Dictionary<TKey, HashSet<TElement>> me, TKey key, TElement value)
where TKey : notnull
{
if (!me.TryGetValue (key, out List<TElement>? valueList)) {
valueList = new ();
me[key] = valueList;
if (!me.TryGetValue (key, out HashSet<TElement>? valueSet)) {
valueSet = new ();
me[key] = valueSet;
}
valueList.Add (value);
if (valueSet.ToList().Count == 1) {

var b = valueSet.ToList()[0]?.Equals (value);
Debug.WriteLine(b);
}
valueSet.Add (value);
}
}
}
35 changes: 2 additions & 33 deletions src/tools/illink/src/linker/Linker/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -942,19 +942,6 @@ protected virtual void AddDgmlDependencyRecorder (LinkContext context, string? f
context.Tracer.AddRecorder (new DgmlDependencyRecorder (context, fileName));
}

protected bool AddMarkHandler (Pipeline pipeline, string arg)
{
if (!TryGetCustomAssembly (ref arg, out Assembly? custom_assembly))
return false;

var step = ResolveStep<IMarkHandler> (arg, custom_assembly);
if (step == null)
return false;

pipeline.AppendMarkHandler (step);
return true;
}

bool TryGetCustomAssembly (ref string arg, [NotNullWhen (true)] out Assembly? assembly)
{
assembly = null;
Expand Down Expand Up @@ -1028,6 +1015,7 @@ protected bool AddCustomStep (Pipeline pipeline, string arg)
var customStep = (IMarkHandler?) Activator.CreateInstance (stepType) ?? throw new InvalidOperationException ();
if (targetName == null) {
pipeline.AppendMarkHandler (customStep);
Context.HasCustomMarkHandler = true;
return true;
}

Expand All @@ -1042,6 +1030,7 @@ protected bool AddCustomStep (Pipeline pipeline, string arg)
else
pipeline.AddMarkHandlerAfter (target, customStep);

Context.HasCustomMarkHandler = true;
return true;
}

Expand Down Expand Up @@ -1086,26 +1075,6 @@ protected bool AddCustomStep (Pipeline pipeline, string arg)
return step;
}

TStep? ResolveStep<TStep> (string type, Assembly assembly) where TStep : class
{
// Ignore warning, since we're just enabling analyzer for dogfooding
#pragma warning disable IL2026
Type? step = assembly != null ? assembly.GetType (type) : Type.GetType (type, false);
#pragma warning restore IL2026

if (step == null) {
Context.LogError (null, DiagnosticId.CustomStepTypeCouldNotBeFound, type);
return null;
}

if (!typeof (TStep).IsAssignableFrom (step)) {
Context.LogError (null, DiagnosticId.CustomStepTypeIsIncompatibleWithLinkerVersion, type);
return null;
}

return (TStep?) Activator.CreateInstance (step);
}

static string[] GetFiles (string param)
{
if (param.Length < 1 || param[0] != '@')
Expand Down
17 changes: 16 additions & 1 deletion src/tools/illink/src/linker/Linker/InterfaceImplementor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Mono.Linker
{
public class InterfaceImplementor
public class InterfaceImplementor : IEquatable<InterfaceImplementor>
{
/// <summary>
/// The type that implements <see cref="InterfaceImplementor.InterfaceType"/>.
Expand Down Expand Up @@ -55,5 +55,20 @@ public static InterfaceImplementor Create(TypeDefinition implementor, TypeDefini
}
throw new InvalidOperationException ($"Type '{implementor.FullName}' does not implement interface '{interfaceType.FullName}' directly or through any interfaces");
}

public bool Equals (InterfaceImplementor? other)
{
if (other is null)
return false;

if (ReferenceEquals (this, other))
return true;

return Implementor == other.Implementor && InterfaceImplementation == other.InterfaceImplementation && InterfaceType == other.InterfaceType;
}

public override bool Equals (object? obj) => obj is InterfaceImplementor other && Equals (other);

public override int GetHashCode () => HashCode.Combine (Implementor, InterfaceImplementation, InterfaceType);
}
}
2 changes: 2 additions & 0 deletions src/tools/illink/src/linker/Linker/LinkContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ internal TypeNameResolver TypeNameResolver {

public string? AssemblyListFile { get; set; }

internal bool HasCustomMarkHandler { get; set; }

public List<IMarkHandler> MarkHandlers { get; }

public Dictionary<string, bool> SingleWarn { get; set; }
Expand Down
24 changes: 23 additions & 1 deletion src/tools/illink/src/linker/Linker/OverrideInformation.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Diagnostics;
using Mono.Cecil;
using System.Diagnostics.CodeAnalysis;

namespace Mono.Linker
{
[DebuggerDisplay ("{Override}")]
public class OverrideInformation
public class OverrideInformation : IEquatable<OverrideInformation>
{
public MethodDefinition Base { get; }

Expand Down Expand Up @@ -37,5 +38,26 @@ public TypeDefinition? InterfaceType
[MemberNotNullWhen (true, nameof (InterfaceImplementor), nameof (MatchingInterfaceImplementation))]
public bool IsOverrideOfInterfaceMember
=> InterfaceImplementor != null;

public bool Equals (OverrideInformation? other)
{
if (other is null)
return false;

if (ReferenceEquals (this, other))
return true;

if (Base != other.Base || Override != other.Override)
return false;

if (InterfaceImplementor == null)
return other.InterfaceImplementor == null;

return InterfaceImplementor.Equals (other.InterfaceImplementor);
}

public override bool Equals (object? obj) => obj is OverrideInformation other && Equals (other);

public override int GetHashCode () => HashCode.Combine (Base, Override, InterfaceImplementor);
}
}
38 changes: 22 additions & 16 deletions src/tools/illink/src/linker/Linker/TypeMapInfo.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

//
Expand Down Expand Up @@ -43,9 +43,10 @@ public class TypeMapInfo
{
readonly HashSet<AssemblyDefinition> assemblies = new HashSet<AssemblyDefinition> ();
readonly LinkContext context;
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> default_interface_implementations = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, HashSet<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, HashSet<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, HashSet<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, HashSet<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, HashSet<OverrideInformation>> default_interface_implementations = new Dictionary<MethodDefinition, HashSet<OverrideInformation>> ();
readonly Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new ();

public TypeMapInfo (LinkContext context)
{
Expand All @@ -58,7 +59,7 @@ public void EnsureProcessed (AssemblyDefinition assembly)
return;

foreach (TypeDefinition type in assembly.MainModule.Types)
MapType (type);
MapTypes (type);
}

public ICollection<MethodDefinition> MethodsWithOverrideInformation => override_methods.Keys;
Expand All @@ -69,8 +70,8 @@ public void EnsureProcessed (AssemblyDefinition assembly)
public List<OverrideInformation>? GetOverrides (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
override_methods.TryGetValue (method, out List<OverrideInformation>? overrides);
return overrides;
override_methods.TryGetValue (method, out HashSet<OverrideInformation>? overrides);
return overrides?.ToList ();
}

/// <summary>
Expand All @@ -83,8 +84,8 @@ public void EnsureProcessed (AssemblyDefinition assembly)
public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
base_methods.TryGetValue (method, out List<OverrideInformation>? bases);
return bases;
base_methods.TryGetValue (method, out HashSet<OverrideInformation>? bases);
return bases?.ToList ();
}

/// <summary>
Expand All @@ -102,32 +103,37 @@ public void EnsureProcessed (AssemblyDefinition assembly)

public void AddBaseMethod (MethodDefinition method, MethodDefinition @base, InterfaceImplementor? interfaceImplementor)
{
base_methods.AddToList (method, new OverrideInformation (@base, method, interfaceImplementor));
base_methods.AddToSet (method, new OverrideInformation (@base, method, interfaceImplementor));
}

public void AddOverride (MethodDefinition @base, MethodDefinition @override, InterfaceImplementor? interfaceImplementor = null)
{
override_methods.AddToList (@base, new OverrideInformation (@base, @override, interfaceImplementor));
override_methods.AddToSet (@base, new OverrideInformation (@base, @override, interfaceImplementor));
}

public void AddDefaultInterfaceImplementation (MethodDefinition @base, InterfaceImplementor interfaceImplementor, MethodDefinition defaultImplementationMethod)
{
Debug.Assert(@base.DeclaringType.IsInterface);
default_interface_implementations.AddToList (@base, new OverrideInformation (@base, defaultImplementationMethod, interfaceImplementor));
default_interface_implementations.AddToSet (@base, new OverrideInformation (@base, defaultImplementationMethod, interfaceImplementor));
}

Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new ();
protected virtual void MapType (TypeDefinition type)
public void MapType (TypeDefinition type)
{
MapVirtualMethods (type);
MapInterfaceMethodsInTypeHierarchy (type);
interfaces[type] = GetRecursiveInterfaceImplementations (type);
if (!interfaces.ContainsKey (type))
interfaces[type] = GetRecursiveInterfaceImplementations (type);
}

protected virtual void MapTypes (TypeDefinition type)
{
MapType (type);

if (!type.HasNestedTypes)
return;

foreach (var nested in type.NestedTypes)
MapType (nested);
MapTypes (nested);
}

internal List<(TypeReference InterfaceType, List<InterfaceImplementation> ImplementationChain)>? GetRecursiveInterfaces (TypeDefinition type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ static void TestReflectionAccessToOtherAssembly ()
// to be created through reflection instead of a direct call to the constructor, otherwise we build the
// TypeMapInfo cache too early for the custom step.

// var type = typeof (InterfaceImplementation);
var type = typeof (InterfaceImplementationInOtherAssembly);
InterfaceType instance = (InterfaceType) System.Activator.CreateInstance (type);
InterfaceType.UseInstance (instance);
Expand All @@ -46,7 +45,7 @@ static void TestReflectionAccess ()
[Kept]
[KeptMember (".ctor()")]
[KeptInterface (typeof (InterfaceType))]
// [CreatedMember ("AbstractMethod()")] // https://github.com/dotnet/runtime/issues/104266
[CreatedMember ("AbstractMethod()")]
class InterfaceImplementationAccessedViaReflection : InterfaceType
{
}
Expand All @@ -61,7 +60,7 @@ static void TestDirectAccess ()
[Kept]
[KeptMember (".ctor()")]
[KeptInterface (typeof (InterfaceType))]
// [CreatedMember ("AbstractMethod()")] // https://github.com/dotnet/runtime/issues/104266
[CreatedMember ("AbstractMethod()")]
class InterfaceImplementation : InterfaceType
{
}
Expand Down
Loading