diff --git a/Source/CSharpEssentials.Tests/CSharpEssentials.Tests.csproj b/Source/CSharpEssentials.Tests/CSharpEssentials.Tests.csproj
index 6402470..7b4df69 100644
--- a/Source/CSharpEssentials.Tests/CSharpEssentials.Tests.csproj
+++ b/Source/CSharpEssentials.Tests/CSharpEssentials.Tests.csproj
@@ -59,6 +59,8 @@
+
+
diff --git a/Source/CSharpEssentials.Tests/NullCheckToNullConditional/NullCheckToNullConditionalAnalyzerTests.cs b/Source/CSharpEssentials.Tests/NullCheckToNullConditional/NullCheckToNullConditionalAnalyzerTests.cs
new file mode 100644
index 0000000..14ec697
--- /dev/null
+++ b/Source/CSharpEssentials.Tests/NullCheckToNullConditional/NullCheckToNullConditionalAnalyzerTests.cs
@@ -0,0 +1,54 @@
+using RoslynNUnitLight;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis.Diagnostics;
+using Microsoft.CodeAnalysis;
+using CSharpEssentials.NullCheckToNullConditional;
+using NUnit.Framework;
+
+namespace CSharpEssentials.Tests.NullCheckToNullConditional
+{
+ class NullCheckToNullConditionalAnalyzerTests : AnalyzerTestFixture
+ {
+ protected override string LanguageName => LanguageNames.CSharp;
+
+ protected override DiagnosticAnalyzer CreateAnalyzer() => new NullCheckToNullConditionalAnalyzer();
+
+ [Test]
+ public void TestNoFixOnComipleError()
+ {
+ const string markup = @"
+class C
+{
+ void M(object o)
+ {
+ if(o.GetType != null){
+ o.GetType.ToString()
+ }
+ }
+}
+";
+ NoDiagnostic(markup, DiagnosticIds.UseNullConditional);
+ }
+
+ [Test]
+ public void TestNoFixOnNoneInvocationBody()
+ {
+ const string markup = @"
+class C
+{
+ object M(object o)
+ {
+ if(o != null){
+ return o;
+ }
+ }
+}
+";
+ NoDiagnostic(markup, DiagnosticIds.UseNullConditional);
+ }
+ }
+}
diff --git a/Source/CSharpEssentials.Tests/NullCheckToNullConditional/NullCheckToNullConditionalCodeFixTests.cs b/Source/CSharpEssentials.Tests/NullCheckToNullConditional/NullCheckToNullConditionalCodeFixTests.cs
new file mode 100644
index 0000000..a7c13f9
--- /dev/null
+++ b/Source/CSharpEssentials.Tests/NullCheckToNullConditional/NullCheckToNullConditionalCodeFixTests.cs
@@ -0,0 +1,105 @@
+using RoslynNUnitLight;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis.CodeRefactorings;
+using Microsoft.CodeAnalysis;
+using CSharpEssentials.NullCheckToNullConditional;
+using NUnit.Framework;
+using Microsoft.CodeAnalysis.CodeFixes;
+using Microsoft.CodeAnalysis.Text;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+
+namespace CSharpEssentials.Tests.NullCheckToNullConditional
+{
+ class NullCheckToNullConditionalCodeFixTests : CodeFixTestFixture
+ {
+ protected override string LanguageName => LanguageNames.CSharp;
+ protected override CodeFixProvider CreateProvider() => new NullCheckToNullConditionalCodeFix();
+
+ const string codeBase = @"
+class SuperAwesomeCode
+{
+ interface A { B b(); }
+ interface B { C c { get; } }
+ interface C { DHolder this[int i] { get; } }
+ interface DHolder { D d { get; } }
+ interface D { void m(object o1, object o2); MyStruct? myStruct{ get; } }
+ struct MyStruct { int this[int i] => i; }
+ void M(A a, B b, C c, DHolder dHolder, D d, object blah, dynamic dyn)
+ {
+ <<<<>>>>
+ }
+}
+";
+ string InsertCode(string s) => codeBase.Replace("<<<<>>>>", s);
+
+ [Test]
+ public void SimpleTest()
+ {
+ var markupCode = InsertCode("[|if (null != d ) d.m(blah, blah);|]");
+ var expected = InsertCode("d?.m(blah, blah);");
+ TestCodeFix(markupCode, expected, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+ }
+
+ [Test]
+ public void TestPropertyAccessor()
+ {
+ var markupCode = InsertCode("[|if (b != null) b.c.ToString();|]");
+ var expected = InsertCode("b?.c.ToString();");
+ TestCodeFix(markupCode, expected, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+ }
+
+ [Test]
+ public void TestIndexer()
+ {
+ var markupCode = InsertCode("[|if (b.c != null) b.c[0].ToString();|]");
+ var expected = InsertCode("b.c?[0].ToString();");
+ TestCodeFix(markupCode, expected, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+ }
+
+ [Test]
+ public void TestNullableValueType()
+ {
+ var markupCode = InsertCode("[|if (d.myStruct != null) d.myStruct.Value[0].CompareTo(42).ToString();|]");
+ var expected = InsertCode("d.myStruct?[0].CompareTo(42).ToString();");
+ TestCodeFix(markupCode, expected, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+ }
+
+ [Test]
+ public void TestDynamicExpression()
+ {
+ var markupCode = InsertCode("[|if (dyn.x.y.z != null) dyn.x.y.z.m();|]");
+ var expected = InsertCode("dyn.x.y.z?.m();");
+ TestCodeFix(markupCode, expected, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+ }
+
+ [Test]
+ public void TestBlockStatement()
+ {
+ var markupCode = InsertCode("[|if (a.b() != null) { a.b().c[0].ToString(); }|]");
+ var expected = InsertCode("a.b()?.c[0].ToString();");
+ TestCodeFix(markupCode, expected, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+ }
+
+ [Test]
+ public void TestInvocationStartsWith()
+ {
+ var code = InsertCode("[|if (a != null) a.b().c[1].d.m(blah, blah);|]");
+ var expeced = InsertCode("a?.b().c[1].d.m(blah, blah);");
+
+ Document doc;
+ TextSpan span;
+ TestHelpers.TryGetDocumentAndSpanFromMarkup(code, LanguageNames.CSharp, out doc, out span);
+ var root = doc.GetSyntaxRootAsync().Result;
+ var ifStatement = root.FindNode(span) as IfStatementSyntax;
+ var exp = (ifStatement.Condition as BinaryExpressionSyntax).Left;
+ var chain = (ifStatement.Statement as ExpressionStatementSyntax).Expression;
+ ExpressionSyntax _;
+ Assert.True(NullCheckToNullConditionalCodeFix.MemberAccessChainExpressionStartsWith(chain, exp, out _));
+
+ }
+ }
+}
diff --git a/Source/CSharpEssentials/CSharpEssentials.csproj b/Source/CSharpEssentials/CSharpEssentials.csproj
index 112ec4b..a94b8ac 100644
--- a/Source/CSharpEssentials/CSharpEssentials.csproj
+++ b/Source/CSharpEssentials/CSharpEssentials.csproj
@@ -11,6 +11,8 @@
+
+
diff --git a/Source/CSharpEssentials/DiagnosticDescriptors.cs b/Source/CSharpEssentials/DiagnosticDescriptors.cs
index b3dd36f..f6b51a1 100644
--- a/Source/CSharpEssentials/DiagnosticDescriptors.cs
+++ b/Source/CSharpEssentials/DiagnosticDescriptors.cs
@@ -37,5 +37,22 @@ public static class DiagnosticDescriptors
category: DiagnosticCategories.Language,
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor UseNullConditionalMemberAccess = new DiagnosticDescriptor(
+ id: DiagnosticIds.UseNullConditional,
+ title: "Replace null-check if statement with null-conditional member access",
+ messageFormat: "Consider replacing the null-check if statement with null-conditional member access",
+ category: DiagnosticCategories.Language,
+ defaultSeverity: DiagnosticSeverity.Info,
+ isEnabledByDefault: true);
+
+ public static readonly DiagnosticDescriptor UseNullConditionalMemberAccessFadedToken = new DiagnosticDescriptor(
+ id: "UseNullConditionalMemberAccessFadedToken",
+ title: UseNullConditionalMemberAccess.Title,
+ messageFormat: UseNullConditionalMemberAccess.MessageFormat,
+ category: DiagnosticCategories.Language,
+ defaultSeverity: DiagnosticSeverity.Hidden,
+ isEnabledByDefault: true,
+ customTags: new[] { WellKnownDiagnosticTags.Unnecessary });
}
}
diff --git a/Source/CSharpEssentials/DiagnosticIds.cs b/Source/CSharpEssentials/DiagnosticIds.cs
index b0dd535..8a6f08b 100644
--- a/Source/CSharpEssentials/DiagnosticIds.cs
+++ b/Source/CSharpEssentials/DiagnosticIds.cs
@@ -5,5 +5,6 @@ internal static class DiagnosticIds
public const string UseNameOf = "CSE0001";
public const string UseGetterOnlyAutoProperty = "CSE0002";
public const string UseExpressionBodiedMember = "CSE0003";
+ public const string UseNullConditional = "CSE0004";
}
}
diff --git a/Source/CSharpEssentials/NullCheckToNullConditional/NullCheckToNullConditionalAnalyzer.cs b/Source/CSharpEssentials/NullCheckToNullConditional/NullCheckToNullConditionalAnalyzer.cs
new file mode 100644
index 0000000..d2c9887
--- /dev/null
+++ b/Source/CSharpEssentials/NullCheckToNullConditional/NullCheckToNullConditionalAnalyzer.cs
@@ -0,0 +1,59 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeRefactorings;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Formatting;
+using System.Collections.Immutable;
+using Microsoft.CodeAnalysis.Simplification;
+using Microsoft.CodeAnalysis.Diagnostics;
+using Microsoft.CodeAnalysis.Text;
+
+namespace CSharpEssentials.NullCheckToNullConditional
+{
+ [DiagnosticAnalyzer(LanguageNames.CSharp)]
+ public class NullCheckToNullConditionalAnalyzer : DiagnosticAnalyzer
+ {
+ public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(DiagnosticDescriptors.UseNullConditionalMemberAccessFadedToken, DiagnosticDescriptors.UseNullConditionalMemberAccess);
+
+ private static async void AnalyzeThat(SyntaxNodeAnalysisContext context)
+ {
+ var ifStatement = context.Node.FindNode(context.Node.Span, getInnermostNodeForTie: true)?.FirstAncestorOrSelf();
+ try
+ {
+ if (await NullCheckToNullConditionalCodeFix.GetCodeFixAsync(() => Task.FromResult(context.SemanticModel), ifStatement) != null)
+ {
+ if (ifStatement.SyntaxTree.IsGeneratedCode(context.CancellationToken))
+ return;
+ var fadeoutLocations = ImmutableArray.CreateBuilder();
+ fadeoutLocations.Add(Location.Create(context.Node.SyntaxTree, TextSpan.FromBounds(ifStatement.IfKeyword.SpanStart, ifStatement.Statement.SpanStart)));
+
+ var statementBlock = ifStatement.Statement as BlockSyntax;
+ if (statementBlock != null)
+ {
+ fadeoutLocations.Add(Location.Create(context.Node.SyntaxTree, (statementBlock.OpenBraceToken.Span)));
+ fadeoutLocations.Add(Location.Create(context.Node.SyntaxTree, (statementBlock.CloseBraceToken.Span)));
+ }
+ foreach (var location in fadeoutLocations)
+ {
+ context.ReportDiagnostic(Diagnostic.Create(DiagnosticDescriptors.UseNullConditionalMemberAccessFadedToken, location));
+ }
+ context.ReportDiagnostic(Diagnostic.Create(DiagnosticDescriptors.UseNullConditionalMemberAccess,
+ Location.Create(context.Node.SyntaxTree, ifStatement.Span)));
+ }
+ }
+ catch (OperationCanceledException ex) when (ex.CancellationToken == context.CancellationToken)
+ {
+ // we should ignore cancellation exceptions, instead of blowing up the universe!
+ }
+ }
+
+ public override void Initialize(AnalysisContext context)
+ {
+ context.RegisterSyntaxNodeAction(AnalyzeThat, ImmutableArray.Create(SyntaxKind.IfStatement));
+ }
+ }
+}
diff --git a/Source/CSharpEssentials/NullCheckToNullConditional/NullCheckToNullConditionalCodeFix.cs b/Source/CSharpEssentials/NullCheckToNullConditional/NullCheckToNullConditionalCodeFix.cs
new file mode 100644
index 0000000..b2fc8e4
--- /dev/null
+++ b/Source/CSharpEssentials/NullCheckToNullConditional/NullCheckToNullConditionalCodeFix.cs
@@ -0,0 +1,189 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using System.Collections.Immutable;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CodeActions;
+using Microsoft.CodeAnalysis.CodeFixes;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using Microsoft.CodeAnalysis.Text;
+
+namespace CSharpEssentials.NullCheckToNullConditional
+{
+ [ExportCodeFixProvider(LanguageNames.CSharp, Name = "Use null-conditional operator")]
+ class NullCheckToNullConditionalCodeFix : CodeFixProvider
+ {
+ public override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(DiagnosticIds.UseNullConditional);
+ public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;
+
+ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
+ {
+ if (context.Diagnostics.Length == 0) return;
+
+ var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken);
+ foreach (var diag in context.Diagnostics)
+ {
+ var ifStatement = diag.Location.SourceTree.GetRoot().FindNode(diag.Location.SourceSpan, getInnermostNodeForTie: true).FirstAncestorOrSelf();
+ context.RegisterCodeFix(CodeAction.Create("Replace null-check 'if' with null-conditional member access", async ct =>
+ {
+ var expressionToReplace = await GetCodeFixAsync(() => context.Document.GetSemanticModelAsync(ct), ifStatement);
+ var newSyntax = root.ReplaceNode(ifStatement, expressionToReplace);
+ return context.Document.WithSyntaxRoot(newSyntax);
+ }), context.Diagnostics);
+ }
+ }
+
+ internal static async Task GetCodeFixAsync(Func> semanticModelLazy, IfStatementSyntax ifStatement)
+ {
+ if (ifStatement != null && ifStatement.Else == null)
+ {
+ var binaryExpression = ifStatement.Condition as BinaryExpressionSyntax;
+ if (binaryExpression?.IsKind(SyntaxKind.NotEqualsExpression) == true)
+ {
+ ExpressionSyntax nullableExpression = null;
+ if (binaryExpression.Left.IsKind(SyntaxKind.NullLiteralExpression))
+ nullableExpression = binaryExpression.Right;
+ else if (binaryExpression.Right.IsKind(SyntaxKind.NullLiteralExpression))
+ nullableExpression = binaryExpression.Left;
+
+ if (nullableExpression != null)
+ {
+ var block = ifStatement.Statement as BlockSyntax;
+ var bodyExpressionStatement = (block?.Statements.Count == 1) ?
+ block.Statements[0] as ExpressionStatementSyntax : ifStatement.Statement as ExpressionStatementSyntax;
+
+ if (bodyExpressionStatement != null)
+ {
+ var invocation = bodyExpressionStatement.Expression as InvocationExpressionSyntax;
+ ExpressionSyntax chainStart;
+ if (invocation != null && MemberAccessChainExpressionStartsWith(invocation, nullableExpression, out chainStart))
+ {
+ var semanticModel = await semanticModelLazy();
+ var referenceType = semanticModel.GetTypeInfo(nullableExpression).Type?.IsReferenceType;
+ if (referenceType == null) return null;
+ if (referenceType == false)
+ {
+ var chainStartParentMemberAccess = chainStart.Parent as MemberAccessExpressionSyntax;
+ if (chainStartParentMemberAccess != null)
+ {
+ if (chainStartParentMemberAccess.Name.Identifier.ValueText == "Value")
+ {
+ var InvocationValueRemoved = invocation.ReplaceNode(chainStartParentMemberAccess, nullableExpression);
+ bodyExpressionStatement = bodyExpressionStatement.ReplaceNode(invocation, InvocationValueRemoved);
+ ExpressionSyntax newChainStart;
+ if (!MemberAccessChainExpressionStartsWith(bodyExpressionStatement.Expression, chainStart, out newChainStart))
+ return null;
+ chainStart = newChainStart;
+ }
+ }
+ }
+ var nullableExpressionMemberCall = GetPropertyIndexerMethodCallExpression(chainStart);
+ var nullableExpressionNullConditionalMemberCall = ConvertToNullConditionalAccess(nullableExpressionMemberCall);
+ if (nullableExpressionNullConditionalMemberCall != null)
+ {
+ return bodyExpressionStatement
+ .ReplaceNode(nullableExpressionMemberCall, nullableExpressionNullConditionalMemberCall)
+ .WithTriviaFrom(ifStatement);
+ }
+ }
+ }
+ }
+ }
+ }
+ return null;
+ }
+
+ ///
+ /// returns the method, indexer, or property access on this expression, if there is any
+ ///
+ ///
+ ///
+ private static ExpressionSyntax GetPropertyIndexerMethodCallExpression(ExpressionSyntax exp)
+ {
+ if (exp.Parent is MemberAccessExpressionSyntax && exp.Parent.Parent is InvocationExpressionSyntax)
+ {
+ return exp.Parent.Parent as InvocationExpressionSyntax;
+ }
+ if (exp.Parent is MemberAccessExpressionSyntax)
+ {
+ return exp.Parent as MemberAccessExpressionSyntax;
+ }
+ if (exp.Parent is ElementAccessExpressionSyntax)
+ {
+ return exp.Parent as ElementAccessExpressionSyntax;
+ }
+ return null;
+ }
+ ///
+ /// converts a normal method, property, or indexer access syntax to the null-conditional version
+ ///
+ ///
+ /// the converted version of a method, property, or indexer access syntax;
+ /// or null if the argument is none of them.
+ ///
+ private static ExpressionSyntax ConvertToNullConditionalAccess(ExpressionSyntax exp)
+ {
+ if (exp is ConditionalAccessExpressionSyntax)
+ {
+ return exp;
+ }
+ else if (exp is InvocationExpressionSyntax)
+ {
+ var invocation = exp as InvocationExpressionSyntax;
+ var memberAccess = (invocation.Expression as MemberAccessExpressionSyntax);
+ if (memberAccess != null)
+ {
+ return ConditionalAccessExpression(memberAccess.Expression,
+ InvocationExpression(MemberBindingExpression(memberAccess.Name), invocation.ArgumentList));
+ }
+ else
+ {
+ return null;
+ }
+ }
+ else if (exp is ElementAccessExpressionSyntax)
+ {
+ var elementAccess = exp as ElementAccessExpressionSyntax;
+ return ConditionalAccessExpression(elementAccess.Expression,
+ ElementBindingExpression(elementAccess.ArgumentList));
+ }
+ else if (exp is MemberAccessExpressionSyntax)
+ {
+ var memberAccess = exp as MemberAccessExpressionSyntax;
+ return ConditionalAccessExpression(memberAccess.Expression,
+ MemberBindingExpression(memberAccess.Name));
+ }
+ else return null;
+ }
+
+ ///
+ /// determines whether 'beginning' appears at the start of 'memberAccessChain',
+ /// for example 'a.b(1)[3]' is considered to be at the start of 'a.b(1)[3].c.m()'
+ ///
+ /// if this method returns true, this parameter will hold the expression at the beginning of memberAccessChain that is equivalent to beginning
+ ///
+ public static bool MemberAccessChainExpressionStartsWith(ExpressionSyntax memberAccessChain, ExpressionSyntax beginning, out ExpressionSyntax atTheBeginning)
+ {
+ if (AreEquivalent(memberAccessChain, beginning, false))
+ {
+ atTheBeginning = memberAccessChain;
+ return true;
+ }
+ switch (memberAccessChain.Kind())
+ {
+ case SyntaxKind.InvocationExpression:
+ return MemberAccessChainExpressionStartsWith((memberAccessChain as InvocationExpressionSyntax).Expression, beginning, out atTheBeginning);
+ case SyntaxKind.SimpleMemberAccessExpression:
+ return MemberAccessChainExpressionStartsWith((memberAccessChain as MemberAccessExpressionSyntax).Expression, beginning, out atTheBeginning);
+ case SyntaxKind.ElementAccessExpression:
+ return MemberAccessChainExpressionStartsWith((memberAccessChain as ElementAccessExpressionSyntax).Expression, beginning, out atTheBeginning);
+ default:
+ atTheBeginning = null;
+ return false;
+ }
+ }
+ }
+}
diff --git a/readme.md b/readme.md
index f15e6e2..e8f5e53 100644
--- a/readme.md
+++ b/readme.md
@@ -5,7 +5,7 @@ refactorings that make it easy to work with C# 6 language features,
such as [nameof expressions](https://github.com/dotnet/roslyn/wiki/New-Language-Features-in-C%23-6#nameof-expressions),
[getter-only auto-properties](https://github.com/dotnet/roslyn/wiki/New-Language-Features-in-C%23-6#getter-only-auto-properties),
[expression-bodied members](https://github.com/dotnet/roslyn/wiki/New-Language-Features-in-C%23-6#expression-bodied-function-members),
-and [string interpolation](https://github.com/dotnet/roslyn/wiki/New-Language-Features-in-C%23-6#string-interpolation).
+ [string interpolation](https://github.com/dotnet/roslyn/wiki/New-Language-Features-in-C%23-6#string-interpolation), and [null-conditional operators](https://github.com/dotnet/roslyn/wiki/New-Language-Features-in-C%23-6#null-conditional-operators).
Supports Visual Studio 2015 ([link](https://visualstudiogallery.msdn.microsoft.com/a4445ad0-f97c-41f9-a148-eae225dcc8a5?SRC=Home))
@@ -47,3 +47,8 @@ call into an interpolated strings.

+### Use Null-Conditional Operators
+
+Identifies when invocations guarded with null-check if statements can be simplfied using null-conditional operators.
+
+
\ No newline at end of file