diff --git a/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift b/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift index afa6922..ee31aad 100644 --- a/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift +++ b/Sources/SpyableMacro/Factories/FunctionImplementationFactory.swift @@ -87,13 +87,13 @@ struct FunctionImplementationFactory { ) } - #if canImport(SwiftSyntax600) - let throwsSpecifier = protocolFunctionDeclaration.signature.effectSpecifiers?.throwsClause? - .throwsSpecifier - #else - let throwsSpecifier = protocolFunctionDeclaration.signature.effectSpecifiers? - .throwsSpecifier - #endif +#if canImport(SwiftSyntax600) + let throwsSpecifier = protocolFunctionDeclaration.signature.effectSpecifiers?.throwsClause? + .throwsSpecifier +#else + let throwsSpecifier = protocolFunctionDeclaration.signature.effectSpecifiers? + .throwsSpecifier +#endif if throwsSpecifier != nil { throwableErrorFactory.throwErrorExpression(variablePrefix: variablePrefix) @@ -123,6 +123,9 @@ struct FunctionImplementationFactory { // due to the bug: https://github.com/apple/swift-syntax/issues/2352 IfExprSyntax( conditions: ConditionElementListSyntax { +#if canImport(SwiftSyntax600) + typedThrowsCondition(protocolFunctionDeclaration: protocolFunctionDeclaration) // if working with typed throws, they are only supported from certain platforms on +#endif ConditionElementSyntax( condition: .expression( ExprSyntax( @@ -154,6 +157,68 @@ struct FunctionImplementationFactory { } ) } + +#if canImport(SwiftSyntax600) + private func typedThrowsCondition(protocolFunctionDeclaration: FunctionDeclSyntax) -> [ConditionElementSyntax] { + guard protocolFunctionDeclaration.signature.effectSpecifiers?.throwsClause?.type != nil else { + return [] + } + return [ + ConditionElementSyntax( + condition: .availability( + AvailabilityConditionSyntax( + availabilityKeyword: .poundAvailableToken(), + availabilityArguments: AvailabilityArgumentListSyntax { + AvailabilityArgumentSyntax( + argument: .availabilityVersionRestriction( + PlatformVersionSyntax( + platform: .identifier("iOS 18.0.0") // iOS 18+ + ) + ) + ) + AvailabilityArgumentSyntax( + argument: .availabilityVersionRestriction( + PlatformVersionSyntax( + platform: .identifier("macOS 15.0.0") // macOS 15+ + ) + ) + ) + AvailabilityArgumentSyntax( + argument: .availabilityVersionRestriction( + PlatformVersionSyntax( + platform: .identifier("tvOS 18.0.0") // tvOS 18+ + ) + ) + ) + AvailabilityArgumentSyntax( + argument: .availabilityVersionRestriction( + PlatformVersionSyntax( + platform: .identifier("watchOS 11.0.0") // watchOS 11+ + ) + ) + ) + AvailabilityArgumentSyntax( + argument: .availabilityVersionRestriction( + PlatformVersionSyntax( + platform: .identifier("macCatalyst 18.0.0") // macCatalyst 18+ + ) + ) + ) + AvailabilityArgumentSyntax( + argument: .availabilityVersionRestriction( + PlatformVersionSyntax( + platform: .identifier("*") + ) + ) + ) + } + ) + ) + ) + ] + } +#endif + } extension DeclModifierListSyntax { diff --git a/Sources/SpyableMacro/Factories/SpyFactory.swift b/Sources/SpyableMacro/Factories/SpyFactory.swift index 4634c94..dbd6429 100644 --- a/Sources/SpyableMacro/Factories/SpyFactory.swift +++ b/Sources/SpyableMacro/Factories/SpyFactory.swift @@ -158,15 +158,20 @@ struct SpyFactory { ) } - #if canImport(SwiftSyntax600) - let throwsSpecifier = functionDeclaration.signature.effectSpecifiers?.throwsClause? - .throwsSpecifier - #else - let throwsSpecifier = functionDeclaration.signature.effectSpecifiers?.throwsSpecifier - #endif +#if canImport(SwiftSyntax600) + let throwsSpecifier = functionDeclaration.signature.effectSpecifiers?.throwsClause? + .throwsSpecifier + let throwsType = functionDeclaration.signature.effectSpecifiers?.throwsClause?.type +#else + let throwsSpecifier = functionDeclaration.signature.effectSpecifiers?.throwsSpecifier + // this should lead to the legacy behaviour (e.g.) + // var fooThrowableError: (any Error)? // Any Error because throwsType == nil + // func foo(_ added: ((text: String) -> Void)?) throws(ExampleError) -> (() -> Int)? // function signature with typed error + let throwsType: TypeSyntax? = nil +#endif if throwsSpecifier != nil { - try throwableErrorFactory.variableDeclaration(variablePrefix: variablePrefix) + try throwableErrorFactory.variableDeclaration(variablePrefix: variablePrefix, typeSpecifier: throwsType?.description) } if let returnType = functionDeclaration.signature.returnClause?.type { @@ -251,7 +256,7 @@ extension SyntaxProtocol { fileprivate var removingLeadingSpaces: Self { with( \.leadingTrivia, - Trivia( + Trivia( pieces: leadingTrivia .filter { @@ -261,7 +266,7 @@ extension SyntaxProtocol { true } } - ) + ) ) } } diff --git a/Sources/SpyableMacro/Factories/ThrowableErrorFactory.swift b/Sources/SpyableMacro/Factories/ThrowableErrorFactory.swift index 770c6d2..26bc3c7 100644 --- a/Sources/SpyableMacro/Factories/ThrowableErrorFactory.swift +++ b/Sources/SpyableMacro/Factories/ThrowableErrorFactory.swift @@ -29,12 +29,12 @@ import SwiftSyntaxBuilder /// your tests. You can use it to simulate different scenarios and verify that your code handles /// errors correctly. struct ThrowableErrorFactory { - func variableDeclaration(variablePrefix: String) throws -> VariableDeclSyntax { - try VariableDeclSyntax( - """ - var \(variableIdentifier(variablePrefix: variablePrefix)): (any Error)? - """ - ) + func variableDeclaration(variablePrefix: String, typeSpecifier: String? = nil) throws -> VariableDeclSyntax { + if let typeSpecifier { + return try typedVariableDeclaration(variablePrefix: variablePrefix, typeSpecifier: typeSpecifier) + } else { + return try untypedVariableDeclaration(variablePrefix: variablePrefix) + } } func throwErrorExpression(variablePrefix: String) -> ExprSyntax { @@ -47,6 +47,22 @@ struct ThrowableErrorFactory { ) } + private func untypedVariableDeclaration(variablePrefix: String) throws -> VariableDeclSyntax { + return try VariableDeclSyntax( + """ + var \(variableIdentifier(variablePrefix: variablePrefix))\(TokenSyntax.colonToken()) \(TokenSyntax.leftParenToken())any Error\(TokenSyntax.rightParenToken())\(TokenSyntax.postfixQuestionMarkToken()) + """ + ) + } + + private func typedVariableDeclaration(variablePrefix: String, typeSpecifier: String) throws -> VariableDeclSyntax { + try VariableDeclSyntax( + """ + var \(variableIdentifier(variablePrefix: variablePrefix))\(TokenSyntax.colonToken()) \(TokenSyntax.identifier(typeSpecifier))\(TokenSyntax.postfixQuestionMarkToken()) + """ + ) + } + private func variableIdentifier(variablePrefix: String) -> TokenSyntax { TokenSyntax.identifier(variablePrefix + "ThrowableError") } diff --git a/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift b/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift index 5031f90..3e8c9df 100644 --- a/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_ClosureFactory.swift @@ -39,6 +39,16 @@ final class UT_ClosureFactory: XCTestCase { ) } +#if canImport(SwiftSyntax600) + func testVariableDeclarationThrowsTyped() throws { + try assertProtocolFunction( + withFunctionDeclaration: "func _ignore_() throws(ExampleError)", + prefixForVariable: "_prefix_", + expectingVariableDeclaration: "var _prefix_Closure: (() throws(ExampleError) -> Void)?" + ) + } +#endif + func testVariableDeclarationReturnValue() throws { try assertProtocolFunction( withFunctionDeclaration: "func _ignore_() -> Data", diff --git a/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift b/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift index 271a754..3724efc 100644 --- a/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_FunctionImplementationFactory.swift @@ -95,6 +95,32 @@ final class UT_FunctionImplementationFactory: XCTestCase { ) } +#if canImport(SwiftSyntax600) + func testDeclarationReturnValueAsyncThrowsTyped() throws { + try assertProtocolFunction( + withFunctionDeclaration: """ + func foo(_ bar: String) async throws(ExampleError) -> (text: String, tuple: (count: Int?, Date)) + """, + prefixForVariable: "_prefix_", + expectingFunctionDeclaration: """ + func foo(_ bar: String) async throws(ExampleError) -> (text: String, tuple: (count: Int?, Date)) { + _prefix_CallsCount += 1 + _prefix_ReceivedBar = (bar) + _prefix_ReceivedInvocations.append((bar)) + if let _prefix_ThrowableError { + throw _prefix_ThrowableError + } + if #available(iOS 18.0.0, macOS 15.0.0, tvOS 18.0.0, watchOS 11.0.0, macCatalyst 18.0.0, *), _prefix_Closure != nil { + return try await _prefix_Closure!(bar) + } else { + return _prefix_ReturnValue + } + } + """ + ) + } +#endif + func testDeclarationWithMutatingKeyword() throws { try assertProtocolFunction( withFunctionDeclaration: "mutating func foo()", diff --git a/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift b/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift index 5b3aaf9..9ccfeb8 100644 --- a/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_SpyFactory.swift @@ -434,6 +434,46 @@ final class UT_SpyFactory: XCTestCase { ) } +#if canImport(SwiftSyntax600) + func testDeclarationThrowsTyped() throws { + try assertProtocol( + withDeclaration: """ + protocol ServiceProtocol { + func foo(_ added: ((text: String) -> Void)?) throws(ExampleError) -> (() -> Int)? + } + """, + expectingClassDeclaration: """ + class ServiceProtocolSpy: ServiceProtocol, @unchecked Sendable { + init() { + } + var fooCallsCount = 0 + var fooCalled: Bool { + return fooCallsCount > 0 + } + var fooReceivedAdded: ((text: String) -> Void)? + var fooReceivedInvocations: [((text: String) -> Void)?] = [] + var fooThrowableError: ExampleError? + var fooReturnValue: (() -> Int)? + var fooClosure: ((((text: String) -> Void)?) throws(ExampleError) -> (() -> Int)?)? + func foo(_ added: ((text: String) -> Void)?) throws(ExampleError) -> (() -> Int)? { + fooCallsCount += 1 + fooReceivedAdded = (added) + fooReceivedInvocations.append((added)) + if let fooThrowableError { + throw fooThrowableError + } + if #available(iOS 18.0.0, macOS 15.0.0, tvOS 18.0.0, watchOS 11.0.0, macCatalyst 18.0.0, *), fooClosure != nil { + return try fooClosure!(added) + } else { + return fooReturnValue + } + } + } + """ + ) + } +#endif + func testDeclarationReturnsExistential() throws { try assertProtocol( withDeclaration: """ diff --git a/Tests/SpyableMacroTests/Factories/UT_ThrowableErrorFactory.swift b/Tests/SpyableMacroTests/Factories/UT_ThrowableErrorFactory.swift index 0cb9439..c56b570 100644 --- a/Tests/SpyableMacroTests/Factories/UT_ThrowableErrorFactory.swift +++ b/Tests/SpyableMacroTests/Factories/UT_ThrowableErrorFactory.swift @@ -20,6 +20,20 @@ final class UT_ThrowableErrorFactory: XCTestCase { ) } + func testTypedVariableDeclaration() throws { + let variablePrefix = "functionName" + let typeSpecifier = "ExampleError" + + let result = try ThrowableErrorFactory().variableDeclaration(variablePrefix: variablePrefix, typeSpecifier: typeSpecifier) + + assertBuildResult( + result, + """ + var functionNameThrowableError: ExampleError? + """ + ) + } + // MARK: - Throw Error Expression func testThrowErrorExpression() { diff --git a/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift b/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift index bc404a9..e0b4b88 100644 --- a/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift +++ b/Tests/SpyableMacroTests/Macro/UT_SpyableMacro.swift @@ -209,6 +209,52 @@ final class UT_SpyableMacro: XCTestCase { ) } +#if canImport(SwiftSyntax600) + func testMacroWithTypedThrow() { + let protocolDeclaration = """ + public protocol ServiceProtocol { + func fetchConfigTypedThrow() async throws(ConfigError) -> [String: String] + } + """ + + assertMacroExpansion( + """ + @Spyable + \(protocolDeclaration) + """, + expandedSource: """ + + \(protocolDeclaration) + + public class ServiceProtocolSpy: ServiceProtocol, @unchecked Sendable { + public init() { + } + public var fetchConfigTypedThrowCallsCount = 0 + public var fetchConfigTypedThrowCalled: Bool { + return fetchConfigTypedThrowCallsCount > 0 + } + public var fetchConfigTypedThrowThrowableError: ConfigError? + public var fetchConfigTypedThrowReturnValue: [String: String]! + public var fetchConfigTypedThrowClosure: (() async throws(ConfigError) -> [String: String])? + public + func fetchConfigTypedThrow() async throws(ConfigError) -> [String: String] { + fetchConfigTypedThrowCallsCount += 1 + if let fetchConfigTypedThrowThrowableError { + throw fetchConfigTypedThrowThrowableError + } + if #available(iOS 18.0.0, macOS 15.0.0, tvOS 18.0.0, watchOS 11.0.0, macCatalyst 18.0.0, *), fetchConfigTypedThrowClosure != nil { + return try await fetchConfigTypedThrowClosure!() + } else { + return fetchConfigTypedThrowReturnValue + } + } + } + """, + macros: sut + ) + } +#endif + // MARK: - `behindPreprocessorFlag` argument func testMacroWithNoArgument() {