diff --git a/Package.swift b/Package.swift index 3c9f629..ebef5b0 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,9 @@ import PackageDescription let package = Package( name: "swift-numerics-differentiable", + platforms: [ + .macOS(.v13), + ], products: [ .library( name: "NumericsDifferentiable", @@ -22,6 +25,12 @@ let package = Package( .package(url: "https://github.com/apple/swift-numerics", from: "1.0.2"), ], targets: [ + .executableTarget(name: "CodeGeneratorExecutable"), + .plugin( + name: "CodeGeneratorPlugin", + capability: .buildTool, + dependencies: ["CodeGeneratorExecutable"] + ), .target( name: "NumericsDifferentiable", dependencies: [ @@ -34,6 +43,9 @@ let package = Package( name: "RealModuleDifferentiable", dependencies: [ .product(name: "RealModule", package: "swift-numerics"), + ], + plugins: [ + "CodeGeneratorPlugin", ] ), .target( diff --git a/Plugins/CodeGeneratorPlugin.swift b/Plugins/CodeGeneratorPlugin.swift new file mode 100644 index 0000000..52e007d --- /dev/null +++ b/Plugins/CodeGeneratorPlugin.swift @@ -0,0 +1,36 @@ +import Foundation +import PackagePlugin + +@main +struct CodeGeneratorPlugin: BuildToolPlugin { + func createBuildCommands(context: PackagePlugin.PluginContext, target _: PackagePlugin.Target) async throws -> [PackagePlugin.Command] { + let output = context.pluginWorkDirectoryURL + + let floatingPointTypes: [String] = ["Float", "Double"] + let simdWidths = [2, 4, 8, 16, 32, 64] + + let outputFiles = floatingPointTypes.flatMap { floatingPointType in + simdWidths.flatMap { simdWidth in + [ + output.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions.swift"), + output.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions+Derivatives.swift"), + ] + } + [ + output.appending(component: "\(floatingPointType)+RealFunctions+Derivatives.swift"), + ] + } + [ + output.appending(component: "SIMD+RealFunctions.swift"), + ] + + return [ + .buildCommand( + displayName: "Generate Code", + executable: try context.tool(named: "CodeGeneratorExecutable").url, + arguments: [output.relativePath], + environment: [:], + inputFiles: [], + outputFiles: outputFiles + ), + ] + } +} diff --git a/README.md b/README.md index b06dbf9..36d6f69 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ # swift-numerics-differentiable +This package attempts to add more Differentiable capabilities to the existing [swift-numerics](https://github.com/apple/swift-numerics) package. Every target in swift-numerics has a Differentiable counterpart that `@_exported import`s the original module such that when you import `NumericsDifferentiable` you will also get all the contents of the `Numerics` module from swift-numerics. + +## RealModule Differentiable +- Registers derivatives to the `Float` and `Double` conformances to `ElementaryFunctions` and `RealFunctions` from swift-numerics. +- Conforms all `SIMD{n}` types to `ElementaryFunctions` and adds most of the protocol requirements from `RealFunctions` as well (`signGamma` is not implementable) +- Registers derivatives for all the provided `ElementaryFunctions` and `RealFunctions` implementations on SIMD{n} +- Tries to leverage Apple's `simd` framework to accelerate these operations where possible on Apple platforms. + ## Contributing ### Code Formatting This package makes use of [SwiftFormat](https://github.com/nicklockwood/SwiftFormat?tab=readme-ov-file#command-line-tool), which you can install diff --git a/Sources/CodeGeneratorExecutable/CodeGenerator.swift b/Sources/CodeGeneratorExecutable/CodeGenerator.swift new file mode 100644 index 0000000..ca56c60 --- /dev/null +++ b/Sources/CodeGeneratorExecutable/CodeGenerator.swift @@ -0,0 +1,91 @@ +import Foundation + +@main +struct CodeGenerator { + static func main() throws { + guard CommandLine.arguments.count == 2 else { + throw CodeGeneratorError.invalidArguments + } + // arguments[0] is the path to this command line tool + let output = URL(filePath: CommandLine.arguments[1]) + + // generate default implementations of RealFunctions for SIMD protocol + let realFunctionSIMDFileURL = output.appending(component: "SIMD+RealFunctions.swift") + let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension( + objectType: "SIMD", + type: "Self", + whereClause: true, + simdAccelerated: false + ) + try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8) + + let floatingPointTypes: [String] = ["Float", "Double"] + let simdWidths: [Int] = [2, 4, 8, 16, 32, 64] + + for floatingPointType in floatingPointTypes { + // Generator Derivatives for RealFunctions for floating point types + let realFunctionDerivativesFileURL = output.appending( + component: "\(floatingPointType)+RealFunctions+Derivatives.swift", + directoryHint: .notDirectory + ) + let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension( + type: floatingPointType, + floatingPointType: floatingPointType + ) + try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) + + for simdWidth in simdWidths { + let realFunctionFileURL = output.appending( + component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions.swift", + directoryHint: .notDirectory + ) + let simdType = "SIMD\(simdWidth)<\(floatingPointType)>" + + // no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those. + let simdAccelerated = simdWidth < 16 || (simdWidth == 16 && floatingPointType == "Float") + + // Generate RealFunctions implementations on concrete SIMD types to attach derivatives to + let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension( + objectType: simdType, + type: simdType, + whereClause: false, + simdAccelerated: simdAccelerated + ) + try realFunctionsExtensionCode.write(to: realFunctionFileURL, atomically: true, encoding: .utf8) + + // Generate RealFunctions derivatives for concrete SIMD types + let realFunctionDerivativesFileURL = output + .appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions+Derivatives.swift") + let type = "SIMD\(simdWidth)<\(floatingPointType)>" + let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension( + type: type, + floatingPointType: floatingPointType + ) + try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) + } + } + } +} + +struct RealFunction { + var name: String + var simdName: String? + var arguments: [Argument] + + struct Argument { + var name: String + var label: String? + var type: String? + } + + init(name: String, simdName: String? = nil, arguments: [Argument] = [.init(name: "x", label: "_")]) { + self.name = name + self.simdName = simdName + self.arguments = arguments + } +} + +enum CodeGeneratorError: Error { + case invalidArguments + case invalidData +} diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift new file mode 100644 index 0000000..8857a24 --- /dev/null +++ b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift @@ -0,0 +1,244 @@ +enum RealFunctionsDerivativesGenerator { + static func realFunctionsDerivativesExtension(type: String, floatingPointType: String) -> String { + """ + #if canImport(_Differentiation) + import _Differentiation + import RealModule + + // MARK: ElementaryFunctions derivatives + extension \(type) { + @derivative(of: exp) + @_transparent + public static func _vjpExp(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = exp(x) + return (value: value, pullback: { v in v * value }) + } + + @derivative(of: expMinusOne) + @_transparent + public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + return (value: expMinusOne(x), pullback: { v in v * exp(x) }) + } + + @derivative(of: cosh) + @_transparent + public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: cosh(x), pullback: { v in sinh(x) }) + } + + @derivative(of: sinh) + @_transparent + public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: sinh(x), pullback: { v in cosh(x) }) + } + + @derivative(of: tanh) + @_transparent + public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + ( + value: tanh(x), + pullback: { v in + let coshx = cosh(x) + return v / (coshx * coshx) + } + ) + } + + @derivative(of: cos) + @_transparent + public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: cos(x), pullback: { v in -v * sin(x) }) + } + + @derivative(of: sin) + @_transparent + public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: sin(x), pullback: { v in v * cos(x) }) + } + + @derivative(of: tan) + @_transparent + public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + ( + value: tan(x), + pullback: { v in + let cosx = cos(x) + return v / (cosx * cosx) + } + ) + } + + @derivative(of: log(_:)) + @_transparent + public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log(x), pullback: { v in v / x }) + } + + @derivative(of: acosh) + @_transparent + public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + // only valid for x > 1 + return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) }) + } + + @derivative(of: asinh) + @_transparent + public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: asinh(x), pullback: { v in v / sqrt(x * x + 1) }) + } + + @derivative(of: atanh) + @_transparent + public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: atanh(x), pullback: { v in v / (1 - x * x) }) + } + + @derivative(of: acos) + @_transparent + public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: acos(x), pullback: { v in -v / .sqrt(1 - x * x) }) + } + + @derivative(of: asin) + @_transparent + public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: asin(x), pullback: { v in v / .sqrt(1 - x * x) }) + } + + @derivative(of: atan) + @_transparent + public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: atan(x), pullback: { v in v / (x * x + 1) }) + } + + @derivative(of: log(onePlus:)) + @_transparent + public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log(onePlus: x), pullback: { v in v / (1 + x) }) + } + + @derivative(of: pow) + @_transparent + public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { + let value = pow(x, y) + // pullback wrt y is not defined for (x < 0) and (x = 0, y = 0) + return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) }) + } + + @derivative(of: pow) + @_transparent + public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) }) + } + + @derivative(of: sqrt) + @_transparent + public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = sqrt(x) + return (value: value, pullback: { v in v / (2 * value) }) + } + + @derivative(of: root) + @_transparent + public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = root(x, n) + return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) }) + } + } + + // MARK: RealFunctions derivatives + extension \(type) { + @derivative(of: erf) + @_transparent + public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) + } + + @derivative(of: erfc) + @_transparent + public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) + } + + @derivative(of: exp2) + @_transparent + public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = exp2(x) + return (value, { v in v * value * .log(2) }) + } + + @derivative(of: exp10) + @_transparent + public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = exp10(x) + return (value, { v in v * value * .log(10) }) + } + + @derivative(of: gamma) + @_transparent + public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + fatalError("unimplemented") + } + + @derivative(of: log2) + @_transparent + public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log2(x), pullback: { v in v / (.log(2) * x) }) + } + + @derivative(of: log10) + @_transparent + public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log10(x), pullback: { v in v / (.log(10) * x) }) + } + + @derivative(of: logGamma) + @_transparent + public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + fatalError("unimplemented") + } + + @derivative(of: atan2) + @_transparent + public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { + ( + value: atan2(y: y, x: x), + pullback: { v in + let c = x * x + y * y + return (v * x / c, -v * y / c) + } + ) + } + + @derivative(of: hypot) + @_transparent + public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { + ( + value: hypot(x, y), + pullback: { v in + let c = sqrt(x * x + y * y) + return (v * x / c, v * y / c) + } + ) + } + } + + // MARK: FloatingPoint functions derivatives + extension \(type) { + @derivative(of: abs) + @_transparent + public static func _vjpAbs(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + \({ + if type == floatingPointType { + "x < 0 ? (value: -x, pullback: { v in .zero - v }) : (value: x, pullback: { v in v })" + } + else { + "(value: abs(x), pullback: { v in v.replacing(with: -v, where: x .< .zero) })" + } + }()) + } + } + #endif + """ + } +} diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift new file mode 100644 index 0000000..dec0c5d --- /dev/null +++ b/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift @@ -0,0 +1,143 @@ +enum RealFunctionsGenerator { + static func realFunctionsExtension(objectType: String, type: String, whereClause: Bool, simdAccelerated: Bool) -> String { + let elementaryFunctions = [ + RealFunction(name: "exp", simdName: "exp"), + RealFunction(name: "expMinusOne", simdName: "expm1"), + RealFunction(name: "cosh", simdName: "cosh"), + RealFunction(name: "sinh", simdName: "sinh"), + RealFunction(name: "tanh", simdName: "tanh"), + RealFunction(name: "cos", simdName: "cos"), + RealFunction(name: "sin", simdName: "sin"), + RealFunction(name: "tan", simdName: "tan"), + RealFunction(name: "log", simdName: "log"), + RealFunction(name: "log", simdName: "log1p", arguments: [.init(name: "x", label: "onePlus")]), + RealFunction(name: "acosh", simdName: "acosh"), + RealFunction(name: "asinh", simdName: "asinh"), + RealFunction(name: "atanh", simdName: "atanh"), + RealFunction(name: "acos", simdName: "acos"), + RealFunction(name: "asin", simdName: "asin"), + RealFunction(name: "atan", simdName: "atan"), + RealFunction( + name: "pow", + simdName: "pow", + arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")] + ), + RealFunction(name: "pow", simdName: "pow", arguments: [.init(name: "x", label: "_"), .init(name: "y", label: "_")]), + RealFunction(name: "sqrt"), + RealFunction(name: "root", arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")]), + ] + + let realFunctions = [ + RealFunction(name: "atan2", simdName: "atan2", arguments: [.init(name: "y"), .init(name: "x")]), + RealFunction(name: "erf", simdName: "erf"), + RealFunction(name: "erfc", simdName: "erfc"), + RealFunction(name: "exp2", simdName: "exp2"), + RealFunction(name: "exp10", simdName: "exp10"), + RealFunction(name: "hypot", simdName: "hypot", arguments: [.init(name: "x", label: "_"), .init(name: "y", label: "_")]), + RealFunction(name: "gamma", simdName: "tgamma"), + RealFunction(name: "log2", simdName: "log2"), + RealFunction(name: "log10", simdName: "log10"), + RealFunction(name: "logGamma", simdName: "lgamma"), + ] + + let floatingPointFunctions = [ + RealFunction(name: "abs", simdName: "simd_abs"), + ] + + let elementaryFunctionsCode = elementaryFunctions.map { + realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) + }.joined(separator: "\n\n") + + let realFunctionsCode = realFunctions.map { + realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) + }.joined(separator: "\n\n") + + let floatingPointFunctionsCode = floatingPointFunctions.map { + realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) + }.joined(separator: "\n\n") + + let acceleratedHeader = """ + #if canImport(simd) + import simd + #endif + """ + + return """ + \(simdAccelerated ? acceleratedHeader : "") + import RealModule + + // MARK: ElementaryFunctions + extension \(objectType)\(whereClause ? " where Scalar: ElementaryFunctions" : "") { + \(elementaryFunctionsCode) + } + + // MARK: RealFunctions + extension \(objectType)\(whereClause ? " where Scalar: RealFunctions" : "") { + \(realFunctionsCode) + + // signGamma is missing here since we cannot return a SIMDX Otherwise we could also conform SIMD types to the RealFunctions protocol. + // @_transparent + // public static func signGamma(_ x: Self) -> SIMDX { + // fatalError() + // } + } + + // MARK: FloatingPointFunctions + extension \(objectType)\(whereClause ? " where Scalar: Real" : "") { + \(floatingPointFunctionsCode) + } + """ + } + + static func realFunctionTemplate(for function: RealFunction, type: String, simdAccelerated: Bool) -> String { + let interfaceArguments: String = function.arguments.map { + if let label = $0.label { + return "\(label) \($0.name): \($0.type ?? type)" + } + else { + return "\($0.name): \($0.type ?? type)" + } + }.joined(separator: ", ") + + let implementationArguments = function.arguments.map { + if let label = $0.label { + "\(label == "_" ? "" : "\(label): ")\($0.name)\($0.type == nil ? "[i]" : "")" + } + else { + "\($0.name): \($0.name)\($0.type == nil ? "[i]" : "")" + } + }.joined(separator: ", ") + + let regularImplementation = """ + @_transparent + public static func \(function.name)(\(interfaceArguments)) -> \(type) { + var v = Self() + for i in v.indices { + v[i] = .\(function.name)(\(implementationArguments)) + } + return v + } + """ + + guard simdAccelerated else { return regularImplementation } + + // we return the regular implementation if no simd equivalent is present (currently only true for sqrt and root) + guard let simdName = function.simdName else { return regularImplementation } + + let acceleratedArguments = function.arguments + .map { arg in "\(arg.type.map { _ in ".init(repeating: .init(\(arg.name)))" } ?? arg.name)" }.joined(separator: ", ") + + let acceleratedImplementation = """ + #if canImport(simd) + @_transparent + public static func \(function.name)(\(interfaceArguments)) -> \(type) { + simd.\(simdName)(\(acceleratedArguments)) + } + #else + \(regularImplementation) + #endif + """ + + return acceleratedImplementation + } +} diff --git a/Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift b/Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift new file mode 100644 index 0000000..7320d30 --- /dev/null +++ b/Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift @@ -0,0 +1,56 @@ +// These extensions are here due to sqrt(_:) being defined on the `Real` protocol and we currently can't define +// default derivatives for protocol requirements. So we have to create a concrete implementation for each type +// to attach the derivatives to. +extension Float { + @_transparent + public static func sqrt(_ x: Float) -> Float { + x.squareRoot() + } +} + +extension Double { + @_transparent + public static func sqrt(_ x: Double) -> Double { + x.squareRoot() + } +} + +#if !(os(macOS) || os(iOS) || os(tvOS) || os(watchOS)) +// This is a concrete version of the default implementation on the `Real` protocol +// this exists here so we can associate a derivative with this function on platforms that do not have a math library that provides exp10 +extension Float { + @_transparent + public static func exp10(_ x: Float) -> Float { + pow(10, x) + } +} + +extension Double { + @_transparent + public static func exp10(_ x: Double) -> Double { + pow(10, x) + } +} +#endif + +// Extensions so that SIMD can have a fallback `abs` implementation with similar api as the RealFunctions protocol +extension Float { + @_transparent + public static func abs(_ x: Float) -> Float { + Swift.abs(x) + } +} + +extension Double { + @_transparent + public static func abs(_ x: Double) -> Double { + Swift.abs(x) + } +} + +extension Real { + @_transparent + public static func abs(_ x: Self) -> Self { + Swift.abs(x) + } +} diff --git a/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift b/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift new file mode 100644 index 0000000..23520a0 --- /dev/null +++ b/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift @@ -0,0 +1,25 @@ +import RealModule + +#if canImport(_Differentiation) +import _Differentiation +#endif + +#if !canImport(_Differentiation) +// add `AdditiveArithmetic` conformance since this is only present in the _Differentiation module which is not present everywhere +extension SIMD2: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD4: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD8: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD16: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD32: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD64: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +#endif + +// Elementary functions are generated for the SIMD protocol and thus every concrete SIMD type can conform to `ElementaryFunctions` +// Actual implementation is generated by the CodeGeneratorPlugin +// Add actual conformances to `ElementaryFunctions` to individual SIMD types +extension SIMD2: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD4: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD8: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD16: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD32: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD64: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} diff --git a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift index 1939214..e49837e 100644 --- a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift +++ b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift @@ -1,6 +1,234 @@ +import _Differentiation @testable import RealModuleDifferentiable import Testing -@Test func example() async throws { - // Write your test here and use APIs like `#expect(...)` to check expected conditions. +@Suite +struct TestRegisteredDerivatives { + // These tests are more about the derivatives being correctly registered than the correct values. We're only checking the first value + // for simds since we run the computation on the same values and only checking the first result makes using `.isApproximatelyEqual(to:)` + // a lot easier. + @Test + func + testExp() + { + let vwpb = valueWithPullback(at: 2.0, of: Float.exp) + #expect(vwpb.value.isApproximatelyEqual(to: 7.38905609893065)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 7.38905609893065)) + } + + @Test + func testExpMinusOne() { + let vwpb = valueWithPullback(at: 2.0, of: Double.expMinusOne(_:)) + #expect(vwpb.value.isApproximatelyEqual(to: 6.38905609893065)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 7.38905609893065)) + } + + @Test + func testCosh() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 2.0), of: SIMD2.cosh) + #expect(vwpb.value[0].isApproximatelyEqual(to: 3.7621956)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 3.6268604)) + } + + @Test + func testSinh() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 2.0), of: SIMD4.sinh) + #expect(vwpb.value[0].isApproximatelyEqual(to: 3.6268604)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 3.7621956)) + } + + @Test + func testTanh() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 2.0), of: SIMD8.tanh) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.9640276)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.07065083)) + } + + @Test + func testCos() { + let vwpb = valueWithPullback(at: SIMD16(repeating: .pi / 2), of: SIMD16.cos) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -1)) + } + + @Test + func testSin() { + let vwpb = valueWithPullback(at: SIMD32(repeating: .pi / 2), of: SIMD32.sin) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + } + + @Test + func testTan() { + let vwpb = valueWithPullback(at: SIMD64(repeating: .pi / 4), of: SIMD64.tan) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 2)) + } + + @Test + func testLog() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 2), of: SIMD2.log(_:)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.6931471805599453)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.5)) + } + + @Test + func testLogOnePlus() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 3), of: SIMD4.log(onePlus:)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1.3862943611198906)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.25)) + } + + @Test + func testAcosh() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 2), of: SIMD8.acosh) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1.3169578969248166)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .sqrt(3))) + } + + @Test + func testAsinh() { + let vwpb = valueWithPullback(at: SIMD16(repeating: 2), of: SIMD16.asinh) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1.4436354751788103)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .sqrt(5))) + } + + @Test + func testAtanh() { + let vwpb = valueWithPullback(at: SIMD32(repeating: 0.5), of: SIMD32.atanh) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.5493061443340549)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 4 / 3)) + } + + @Test + func testAcos() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 1 / .sqrt(2)), of: SIMD64.acos) + #expect(vwpb.value[0].isApproximatelyEqual(to: .pi / 4)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -.sqrt(2))) + } + + @Test + func testAsin() { + let vwpb = valueWithPullback(at: 1 / .sqrt(2), of: Float.asin) + #expect(vwpb.value.isApproximatelyEqual(to: .pi / 4)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: .sqrt(2))) + } + + @Test + func testAtan() { + let vwpb = valueWithPullback(at: 0.5, of: Double.atan) + #expect(vwpb.value.isApproximatelyEqual(to: 0.46364760900080615)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 0.8)) + } + + @Test + func testPow() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 0.5), SIMD2(repeating: 2), of: SIMD2.pow(_:_:)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.25)) + let gradient = vwpb.pullback(.one) + #expect(gradient.0[0].isApproximatelyEqual(to: 1.0)) + #expect(gradient.1[0].isApproximatelyEqual(to: -0.1732868)) + } + + @Test + func testPowInt() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 0.5), of: { x in SIMD4.pow(x, 2) }) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.25)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1)) + } + + @Test + func testSqrt() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 4), of: SIMD8.sqrt) + #expect(vwpb.value[0].isApproximatelyEqual(to: 2)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.25)) + } + + @Test + func testRoot() { + let vwpb = valueWithPullback(at: SIMD16(repeating: 16), of: { x in SIMD16.root(x, 4) }) + #expect(vwpb.value[0].isApproximatelyEqual(to: 2)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / 32)) + } + + @Test + func testAtan2() { + let vwpb = valueWithPullback(at: SIMD32(repeating: 1), SIMD32(repeating: 0), of: SIMD32.atan2) + #expect(vwpb.value[0].isApproximatelyEqual(to: .pi / 2)) + let gradient = vwpb.pullback(.one) + #expect(gradient.0[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + #expect(gradient.1[0].isApproximatelyEqual(to: -1)) + } + + @Test + func testErf() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 0.5), of: SIMD64.erf) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.5204998778)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.8787825789354449)) + } + + @Test + func testErfc() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 0.5), of: SIMD2.erfc) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1 - 0.5204998778)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -0.8787825789354449)) + } + + @Test + func testExp2() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 2), of: SIMD4.exp2) + #expect(vwpb.value[0].isApproximatelyEqual(to: 4)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 4 * .log(2))) + } + + @Test + func testExp10() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 2), of: SIMD8.exp10) + #expect(vwpb.value[0].isApproximatelyEqual(to: 100)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 100 * .log(10))) + } + + @Test + func testHypot() { + let vwpb = valueWithPullback(at: SIMD16(repeating: 3), SIMD16(repeating: 4), of: SIMD16.hypot) + #expect(vwpb.value[0].isApproximatelyEqual(to: 5)) + let gradient = vwpb.pullback(.one) + #expect(gradient.0[0].isApproximatelyEqual(to: 3 / 5)) + #expect(gradient.1[0].isApproximatelyEqual(to: 4 / 5)) + } + + @Test(.disabled("derivative not implemented")) + func testGamma() { + let vwpb = valueWithPullback(at: SIMD32(repeating: 2), of: SIMD32.gamma) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + } + + @Test + func testLog2() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 2), of: SIMD64.log2) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .log(4))) + } + + @Test + func testLog10() { + let vwpb = valueWithPullback(at: 2.0, of: Float.log10) + #expect(vwpb.value.isApproximatelyEqual(to: .log(2) / .log(10))) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 1 / .log(100))) + } + + @Test(.disabled("derivative not implemented")) + func testLogGamma() { + let vwpb = valueWithPullback(at: 2, of: Double.logGamma) + #expect(vwpb.value.isApproximatelyEqual(to: 0)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + } + + @Test + func testAbs() { + let vwpb = valueWithPullback(at: SIMD2(repeating: -2), of: SIMD2.abs) + #expect(vwpb.value[0].isApproximatelyEqual(to: 2)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -1)) + } }