Skip to content

Commit 8e9e377

Browse files
#88 Support for contextual/scoped dependency overrides 2
1 parent 231c64c commit 8e9e377

6 files changed

Lines changed: 116 additions & 14 deletions

File tree

src/Pure.DI.Core/Core/Code/ConstructCodeBuilder.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ public void Build(BuildContext ctx, in DpConstruct construct)
4141
break;
4242

4343
case MdConstructKind.Accumulator:
44-
case MdConstructKind.Override:
44+
break;
45+
46+
case MdConstructKind.Override
47+
when construct.Source.State is MdOverride @override:
48+
BuildOverride(ctx, @override);
4549
break;
4650

4751
case MdConstructKind.None:
@@ -137,4 +141,10 @@ private static void BuildExplicitDefaultValue(BuildContext ctx, in DpConstruct e
137141
var variable = ctx.Variable;
138142
ctx.Code.AppendLine($"{ctx.BuildTools.GetDeclaration(variable)}{variable.VariableName} = {explicitDefault.Source.ExplicitDefaultValue.ValueToString()};");
139143
}
144+
145+
private static void BuildOverride(BuildContext ctx, in MdOverride @override)
146+
{
147+
var variable = ctx.Variable;
148+
ctx.Code.AppendLine($"{ctx.BuildTools.GetDeclaration(variable)}{variable.VariableName} = {@override.ValueExpression.ValueToString()};");
149+
}
140150
}

src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ public void Build(BuildContext ctx, in DpFactory factory)
153153

154154
// Rewrites syntax tree
155155
var finishLabel = $"{variable.VariableDeclarationName}Finish";
156-
var factoryExpression = (LambdaExpressionSyntax)factory.Source.LocalVariableRenamingRewriter.Rewrite(ctx.DependencyGraph.Source.SemanticModel, ctx.DependencyGraph.Source.Hints.IsFormatCodeEnabled, false, originalLambda);
156+
var factoryExpression = (LambdaExpressionSyntax)factory.Source.LocalVariableRenamingRewriter.Clone().Rewrite(ctx.DependencyGraph.Source.SemanticModel, ctx.DependencyGraph.Source.Hints.IsFormatCodeEnabled, false, originalLambda);
157157
var injections = new List<FactoryRewriter.Injection>();
158158
var inits = new List<FactoryRewriter.Initializer>();
159159
var factoryRewriter = new FactoryRewriter(arguments, compilations, factory, variable, finishLabel, injections, inits, triviaTools, symbolNames);

src/Pure.DI.Core/Core/Code/ILocalVariableRenamingRewriter.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
interface ILocalVariableRenamingRewriter
44
{
55
SyntaxNode Rewrite(SemanticModel semanticModel, bool formatCode, bool isPartial, SyntaxNode lambda);
6+
7+
ILocalVariableRenamingRewriter Clone();
68
}

src/Pure.DI.Core/Core/Code/LocalVariableRenamingRewriter.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ sealed class LocalVariableRenamingRewriter(
77
IVariableNameProvider variableNameProvider)
88
: CSharpSyntaxRewriter, ILocalVariableRenamingRewriter
99
{
10-
private readonly Dictionary<string, string> _names = [];
10+
private Dictionary<string, string> Names { get; init; } = [];
1111
private bool _formatCode;
1212
private bool _isPartial;
1313
private SemanticModel? _semanticModel;
@@ -20,6 +20,14 @@ public SyntaxNode Rewrite(SemanticModel semanticModel, bool formatCode, bool isP
2020
return Visit(lambda);
2121
}
2222

23+
public ILocalVariableRenamingRewriter Clone()
24+
{
25+
return new LocalVariableRenamingRewriter(triviaTools, variableNameProvider)
26+
{
27+
Names = new Dictionary<string, string>(Names)
28+
};
29+
}
30+
2331
public override SyntaxNode? VisitVariableDeclarator(VariableDeclaratorSyntax node) =>
2432
base.VisitVariableDeclarator(node.WithIdentifier(SyntaxFactory.Identifier(GetUniqueName(node.Identifier.Text))));
2533

@@ -31,7 +39,7 @@ public SyntaxNode Rewrite(SemanticModel semanticModel, bool formatCode, bool isP
3139

3240
public override SyntaxToken VisitToken(SyntaxToken token)
3341
{
34-
if (_names.TryGetValue(token.Text, out var newName)
42+
if (Names.TryGetValue(token.Text, out var newName)
3543
&& token.IsKind(SyntaxKind.IdentifierToken)
3644
&& token.Parent is {} parent
3745
&& (_semanticModel?.SyntaxTree != parent.SyntaxTree || _semanticModel.GetSymbolInfo(parent).Symbol is ILocalSymbol))
@@ -44,10 +52,10 @@ public override SyntaxToken VisitToken(SyntaxToken token)
4452

4553
private string GetUniqueName(string baseName)
4654
{
47-
if (!_names.TryGetValue(baseName, out var newName))
55+
if (!Names.TryGetValue(baseName, out var newName))
4856
{
4957
newName = variableNameProvider.GetLocalUniqueVariableName(baseName);
50-
_names.Add(baseName, newName);
58+
Names.Add(baseName, newName);
5159
}
5260

5361
return newName;

src/Pure.DI.Core/Core/Code/VariablesBuilder.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,6 @@ private Variable GetVariable(
213213
};
214214
}
215215

216-
if (node.Construct is { Source: { Kind: MdConstructKind.Override, State: MdOverride @override } })
217-
{
218-
return new Variable(variableNameProvider, setup, parentBlock, node.Binding.Id, node, injection, new List<IStatement>(), new VariableInfo(), nodeInfo.IsLazy(node), false)
219-
{
220-
VariableCode = @override.ValueExpression.ToString()
221-
};
222-
}
223-
224216
if (node.Arg is null)
225217
{
226218
// ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault

tests/Pure.DI.IntegrationTests/OverrideTests.cs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,94 @@ public static void Main()
323323
result.Success.ShouldBeTrue(result);
324324
result.StdOut.ShouldBe(["Sample.LoggerA", "Sample.LoggerA", "Sample.LoggerB", "Sample.LoggerB"], result);
325325
}
326+
327+
[Fact]
328+
public async Task ShouldSupportOverrideWhenFuncWithArg()
329+
{
330+
// Given
331+
332+
// When
333+
var result = await """
334+
using System;
335+
using Pure.DI;
336+
337+
namespace Sample
338+
{
339+
interface ILogger {}
340+
341+
class Logger: ILogger
342+
{
343+
}
344+
345+
interface IDependency
346+
{
347+
ILogger Logger { get; }
348+
}
349+
350+
class Dependency: IDependency
351+
{
352+
public Dependency(ILogger logger, string name)
353+
{
354+
Logger = logger;
355+
Console.WriteLine(name);
356+
}
357+
358+
public ILogger Logger { get; set; }
359+
}
360+
361+
interface IService
362+
{
363+
IDependency Dep { get; }
364+
365+
ILogger Logger { get; }
366+
}
367+
368+
class Service: IService
369+
{
370+
public Service(Func<string, IDependency> dep, ILogger logger)
371+
{
372+
Dep = dep("Abc");
373+
Logger = logger;
374+
}
375+
376+
public IDependency Dep { get; set; }
377+
378+
public ILogger Logger { get; set; }
379+
}
380+
381+
static class Setup
382+
{
383+
private static void SetupComposition()
384+
{
385+
DI.Setup("Composition")
386+
.Bind().To<Logger>()
387+
.Bind().To<Func<string, IDependency>>(ctx =>
388+
{
389+
return new Func<string, IDependency>(name =>
390+
{
391+
ctx.Override(name);
392+
ctx.Inject(out Dependency dep);
393+
return dep;
394+
});
395+
})
396+
.Bind().To<Service>()
397+
.Root<IService>("Root");
398+
}
399+
}
400+
401+
public class Program
402+
{
403+
public static void Main()
404+
{
405+
var composition = new Composition();
406+
var root = composition.Root;
407+
}
408+
}
409+
}
410+
""".RunAsync();
411+
412+
// Then
413+
result.Success.ShouldBeTrue(result);
414+
result.StdOut.ShouldBe(["Abc"], result);
415+
}
326416
}

0 commit comments

Comments
 (0)