Skip to content

Commit e28d235

Browse files
#88 Support for contextual/scoped dependency overrides
1 parent 27e8e49 commit e28d235

52 files changed

Lines changed: 1621 additions & 255 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/Pure.DI.Core/Components/Api.g.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3005,6 +3005,8 @@ internal interface IContext
30053005
/// <typeparam name="T">Object type.</typeparam>
30063006
/// <seealso cref="IBinding.To{T}(System.Func{Pure.DI.IContext,T})"/>
30073007
void BuildUp<T>(T value);
3008+
3009+
void Override<T>(T value, object tag = null);
30083010
}
30093011

30103012
/// <summary>

src/Pure.DI.Core/Core/ApiInvocationProcessor.cs

Lines changed: 191 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ sealed class ApiInvocationProcessor(
1212
ISemantic semantic,
1313
ISymbolNames symbolNames,
1414
[Tag(Tag.UniqueTag)] IdGenerator idGenerator,
15+
IOverrideIdProvider overrideIdProvider,
1516
IBaseSymbolsProvider baseSymbolsProvider,
1617
INameFormatter nameFormatter,
1718
ITypes types,
1819
IWildcardMatcher wildcardMatcher,
1920
Func<INamespacesWalker> namespacesWalkerFactory,
20-
Func<IFactoryResolversWalker> factoryResolversWalkerFactory)
21+
Func<IFactoryApiWalker> factoryApiWalkerFactory,
22+
Func<ILocalVariableRenamingRewriter> localVariableRenamingRewriterFactory)
2123
: IApiInvocationProcessor
2224
{
2325
private static readonly char[] TypeNamePartsSeparators = ['.'];
@@ -626,6 +628,7 @@ private void VisitSimpleFactory(
626628
semanticModel,
627629
source,
628630
returnType,
631+
localVariableRenamingRewriterFactory(),
629632
lambdaExpression,
630633
true,
631634
SyntaxFactory.Parameter(SyntaxFactory.Identifier("ctx_1182D127")),
@@ -769,100 +772,20 @@ private void VisitFactory(
769772
return;
770773
}
771774

772-
var factoryResolversWalker = factoryResolversWalkerFactory();
773-
factoryResolversWalker.Visit(lambdaExpression);
775+
var localVariableRenamingRewriter = localVariableRenamingRewriterFactory()!;
776+
var factoryApiWalker = factoryApiWalkerFactory();
777+
factoryApiWalker.Visit(lambdaExpression);
774778
var position = 0;
775779
var hasContextTag = false;
776-
var resolvers = factoryResolversWalker.Resolvers.Select(invocation => {
777-
if (invocation.ArgumentList.Arguments is not { Count: > 0 } invArguments)
778-
{
779-
return default;
780-
}
781-
782-
switch (invArguments)
783-
{
784-
case [{ RefOrOutKeyword.IsMissing: false } targetValue]:
785-
var argSymbol = GetArgSymbol(semanticModel, invArguments[0], resultType);
786-
if (argSymbol is not null)
787-
{
788-
return new MdResolver(
789-
semanticModel,
790-
invocation,
791-
position++,
792-
argSymbol,
793-
null,
794-
targetValue.Expression);
795-
}
796-
797-
break;
798-
799-
default:
800-
var args = arguments.GetArgs(invocation.ArgumentList, "tag", "value");
801-
var tag = args[0]?.Expression;
802-
803-
hasContextTag =
804-
tag is MemberAccessExpressionSyntax memberAccessExpression
805-
&& memberAccessExpression.IsKind(SyntaxKind.SimpleMemberAccessExpression)
806-
&& memberAccessExpression.Name.Identifier.Text == nameof(IContext.Tag)
807-
&& memberAccessExpression.Expression is IdentifierNameSyntax identifierName
808-
&& identifierName.Identifier.Text == contextParameter.Identifier.Text;
809-
810-
var resolverTag = new MdTag(
811-
0,
812-
hasContextTag
813-
? MdTag.ContextTag
814-
: tag is null
815-
? null
816-
: semantic.GetConstantValue<object>(semanticModel, tag));
817-
818-
if (args[1] is {} valueArg)
819-
{
820-
var argType = GetArgSymbol(semanticModel, valueArg, resultType);
821-
if (argType is null
822-
&& invocation.SyntaxTree == semanticModel.SyntaxTree && semanticModel.GetOperation(invocation) is {} invocationOperation
823-
&& invocationOperation.ChildOperations.OfType<IDeclarationExpressionOperation>().FirstOrDefault() is { Type: {} declarationType })
824-
{
825-
argType = declarationType;
826-
}
827-
828-
if (argType is not null)
829-
{
830-
return new MdResolver(
831-
semanticModel,
832-
invocation,
833-
position++,
834-
argType,
835-
resolverTag,
836-
valueArg.Expression);
837-
}
838-
}
839-
840-
break;
841-
}
842-
843-
return default;
844-
})
780+
var resolvers = factoryApiWalker.Meta
781+
.Where(i => i.Kind == FactoryMetaKind.Resolver)
782+
.Select(meta => CreateResolver(semanticModel, resultType, meta, contextParameter, ref position, ref hasContextTag, localVariableRenamingRewriter))
845783
.Where(i => i != default)
846784
.ToImmutableArray();
847785

848-
var initializers = factoryResolversWalker.Initializers.Select(invocation => {
849-
if (invocation.ArgumentList.Arguments is not [{} targetArg])
850-
{
851-
return default;
852-
}
853-
854-
var targetType = GetArgSymbol(semanticModel, targetArg, resultType);
855-
if (targetType is null)
856-
{
857-
return default;
858-
}
859-
860-
return new MdInitializer(
861-
semanticModel,
862-
invocation,
863-
targetType,
864-
targetArg.Expression);
865-
})
786+
var initializers = factoryApiWalker.Meta
787+
.Where(i => i.Kind == FactoryMetaKind.Initializer)
788+
.Select(meta => CreateInitializer(semanticModel, resultType, meta, contextParameter, ref hasContextTag, localVariableRenamingRewriter))
866789
.Where(i => i != default)
867790
.ToImmutableArray();
868791

@@ -876,6 +799,7 @@ tag is MemberAccessExpressionSyntax memberAccessExpression
876799
semanticModel,
877800
lambdaExpression,
878801
resultType,
802+
localVariableRenamingRewriter,
879803
lambdaExpression,
880804
false,
881805
contextParameter,
@@ -884,6 +808,183 @@ tag is MemberAccessExpressionSyntax memberAccessExpression
884808
hasContextTag));
885809
}
886810

811+
private MdOverride CreateOverride(
812+
SemanticModel semanticModel,
813+
ITypeSymbol resultType,
814+
OverrideMeta @override,
815+
ParameterSyntax contextParameter,
816+
ILocalVariableRenamingRewriter localVariableRenamingRewriter,
817+
ref bool hasContextTag)
818+
{
819+
var invocation = @override.Expression;
820+
if (invocation.ArgumentList.Arguments is not { Count: 1 or 2 })
821+
{
822+
return default;
823+
}
824+
825+
var args = arguments.GetArgs(invocation.ArgumentList, "value", "tag");
826+
if (args[0] is not {} valueArg)
827+
{
828+
return default;
829+
}
830+
831+
var argType = GetArg(semanticModel, resultType, valueArg, invocation);
832+
if (argType is null)
833+
{
834+
return default;
835+
}
836+
837+
var tag = args[1]?.Expression;
838+
var hasCtx = HasContextTag(tag, contextParameter);
839+
hasContextTag |= hasCtx;
840+
var tagValue = hasCtx ? MdTag.ContextTag : tag is null ? null : semantic.GetConstantValue<object>(semanticModel, tag);
841+
var resolverTag = new MdTag(0, tagValue);
842+
var valueExpression = (ExpressionSyntax)localVariableRenamingRewriter.Rewrite(semanticModel, false, true, valueArg.Expression);
843+
return new MdOverride(
844+
semanticModel,
845+
invocation,
846+
overrideIdProvider.GetId(argType, tagValue),
847+
@override.Position,
848+
argType,
849+
resolverTag,
850+
valueExpression);
851+
}
852+
853+
private MdInitializer CreateInitializer(
854+
SemanticModel semanticModel,
855+
ITypeSymbol resultType,
856+
FactoryMeta meta,
857+
ParameterSyntax contextParameter,
858+
ref bool hasContextTag,
859+
ILocalVariableRenamingRewriter localVariableRenamingRewriter)
860+
{
861+
var invocation = meta.Expression;
862+
if (invocation.ArgumentList.Arguments is not [{} targetArg])
863+
{
864+
return default;
865+
}
866+
867+
var targetType = GetArgSymbol(semanticModel, targetArg, resultType);
868+
if (targetType is null)
869+
{
870+
return default;
871+
}
872+
873+
var overrides = new List<MdOverride>();
874+
// ReSharper disable once LoopCanBeConvertedToQuery
875+
foreach (var @override in meta.Overrides)
876+
{
877+
var mdOverride = CreateOverride(semanticModel, resultType, @override, contextParameter, localVariableRenamingRewriter, ref hasContextTag);
878+
if (mdOverride != default)
879+
{
880+
overrides.Add(mdOverride);
881+
}
882+
}
883+
884+
return new MdInitializer(
885+
semanticModel,
886+
invocation,
887+
targetType,
888+
targetArg.Expression,
889+
overrides.ToImmutableArray());
890+
}
891+
892+
private MdResolver CreateResolver(
893+
SemanticModel semanticModel,
894+
ITypeSymbol resultType,
895+
FactoryMeta meta,
896+
ParameterSyntax contextParameter,
897+
ref int position,
898+
ref bool hasContextTag,
899+
ILocalVariableRenamingRewriter localVariableRenamingRewriter)
900+
{
901+
var invocation = meta.Expression;
902+
if (invocation.ArgumentList.Arguments is not { Count: > 0 } invArguments)
903+
{
904+
return default;
905+
}
906+
907+
var overrides = new List<MdOverride>();
908+
// ReSharper disable once LoopCanBeConvertedToQuery
909+
foreach (var overrideInvocation in meta.Overrides)
910+
{
911+
var mdOverride = CreateOverride(semanticModel, resultType, overrideInvocation, contextParameter, localVariableRenamingRewriter, ref hasContextTag);
912+
if (mdOverride != default)
913+
{
914+
overrides.Add(mdOverride);
915+
}
916+
}
917+
918+
switch (invArguments)
919+
{
920+
case [{ RefOrOutKeyword.IsMissing: false } targetValue]:
921+
var argSymbol = GetArgSymbol(semanticModel, invArguments[0], resultType);
922+
if (argSymbol is not null)
923+
{
924+
return new MdResolver(
925+
semanticModel,
926+
invocation,
927+
position++,
928+
argSymbol,
929+
null,
930+
targetValue.Expression,
931+
overrides.ToImmutableArray());
932+
}
933+
934+
break;
935+
936+
default:
937+
var args = arguments.GetArgs(invocation.ArgumentList, "tag", "value");
938+
var tag = args[0]?.Expression;
939+
var hasCtx = HasContextTag(tag, contextParameter);
940+
hasContextTag |= hasCtx;
941+
var tagValue = hasCtx ? MdTag.ContextTag : tag is null ? null : semantic.GetConstantValue<object>(semanticModel, tag);
942+
var resolverTag = new MdTag(0, tagValue);
943+
if (args[1] is {} valueArg)
944+
{
945+
var argType = GetArg(semanticModel, resultType, valueArg, invocation);
946+
if (argType is not null)
947+
{
948+
return new MdResolver(
949+
semanticModel,
950+
invocation,
951+
position++,
952+
argType,
953+
resolverTag,
954+
valueArg.Expression,
955+
overrides.ToImmutableArray());
956+
}
957+
}
958+
959+
break;
960+
}
961+
962+
return default;
963+
}
964+
965+
private static ITypeSymbol? GetArg(
966+
SemanticModel semanticModel,
967+
ITypeSymbol resultType,
968+
ArgumentSyntax valueArg,
969+
InvocationExpressionSyntax invocation)
970+
{
971+
var argType = GetArgSymbol(semanticModel, valueArg, resultType);
972+
if (argType is null
973+
&& invocation.SyntaxTree == semanticModel.SyntaxTree && semanticModel.GetOperation(invocation) is {} invocationOperation
974+
&& invocationOperation.ChildOperations.OfType<IDeclarationExpressionOperation>().FirstOrDefault() is { Type: {} declarationType })
975+
{
976+
argType = declarationType;
977+
}
978+
return argType;
979+
}
980+
981+
private static bool HasContextTag(ExpressionSyntax? tag, ParameterSyntax contextParameter) =>
982+
tag is MemberAccessExpressionSyntax memberAccessExpression
983+
&& memberAccessExpression.IsKind(SyntaxKind.SimpleMemberAccessExpression)
984+
&& memberAccessExpression.Name.Identifier.Text == nameof(IContext.Tag)
985+
&& memberAccessExpression.Expression is IdentifierNameSyntax identifierName
986+
&& identifierName.Identifier.Text == contextParameter.Identifier.Text;
987+
887988
private static ITypeSymbol? GetArgSymbol(SemanticModel semanticModel, ArgumentSyntax argumentSyntax, ITypeSymbol defaultType)
888989
{
889990
ITypeSymbol? argType = null;

src/Pure.DI.Core/Core/Code/CompositionBuilder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public CompositionCode Build(DependencyGraph graph)
4141
foreach (var perResolveVar in map.GetPerResolves())
4242
{
4343
ctx.Code.AppendLine($"var {perResolveVar.VariableName} = default({typeResolver.Resolve(graph.Source, perResolveVar.InstanceType)});");
44-
if (perResolveVar.Info.RefCount > 1 && perResolveVar.InstanceType.IsValueType)
44+
if (perResolveVar.Info.RefCount > 0 && perResolveVar.InstanceType.IsValueType)
4545
{
4646
ctx.Code.AppendLine($"var {perResolveVar.VariableName}Created = false;");
4747
}

src/Pure.DI.Core/Core/Code/CompositionClassBuilder.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public CompositionCode Build(CompositionCode composition)
2222
if (composition.Compilation.Options.NullableContextOptions != NullableContextOptions.Disable)
2323
{
2424
code.AppendLine("#nullable enable annotations");
25+
code.AppendLine("#pragma warning disable CS0219");
2526
}
2627

2728
code.AppendLine();

src/Pure.DI.Core/Core/Code/ConstructCodeBuilder.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public void Build(BuildContext ctx, in DpConstruct construct)
4141
break;
4242

4343
case MdConstructKind.Accumulator:
44+
case MdConstructKind.Override:
4445
break;
4546

4647
case MdConstructKind.None:
@@ -101,7 +102,7 @@ private void BuildArray(BuildContext ctx, in DpConstruct array)
101102
}
102103
else
103104
{
104-
variable.VariableCode = instantiation;
105+
variable.VariableCode = instantiation;
105106
}
106107
}
107108

0 commit comments

Comments
 (0)