Skip to content

Commit a8d4395

Browse files
Correctly change return type of partial method definition part when making method async (#79478)
2 parents 934c2a1 + dc2da5a commit a8d4395

File tree

4 files changed

+65
-36
lines changed

4 files changed

+65
-36
lines changed

src/Analyzers/CSharp/CodeFixes/MakeMethodAsynchronous/CSharpMakeMethodAsynchronousCodeFixProvider.cs

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Collections.Immutable;
77
using System.Composition;
8+
using System.Diagnostics;
89
using System.Diagnostics.CodeAnalysis;
910
using System.Linq;
1011
using System.Threading;
@@ -67,54 +68,63 @@ protected override bool IsAsyncReturnType(ITypeSymbol type, KnownTaskTypes known
6768
=> IsIAsyncEnumerableOrEnumerator(type, knownTypes) ||
6869
knownTypes.IsTaskLike(type);
6970

70-
protected override SyntaxNode AddAsyncTokenAndFixReturnType(
71+
protected override SyntaxNode FixMethodSignature(
72+
bool addAsyncModifier,
7173
bool keepVoid,
7274
IMethodSymbol methodSymbol,
7375
SyntaxNode node,
74-
KnownTaskTypes knownTypes,
75-
CancellationToken cancellationToken)
76+
KnownTaskTypes knownTypes)
7677
{
78+
// We currently fix signature without adding 'async' modifier
79+
// only for a partial definitions part of partial methods
80+
Debug.Assert(addAsyncModifier || node is MethodDeclarationSyntax);
81+
7782
return node switch
7883
{
79-
MethodDeclarationSyntax method => FixMethod(keepVoid, methodSymbol, method, knownTypes, cancellationToken),
80-
LocalFunctionStatementSyntax localFunction => FixLocalFunction(keepVoid, methodSymbol, localFunction, knownTypes, cancellationToken),
84+
MethodDeclarationSyntax method => FixMethod(addAsyncModifier, keepVoid, methodSymbol, method, knownTypes),
85+
LocalFunctionStatementSyntax localFunction => FixLocalFunction(keepVoid, methodSymbol, localFunction, knownTypes),
8186
AnonymousFunctionExpressionSyntax anonymous => FixAnonymousFunction(anonymous),
8287
_ => node,
8388
};
8489
}
8590

8691
private static MethodDeclarationSyntax FixMethod(
92+
bool addAsyncModifier,
8793
bool keepVoid,
8894
IMethodSymbol methodSymbol,
8995
MethodDeclarationSyntax method,
90-
KnownTaskTypes knownTypes,
91-
CancellationToken cancellationToken)
96+
KnownTaskTypes knownTypes)
9297
{
93-
var (newModifiers, newReturnType) = AddAsyncModifierWithCorrectedTrivia(
94-
method.Modifiers,
95-
FixMethodReturnType(keepVoid, methodSymbol, method.ReturnType, knownTypes, cancellationToken));
96-
return method.WithReturnType(newReturnType).WithModifiers(newModifiers);
98+
var fixedReturnType = FixMethodReturnType(keepVoid, methodSymbol, method.ReturnType, knownTypes);
99+
100+
if (addAsyncModifier)
101+
{
102+
var (newModifiers, newReturnType) = AddAsyncModifierWithCorrectedTrivia(method.Modifiers, fixedReturnType);
103+
return method.WithReturnType(newReturnType).WithModifiers(newModifiers);
104+
}
105+
else
106+
{
107+
return method.WithReturnType(fixedReturnType);
108+
}
97109
}
98110

99111
private static LocalFunctionStatementSyntax FixLocalFunction(
100112
bool keepVoid,
101113
IMethodSymbol methodSymbol,
102114
LocalFunctionStatementSyntax localFunction,
103-
KnownTaskTypes knownTypes,
104-
CancellationToken cancellationToken)
115+
KnownTaskTypes knownTypes)
105116
{
106117
var (newModifiers, newReturnType) = AddAsyncModifierWithCorrectedTrivia(
107118
localFunction.Modifiers,
108-
FixMethodReturnType(keepVoid, methodSymbol, localFunction.ReturnType, knownTypes, cancellationToken));
119+
FixMethodReturnType(keepVoid, methodSymbol, localFunction.ReturnType, knownTypes));
109120
return localFunction.WithReturnType(newReturnType).WithModifiers(newModifiers);
110121
}
111122

112123
private static TypeSyntax FixMethodReturnType(
113124
bool keepVoid,
114125
IMethodSymbol methodSymbol,
115126
TypeSyntax returnTypeSyntax,
116-
KnownTaskTypes knownTypes,
117-
CancellationToken cancellationToken)
127+
KnownTaskTypes knownTypes)
118128
{
119129
var newReturnType = returnTypeSyntax.WithAdditionalAnnotations(Formatter.Annotation);
120130

@@ -128,13 +138,13 @@ private static TypeSyntax FixMethodReturnType(
128138
else
129139
{
130140
var returnType = methodSymbol.ReturnType;
131-
if (IsIEnumerable(returnType, knownTypes) && IsIterator(methodSymbol, cancellationToken))
141+
if (IsIEnumerable(returnType, knownTypes) && methodSymbol.IsIterator)
132142
{
133143
newReturnType = knownTypes.IAsyncEnumerableOfTType is null
134144
? MakeGenericType(nameof(IAsyncEnumerable<>), methodSymbol.ReturnType)
135145
: knownTypes.IAsyncEnumerableOfTType.Construct(methodSymbol.ReturnType.GetTypeArguments()[0]).GenerateTypeSyntax();
136146
}
137-
else if (IsIEnumerator(returnType, knownTypes) && IsIterator(methodSymbol, cancellationToken))
147+
else if (IsIEnumerator(returnType, knownTypes) && methodSymbol.IsIterator)
138148
{
139149
newReturnType = knownTypes.IAsyncEnumeratorOfTType is null
140150
? MakeGenericType(nameof(IAsyncEnumerator<>), methodSymbol.ReturnType)
@@ -164,9 +174,6 @@ static TypeSyntax MakeGenericType(string type, ITypeSymbol typeArgumentFrom)
164174
}
165175
}
166176

167-
private static bool IsIterator(IMethodSymbol method, CancellationToken cancellationToken)
168-
=> method.Locations.Any(static (loc, cancellationToken) => loc.FindNode(cancellationToken).ContainsYield(), cancellationToken);
169-
170177
private static bool IsIAsyncEnumerableOrEnumerator(ITypeSymbol returnType, KnownTaskTypes knownTypes)
171178
=> returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumerableOfTType) ||
172179
returnType.OriginalDefinition.Equals(knownTypes.IAsyncEnumeratorOfTType);

src/Analyzers/CSharp/Tests/MakeMethodAsynchronous/MakeMethodAsynchronousTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ partial void M()
14061406
14071407
public partial class C
14081408
{
1409-
partial void MAsync();
1409+
partial Task MAsync();
14101410
}
14111411
14121412
public partial class C
@@ -1440,7 +1440,7 @@ public partial void M()
14401440
14411441
public partial class C
14421442
{
1443-
public partial void MAsync();
1443+
public partial Task MAsync();
14441444
}
14451445
14461446
public partial class C

src/Analyzers/Core/CodeFixes/MakeMethodAsynchronous/AbstractMakeMethodAsynchronousCodeFixProvider.cs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Threading.Tasks;
88
using Microsoft.CodeAnalysis.CodeActions;
99
using Microsoft.CodeAnalysis.CodeFixes;
10+
using Microsoft.CodeAnalysis.Editing;
1011
using Microsoft.CodeAnalysis.LanguageService;
1112
using Microsoft.CodeAnalysis.Rename;
1213
using Microsoft.CodeAnalysis.Shared.Extensions;
@@ -24,8 +25,12 @@ internal abstract partial class AbstractMakeMethodAsynchronousCodeFixProvider :
2425

2526
protected abstract bool IsAsyncReturnType(ITypeSymbol type, KnownTaskTypes knownTypes);
2627

27-
protected abstract SyntaxNode AddAsyncTokenAndFixReturnType(
28-
bool keepVoid, IMethodSymbol methodSymbol, SyntaxNode node, KnownTaskTypes knownTypes, CancellationToken cancellationToken);
28+
protected abstract SyntaxNode FixMethodSignature(
29+
bool addAsyncModifier,
30+
bool keepVoid,
31+
IMethodSymbol methodSymbol,
32+
SyntaxNode node,
33+
KnownTaskTypes knownTypes);
2934

3035
public override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer;
3136

@@ -119,7 +124,7 @@ private async Task<Solution> FixNodeAsync(
119124

120125
return NeedsRename()
121126
? await RenameThenAddAsyncTokenAsync(keepVoid, document, node, methodSymbol, knownTypes, cancellationToken).ConfigureAwait(false)
122-
: await AddAsyncTokenAsync(keepVoid, document, methodSymbol, knownTypes, node, cancellationToken).ConfigureAwait(false);
127+
: await FixRelatedSignaturesAsync(keepVoid, document, methodSymbol, knownTypes, node, cancellationToken).ConfigureAwait(false);
123128

124129
bool NeedsRename()
125130
{
@@ -174,26 +179,39 @@ private async Task<Solution> RenameThenAddAsyncTokenAsync(
174179
{
175180
var semanticModel = await newDocument.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
176181
var newMethod = (IMethodSymbol)semanticModel.GetRequiredDeclaredSymbol(newNode, cancellationToken);
177-
return await AddAsyncTokenAsync(keepVoid, newDocument, newMethod, knownTypes, newNode, cancellationToken).ConfigureAwait(false);
182+
return await FixRelatedSignaturesAsync(keepVoid, newDocument, newMethod, knownTypes, newNode, cancellationToken).ConfigureAwait(false);
178183
}
179184

180185
return newSolution;
181186
}
182187

183-
private async Task<Solution> AddAsyncTokenAsync(
188+
private async Task<Solution> FixRelatedSignaturesAsync(
184189
bool keepVoid,
185190
Document document,
186191
IMethodSymbol methodSymbol,
187192
KnownTaskTypes knownTypes,
188193
SyntaxNode node,
189194
CancellationToken cancellationToken)
190195
{
191-
var newNode = AddAsyncTokenAndFixReturnType(keepVoid, methodSymbol, node, knownTypes, cancellationToken);
196+
var newNode = FixMethodSignature(addAsyncModifier: true, keepVoid, methodSymbol, node, knownTypes);
197+
198+
var solution = document.Project.Solution;
199+
var solutionEditor = new SolutionEditor(solution);
200+
var mainDocumentEditor = await solutionEditor.GetDocumentEditorAsync(document.Id, cancellationToken).ConfigureAwait(false);
192201

193-
var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
194-
var newRoot = root.ReplaceNode(node, newNode);
202+
mainDocumentEditor.ReplaceNode(node, newNode);
203+
204+
if (!keepVoid && methodSymbol.PartialDefinitionPart is { Locations: [{ } partialDefinitionLocation] })
205+
{
206+
var partialDefinitionNode = partialDefinitionLocation.FindNode(cancellationToken);
207+
var fixedPartialDefinitionNode = FixMethodSignature(addAsyncModifier: false, keepVoid, methodSymbol, partialDefinitionNode, knownTypes);
208+
209+
var partialDefinitionDocument = solution.GetDocument(partialDefinitionNode.SyntaxTree);
210+
Contract.ThrowIfNull(partialDefinitionDocument);
211+
var partialDefinitionDocumentEditor = await solutionEditor.GetDocumentEditorAsync(partialDefinitionDocument.Id, cancellationToken).ConfigureAwait(false);
212+
partialDefinitionDocumentEditor.ReplaceNode(partialDefinitionNode, fixedPartialDefinitionNode);
213+
}
195214

196-
var newDocument = document.WithSyntaxRoot(newRoot);
197-
return newDocument.Project.Solution;
215+
return solutionEditor.GetChangedSolution();
198216
}
199217
}

src/Analyzers/VisualBasic/CodeFixes/MakeMethodAsynchronous/VisualBasicMakeMethodAsynchronousCodeFixProvider.vb

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,16 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.MakeMethodAsynchronous
5757
Return knownTypes.IsTaskLike(type)
5858
End Function
5959

60-
Protected Overrides Function AddAsyncTokenAndFixReturnType(
60+
Protected Overrides Function FixMethodSignature(
61+
addAsyncModifier As Boolean,
6162
keepVoid As Boolean,
6263
methodSymbolOpt As IMethodSymbol,
6364
node As SyntaxNode,
64-
knownTypes As KnownTaskTypes,
65-
cancellationToken As CancellationToken) As SyntaxNode
65+
knownTypes As KnownTaskTypes) As SyntaxNode
66+
67+
' This flag can only be false when updating partial definition method signature.
68+
' Since partial methods cannot be async in VB, it cannot be false here
69+
Debug.Assert(addAsyncModifier)
6670

6771
If node.IsKind(SyntaxKind.SingleLineSubLambdaExpression) OrElse
6872
node.IsKind(SyntaxKind.SingleLineFunctionLambdaExpression) Then

0 commit comments

Comments
 (0)