55using System . Collections . Generic ;
66using System . Collections . Immutable ;
77using System . Composition ;
8+ using System . Diagnostics ;
89using System . Diagnostics . CodeAnalysis ;
910using System . Linq ;
1011using System . Threading ;
@@ -67,54 +68,63 @@ protected override bool IsAsyncReturnType(ITypeSymbol type, KnownTaskTypes known
6768 => IsIAsyncEnumerableOrEnumerator ( type , knownTypes ) ||
6869 knownTypes . IsTaskLike ( type ) ;
6970
70- protected override SyntaxNode AddAsyncTokenAndFixReturnType (
71+ protected override SyntaxNode FixMethodSignature (
72+ bool addAsyncModifier ,
7173 bool keepVoid ,
7274 IMethodSymbol methodSymbol ,
7375 SyntaxNode node ,
74- KnownTaskTypes knownTypes ,
75- CancellationToken cancellationToken )
76+ KnownTaskTypes knownTypes )
7677 {
78+ // We currently fix signature without adding 'async' modifier
79+ // only for a partial definitions part of partial methods
80+ Debug . Assert ( addAsyncModifier || node is MethodDeclarationSyntax ) ;
81+
7782 return node switch
7883 {
79- MethodDeclarationSyntax method => FixMethod ( keepVoid , methodSymbol , method , knownTypes , cancellationToken ) ,
80- LocalFunctionStatementSyntax localFunction => FixLocalFunction ( keepVoid , methodSymbol , localFunction , knownTypes , cancellationToken ) ,
84+ MethodDeclarationSyntax method => FixMethod ( addAsyncModifier , keepVoid , methodSymbol , method , knownTypes ) ,
85+ LocalFunctionStatementSyntax localFunction => FixLocalFunction ( keepVoid , methodSymbol , localFunction , knownTypes ) ,
8186 AnonymousFunctionExpressionSyntax anonymous => FixAnonymousFunction ( anonymous ) ,
8287 _ => node ,
8388 } ;
8489 }
8590
8691 private static MethodDeclarationSyntax FixMethod (
92+ bool addAsyncModifier ,
8793 bool keepVoid ,
8894 IMethodSymbol methodSymbol ,
8995 MethodDeclarationSyntax method ,
90- KnownTaskTypes knownTypes ,
91- CancellationToken cancellationToken )
96+ KnownTaskTypes knownTypes )
9297 {
93- var ( newModifiers , newReturnType ) = AddAsyncModifierWithCorrectedTrivia (
94- method . Modifiers ,
95- FixMethodReturnType ( keepVoid , methodSymbol , method . ReturnType , knownTypes , cancellationToken ) ) ;
96- return method . WithReturnType ( newReturnType ) . WithModifiers ( newModifiers ) ;
98+ var fixedReturnType = FixMethodReturnType ( keepVoid , methodSymbol , method . ReturnType , knownTypes ) ;
99+
100+ if ( addAsyncModifier )
101+ {
102+ var ( newModifiers , newReturnType ) = AddAsyncModifierWithCorrectedTrivia ( method . Modifiers , fixedReturnType ) ;
103+ return method . WithReturnType ( newReturnType ) . WithModifiers ( newModifiers ) ;
104+ }
105+ else
106+ {
107+ return method . WithReturnType ( fixedReturnType ) ;
108+ }
97109 }
98110
99111 private static LocalFunctionStatementSyntax FixLocalFunction (
100112 bool keepVoid ,
101113 IMethodSymbol methodSymbol ,
102114 LocalFunctionStatementSyntax localFunction ,
103- KnownTaskTypes knownTypes ,
104- CancellationToken cancellationToken )
115+ KnownTaskTypes knownTypes )
105116 {
106117 var ( newModifiers , newReturnType ) = AddAsyncModifierWithCorrectedTrivia (
107118 localFunction . Modifiers ,
108- FixMethodReturnType ( keepVoid , methodSymbol , localFunction . ReturnType , knownTypes , cancellationToken ) ) ;
119+ FixMethodReturnType ( keepVoid , methodSymbol , localFunction . ReturnType , knownTypes ) ) ;
109120 return localFunction . WithReturnType ( newReturnType ) . WithModifiers ( newModifiers ) ;
110121 }
111122
112123 private static TypeSyntax FixMethodReturnType (
113124 bool keepVoid ,
114125 IMethodSymbol methodSymbol ,
115126 TypeSyntax returnTypeSyntax ,
116- KnownTaskTypes knownTypes ,
117- CancellationToken cancellationToken )
127+ KnownTaskTypes knownTypes )
118128 {
119129 var newReturnType = returnTypeSyntax . WithAdditionalAnnotations ( Formatter . Annotation ) ;
120130
@@ -128,13 +138,13 @@ private static TypeSyntax FixMethodReturnType(
128138 else
129139 {
130140 var returnType = methodSymbol . ReturnType ;
131- if ( IsIEnumerable ( returnType , knownTypes ) && IsIterator ( methodSymbol , cancellationToken ) )
141+ if ( IsIEnumerable ( returnType , knownTypes ) && methodSymbol . IsIterator )
132142 {
133143 newReturnType = knownTypes . IAsyncEnumerableOfTType is null
134144 ? MakeGenericType ( nameof ( IAsyncEnumerable < > ) , methodSymbol . ReturnType )
135145 : knownTypes . IAsyncEnumerableOfTType . Construct ( methodSymbol . ReturnType . GetTypeArguments ( ) [ 0 ] ) . GenerateTypeSyntax ( ) ;
136146 }
137- else if ( IsIEnumerator ( returnType , knownTypes ) && IsIterator ( methodSymbol , cancellationToken ) )
147+ else if ( IsIEnumerator ( returnType , knownTypes ) && methodSymbol . IsIterator )
138148 {
139149 newReturnType = knownTypes . IAsyncEnumeratorOfTType is null
140150 ? MakeGenericType ( nameof ( IAsyncEnumerator < > ) , methodSymbol . ReturnType )
@@ -164,9 +174,6 @@ static TypeSyntax MakeGenericType(string type, ITypeSymbol typeArgumentFrom)
164174 }
165175 }
166176
167- private static bool IsIterator ( IMethodSymbol method , CancellationToken cancellationToken )
168- => method . Locations . Any ( static ( loc , cancellationToken ) => loc . FindNode ( cancellationToken ) . ContainsYield ( ) , cancellationToken ) ;
169-
170177 private static bool IsIAsyncEnumerableOrEnumerator ( ITypeSymbol returnType , KnownTaskTypes knownTypes )
171178 => returnType . OriginalDefinition . Equals ( knownTypes . IAsyncEnumerableOfTType ) ||
172179 returnType . OriginalDefinition . Equals ( knownTypes . IAsyncEnumeratorOfTType ) ;
0 commit comments