Skip to content

Commit 1fceb4d

Browse files
author
Nikolay Pianikov
committed
Support for nullable reference types
1 parent 8cb8921 commit 1fceb4d

47 files changed

Lines changed: 3638 additions & 167 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/Pure.DI.Core/Core/ApiInvocationProcessor.cs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ MemberAccessExpressionSyntax memberAccess when memberAccess.Kind() == SyntaxKind
433433
var tagArguments = invocation.ArgumentList.Arguments.SkipWhile((arg, i) => arg.NameColon?.Name.Identifier.Text != "tags" && i < 2);
434434
var tags = BuildTags(semanticModel, tagArguments);
435435
VisitBind(metadataVisitor, semanticModel, invocation, tags, genericName);
436-
var rootBindSymbol = semantic.GetTypeSymbol<INamedTypeSymbol>(semanticModel, rootBindType);
436+
var rootBindSymbol = GetTypeSymbol(semanticModel, rootBindType);
437437
VisitRoot(invocation, tags.FirstOrDefault(), metadataVisitor, semanticModel, invocation, invocationComments, rootBindSymbol);
438438
break;
439439

@@ -616,7 +616,7 @@ MemberAccessExpressionSyntax memberAccess when memberAccess.Kind() == SyntaxKind
616616
nameof(Strings.Error_InvalidRootType));
617617
}
618618

619-
var rootSymbol = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, rootTypeSyntax);
619+
var rootSymbol = GetTypeSymbol(semanticModel, rootTypeSyntax);
620620
VisitRoot(invocation, metadataVisitor, semanticModel, invocation, invocationComments, rootSymbol);
621621
break;
622622

@@ -1033,7 +1033,7 @@ private void VisitRoot(
10331033
SemanticModel semanticModel,
10341034
InvocationExpressionSyntax invocation,
10351035
IReadOnlyCollection<string> invocationComments,
1036-
INamedTypeSymbol rootSymbol)
1036+
ITypeSymbol rootSymbol)
10371037
{
10381038
tag ??= new MdTag(0, null);
10391039
var rootArgs = arguments.GetArgs(invocation.ArgumentList, "name", "kind");
@@ -1062,13 +1062,15 @@ private void VisitBind(
10621062
GenericNameSyntax genericName)
10631063
{
10641064
var contractTypes = genericName.TypeArgumentList.Arguments;
1065+
// ReSharper disable once ForeachCanBePartlyConvertedToQueryUsingAnotherGetEnumerator
10651066
foreach (var contractType in contractTypes)
10661067
{
1068+
var contractTypeSymbol = GetTypeSymbol(semanticModel, contractType);
10671069
metadataVisitor.VisitContract(
10681070
new MdContract(
10691071
semanticModel,
10701072
invocation,
1071-
semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, contractType),
1073+
contractTypeSymbol,
10721074
ContractKind.Explicit,
10731075
tags));
10741076
}
@@ -1388,7 +1390,7 @@ private MdResolver CreateResolver(
13881390
var resolverTag = new MdTag(0, tagValue);
13891391
if (args[1] is {} argSyntax2)
13901392
{
1391-
var argType2 = GetDefaultType(semanticModel, invocation, 1) ?? GetArgSymbol(semanticModel, argSyntax2) ?? resultType;
1393+
var argType2 = GetDefaultType(semanticModel, invocation, 0) ?? GetArgSymbol(semanticModel, argSyntax2) ?? resultType;
13921394
if (argType2 is null or IErrorTypeSymbol)
13931395
{
13941396
throw new CompileErrorException(
@@ -1461,14 +1463,27 @@ MemberAccessExpressionSyntax memberAccess when memberAccess.Kind() == SyntaxKind
14611463
};
14621464

14631465
ITypeSymbol? defaultType = null;
1466+
// ReSharper disable once InvertIf
14641467
if (name is GenericNameSyntax genericName && genericName.TypeArgumentList.Arguments.Count > typeArgPosition)
14651468
{
1466-
defaultType = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, genericName.TypeArgumentList.Arguments[typeArgPosition]);
1469+
var typeArgument = genericName.TypeArgumentList.Arguments[typeArgPosition];
1470+
defaultType = GetTypeSymbol(semanticModel, typeArgument);
14671471
}
14681472

14691473
return defaultType;
14701474
}
14711475

1476+
private ITypeSymbol GetTypeSymbol(SemanticModel semanticModel, TypeSyntax typeSyntax)
1477+
{
1478+
var typeSymbol = semantic.GetTypeSymbol<ITypeSymbol>(semanticModel, typeSyntax);
1479+
if (typeSyntax is NullableTypeSyntax && typeSymbol.IsReferenceType)
1480+
{
1481+
typeSymbol = typeSymbol.WithNullableAnnotation(NullableAnnotation.Annotated);
1482+
}
1483+
1484+
return typeSymbol;
1485+
}
1486+
14721487
private static bool HasContextTag(ExpressionSyntax? tag, ParameterSyntax contextParameter) =>
14731488
tag is MemberAccessExpressionSyntax memberAccessExpression
14741489
&& memberAccessExpression.IsKind(SyntaxKind.SimpleMemberAccessExpression)

src/Pure.DI.Core/Core/BindingBuilder.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ sealed class BindingBuilder(
1010
[Tag(SpecialBindingIdGenerator)] IIdGenerator specialBindingIdGenerator,
1111
IBaseSymbolsProvider baseSymbolsProvider,
1212
ILocationProvider locationProvider,
13-
ILifetimeProvider lifetimeProvider)
13+
ILifetimeProvider lifetimeProvider,
14+
ITypeSymbolComparer typeSymbolComparer)
1415
: IBindingBuilder
1516
{
1617
private readonly List<MdContract> _contracts = [];
@@ -130,7 +131,7 @@ private IEnumerable<MdContract> CreateExplicitContractsFromImplementation(
130131
// Only search for base symbols if the implementation is a concrete class or struct
131132
if (implementationType is { SpecialType: Microsoft.CodeAnalysis.SpecialType.None, TypeKind: TypeKind.Class or TypeKind.Struct, IsAbstract: false })
132133
{
133-
var specialTypes = setup.SpecialTypes.Select(i => i.Type).ToImmutableHashSet(SymbolEqualityComparer.Default);
134+
var specialTypes = setup.SpecialTypes.Select(i => i.Type).ToImmutableHashSet(typeSymbolComparer.Runtime);
134135
baseSymbols = baseSymbolsProvider
135136
.GetBaseSymbols(implementationType, (type, deepness) => deepness switch
136137
{
@@ -141,7 +142,7 @@ private IEnumerable<MdContract> CreateExplicitContractsFromImplementation(
141142
.Select(i => i.Type);
142143
}
143144

144-
var contracts = new HashSet<ITypeSymbol>(baseSymbols, SymbolEqualityComparer.Default) { implementationType };
145+
var contracts = new HashSet<ITypeSymbol>(baseSymbols, typeSymbolComparer.Runtime) { implementationType };
145146
var tags = implementationContracts
146147
.SelectMany(i => i.Tags)
147148
.GroupBy(i => i.Value)
@@ -154,7 +155,7 @@ private IEnumerable<MdContract> CreateExplicitContractsFromImplementation(
154155
}
155156
}
156157

157-
private static bool IsSuitableForBinding(ImmutableHashSet<ISymbol?> specialTypes, ITypeSymbol type)
158+
private static bool IsSuitableForBinding(ImmutableHashSet<ITypeSymbol> specialTypes, ITypeSymbol type)
158159
{
159160
// Checks if the type is an interface or an abstract class, which are typical candidates for DI contracts.
160161
var isAbstractOrInterface = type.TypeKind == TypeKind.Interface || type.IsAbstract;

src/Pure.DI.Core/Core/BindingsFactory.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ class BindingsFactory(
44
Func<IFastBuilder<RewriterContext<MdFactory>, MdFactory>> factoryRewriterFactory,
55
ITypes types,
66
IMarker marker,
7-
ILifetimeProvider lifetimeProvider)
7+
ILifetimeProvider lifetimeProvider,
8+
IInjectionComparer injectionComparer)
89
: IBindingsFactory
910
{
1011
public MdBinding CreateGenericBinding(
@@ -160,9 +161,8 @@ public MdBinding CreateConstructBinding(
160161
object? explicitDefaultValue = null,
161162
object? state = null)
162163
{
163-
elementType = elementType.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
164164
var dependencyContracts = new List<MdContract>();
165-
var contracts = new HashSet<Injection>();
165+
var contracts = new HashSet<Injection>(injectionComparer);
166166
var originalIds = new HashSet<int>();
167167
if (constructKind is MdConstructKind.Array or MdConstructKind.AsyncEnumerable or MdConstructKind.Enumerable or MdConstructKind.Span)
168168
{
@@ -260,4 +260,4 @@ private IEnumerable<MdContract> GetMatchedMdContracts(MdSetup setup, ITypeSymbol
260260

261261
return new MdTag(0, injection.Tag);
262262
}
263-
}
263+
}

src/Pure.DI.Core/Core/Code/BuildTools.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ sealed class BuildTools(
1313
IUniqueNameProvider uniqueNameProvider,
1414
ILocks locks,
1515
ISymbolNames symbolNames,
16-
ICompilations compilations)
16+
ICompilations compilations,
17+
ITypeSymbolComparer typeSymbolComparer)
1718
: IBuildTools
1819
{
1920
public string NullCheck(Compilation compilation, string variableName)
@@ -47,10 +48,10 @@ public Lines OnCreated(CodeContext ctx, VarInjection varInjection)
4748
return new Lines();
4849
}
4950

50-
var baseTypes = new Lazy<ImmutableHashSet<ISymbol?>>(() =>
51+
var baseTypes = new Lazy<ImmutableHashSet<ITypeSymbol>>(() =>
5152
baseSymbolsProvider.GetBaseSymbols(varInjection.Var.InstanceType, (_, _) => true)
5253
.Select(i => i.Type)
53-
.ToImmutableHashSet(SymbolEqualityComparer.Default));
54+
.ToImmutableHashSet(typeSymbolComparer.Runtime));
5455

5556
var accLines = ctx.Accumulators
5657
.Where(acc => acc.Lifetime == varInjection.Var.AbstractNode.Lifetime)
@@ -142,8 +143,7 @@ private string OnInjectedInternal(CodeContext ctx, VarInjection varInjection)
142143
}
143144
}
144145

145-
if (varInjection.Var.InstanceType.IsReferenceType
146-
&& varInjection.Var.InstanceType.NullableAnnotation == NullableAnnotation.Annotated
146+
if (varInjection.Var.InstanceType is { IsReferenceType: true, NullableAnnotation: NullableAnnotation.Annotated }
147147
&& varInjection.ContractType.NullableAnnotation != NullableAnnotation.Annotated)
148148
{
149149
variableCode = $"{variableCode}!";

src/Pure.DI.Core/Core/Code/ClassDiagramBuilder.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ sealed class ClassDiagramBuilder(
1414
ITypes types,
1515
ILocationProvider locationProvider,
1616
IGlobalProperties globalProperties,
17+
ITypeSymbolComparer typeSymbolComparer,
1718
CancellationToken cancellationToken)
1819
: IBuilder<CompositionCode, Lines>
1920
{
@@ -88,8 +89,8 @@ public Lines Build(CompositionCode composition)
8889
lines.AppendLine($"{composition.Name.ClassName} --|> IAsyncDisposable");
8990
}
9091

91-
var typeSymbols = new HashSet<ITypeSymbol>(SymbolEqualityComparer.Default);
92-
foreach (var node in graph.Vertices.GroupBy(i => i.Type, SymbolEqualityComparer.Default).Select(i => i.First()).OrderBy(i => i.Binding.Id))
92+
var typeSymbols = new HashSet<ITypeSymbol>(typeSymbolComparer.Dependency);
93+
foreach (var node in graph.Vertices.GroupBy(i => i.Type, typeSymbolComparer.Dependency).Select(i => i.First()).OrderBy(i => i.Binding.Id))
9394
{
9495
cancellationToken.ThrowIfCancellationRequested();
9596
if (node.Root is not null)
@@ -326,14 +327,18 @@ private string FormatType(MdSetup setup, ISymbol? symbol, FormatOptions options)
326327

327328
return symbol switch
328329
{
329-
INamedTypeSymbol { IsGenericType: true } namedTypeSymbol => $"{namedTypeSymbol.Name}{options.StartGenericArgsSymbol}{string.Join(options.TypeArgsSeparator, namedTypeSymbol.TypeArguments.Select(FormatSymbolLocal))}{options.FinishGenericArgsSymbol}",
330-
IArrayTypeSymbol array => $"Array{options.StartGenericArgsSymbol}{FormatType(setup, array.ElementType, options)}{options.FinishGenericArgsSymbol}",
330+
INamedTypeSymbol { IsGenericType: true } namedTypeSymbol => $"{namedTypeSymbol.Name}{options.StartGenericArgsSymbol}{string.Join(options.TypeArgsSeparator, namedTypeSymbol.TypeArguments.Select(FormatSymbolLocal))}{options.FinishGenericArgsSymbol}{FormatNullableSuffix(namedTypeSymbol)}",
331+
IArrayTypeSymbol array => $"Array{options.StartGenericArgsSymbol}{FormatType(setup, array.ElementType, options)}{options.FinishGenericArgsSymbol}{FormatNullableSuffix(array)}",
332+
ITypeSymbol type => $"{type.Name}{FormatNullableSuffix(type)}",
331333
_ => symbol?.Name ?? "Unresolved"
332334
};
333335

334336
string FormatSymbolLocal(ITypeSymbol i) => FormatSymbol(setup, i, options);
335337
}
336338

339+
private static string FormatNullableSuffix(ITypeSymbol type) =>
340+
type is { IsReferenceType: true, NullableAnnotation: NullableAnnotation.Annotated } ? "?" : "";
341+
337342
private string ResolveTypeName(MdSetup setup, ITypeSymbol typeSymbol)
338343
{
339344
var typeName = typeResolver.Resolve(setup, typeSymbol).Name;
@@ -474,4 +479,4 @@ public string ActualKind
474479
}
475480
}
476481
}
477-
}
482+
}

src/Pure.DI.Core/Core/Code/CompositionClassBuilder.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
namespace Pure.DI.Core.Code;
44

5-
using Microsoft.CodeAnalysis;
65
using Parts;
76
using static LinesExtensions;
87
using static Tag;
@@ -21,10 +20,7 @@ public CompositionCode Build(CompositionCode composition)
2120
{
2221
var code = composition.Code;
2322
code.AppendComments("<auto-generated/>", $"by {information.Description}");
24-
if (composition.Compilation.Options.NullableContextOptions != NullableContextOptions.Disable)
25-
{
26-
code.AppendLine("#nullable enable annotations");
27-
}
23+
code.AppendLine("#nullable enable annotations");
2824

2925
code.AppendLine("#pragma warning disable CS0162");
3026

src/Pure.DI.Core/Core/Code/FactoryRewriter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ private static bool LambdaDeclaresParameter(LambdaExpressionSyntax lambda, strin
409409
: SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression, SyntaxFactory.Token(SyntaxKind.FalseKeyword));
410410

411411
case nameof(IContext.RootType):
412-
return SyntaxFactory.ParseExpression($"typeof({typeResolver.Resolve(_ctx!.RootContext.Graph.Source, _ctx!.RootContext.Root.Injection.Type)})");
412+
return SyntaxFactory.ParseExpression($"typeof({typeResolver.ResolveRuntime(_ctx!.RootContext.Graph.Source, _ctx!.RootContext.Root.Injection.Type)})");
413413

414414
case nameof(IContext.RootName):
415415
return SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(_ctx!.RootContext.Root.DisplayName));
@@ -427,7 +427,7 @@ private IEnumerable<string> GetConsumers(MdSetup setup)
427427
{
428428
foreach (var parent in _ctx?.Parents.Reverse() ?? [])
429429
{
430-
yield return $"typeof({typeResolver.Resolve(setup, parent.Var.InstanceType)})";
430+
yield return $"typeof({typeResolver.ResolveRuntime(setup, parent.Var.InstanceType)})";
431431
}
432432

433433
yield return $"typeof({_ctx!.RootContext.Graph.Source.Name.FullName})";

src/Pure.DI.Core/Core/Code/ITypeResolver.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ namespace Pure.DI.Core.Code;
33
interface ITypeResolver
44
{
55
TypeDescription Resolve(MdSetup setup, ITypeSymbol type);
6-
}
6+
7+
TypeDescription ResolveRuntime(MdSetup setup, ITypeSymbol type);
8+
}

src/Pure.DI.Core/Core/Code/ImplementationCodeBuilder.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ private string CreateInstantiation(
176176
.ToList();
177177

178178
var args = string.Join(", ", ctorArgs.Select(i => buildTools.OnInjected(ctx, i)));
179-
code.Append(var.InstanceType.IsTupleType ? $"({args})" : $"new {typeResolver.Resolve(ctx.RootContext.Graph.Source, var.InstanceType)}({args})");
179+
var instanceType = RemoveNullableAnnotation(var.InstanceType);
180+
code.Append(var.InstanceType.IsTupleType ? $"({args})" : $"new {typeResolver.Resolve(ctx.RootContext.Graph.Source, instanceType)}({args})");
180181
if (required.Count > 0)
181182
{
182183
code.Append($" {LinesExtensions.BlockStart} ");
@@ -191,4 +192,12 @@ private string CreateInstantiation(
191192

192193
return code.ToString();
193194
}
195+
196+
private static ITypeSymbol RemoveNullableAnnotation(ITypeSymbol type) =>
197+
type switch
198+
{
199+
INamedTypeSymbol namedType => namedType.WithNullableAnnotation(NullableAnnotation.NotAnnotated),
200+
IArrayTypeSymbol arrayType => arrayType.WithNullableAnnotation(NullableAnnotation.NotAnnotated),
201+
_ => type.WithNullableAnnotation(NullableAnnotation.NotAnnotated)
202+
};
194203
}

src/Pure.DI.Core/Core/Code/NodeTools.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace Pure.DI.Core.Code;
77

88
sealed class NodeTools(
99
ITypes types,
10-
ICache<NodeTools.LazyKey, bool> isLazy) : INodeTools
10+
ICache<NodeTools.LazyKey, bool> isLazy,
11+
ITypeSymbolComparer typeSymbolComparer) : INodeTools
1112
{
1213
public bool IsLazy(DependencyNode node, DependencyGraph graph) =>
1314
isLazy.Get(new LazyKey(node, graph.Source.SemanticModel), key =>
@@ -41,12 +42,12 @@ private bool IsAsyncDisposable(Compilation compilation, ISymbol type) =>
4142
private static bool IsDelegate(DependencyNode node) =>
4243
node.Type.TypeKind == TypeKind.Delegate;
4344

44-
private static bool IsLazyFactory(DpFactory factory, SemanticModel semanticModel) =>
45+
private bool IsLazyFactory(DpFactory factory, SemanticModel semanticModel) =>
4546
factory.Resolvers.All(i => IsLazy(factory, i.Source.Source, semanticModel))
4647
&& factory.Initializers.All(i => IsLazy(factory, i.Source.Source, semanticModel))
4748
&& factory.OverridesMap.Values.All(i => IsLazy(factory, i.Source.Source, semanticModel));
4849

49-
private static bool IsLazy(DpFactory factory, ExpressionSyntax source, SemanticModel semanticModel)
50+
private bool IsLazy(DpFactory factory, ExpressionSyntax source, SemanticModel semanticModel)
5051
{
5152
if (semanticModel.SyntaxTree != factory.Source.Factory.SyntaxTree)
5253
{
@@ -75,7 +76,7 @@ private static bool IsLazy(DpFactory factory, ExpressionSyntax source, SemanticM
7576
continue;
7677
}
7778

78-
return !SymbolEqualityComparer.Default.Equals(factoryType, invocationType);
79+
return !typeSymbolComparer.RuntimeEquals(factoryType, invocationType);
7980
}
8081

8182
return false;
@@ -99,4 +100,4 @@ public override int GetHashCode()
99100
}
100101
}
101102
}
102-
}
103+
}

0 commit comments

Comments
 (0)