From 4662bca88b44352ef06edc7fe8984b95f97710f9 Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Thu, 1 May 2025 11:42:31 -0600 Subject: [PATCH 1/6] Add derivatives for RealFunctions Also adds RealFunctions conformance to SIMD types and their derivatives --- Package.swift | 12 + Plugins/CodeGeneratorPlugin.swift | 36 +++ .../CodeGenerator.swift | 81 +++++++ .../RealFunctionsDerivativesGenerator.swift | 212 ++++++++++++++++++ .../RealFunctionsGenerator.swift | 136 +++++++++++ ...loatingPoint+ConcreteImplementations.swift | 56 +++++ .../SIMD+ElementaryFunctions.swift | 25 +++ 7 files changed, 558 insertions(+) create mode 100644 Plugins/CodeGeneratorPlugin.swift create mode 100644 Sources/CodeGeneratorExecutable/CodeGenerator.swift create mode 100644 Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift create mode 100644 Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift create mode 100644 Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift create mode 100644 Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift diff --git a/Package.swift b/Package.swift index 3c9f629..ebef5b0 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,9 @@ import PackageDescription let package = Package( name: "swift-numerics-differentiable", + platforms: [ + .macOS(.v13), + ], products: [ .library( name: "NumericsDifferentiable", @@ -22,6 +25,12 @@ let package = Package( .package(url: "https://github.com/apple/swift-numerics", from: "1.0.2"), ], targets: [ + .executableTarget(name: "CodeGeneratorExecutable"), + .plugin( + name: "CodeGeneratorPlugin", + capability: .buildTool, + dependencies: ["CodeGeneratorExecutable"] + ), .target( name: "NumericsDifferentiable", dependencies: [ @@ -34,6 +43,9 @@ let package = Package( name: "RealModuleDifferentiable", dependencies: [ .product(name: "RealModule", package: "swift-numerics"), + ], + plugins: [ + "CodeGeneratorPlugin", ] ), .target( diff --git a/Plugins/CodeGeneratorPlugin.swift b/Plugins/CodeGeneratorPlugin.swift new file mode 100644 index 0000000..63fbab4 --- /dev/null +++ b/Plugins/CodeGeneratorPlugin.swift @@ -0,0 +1,36 @@ +import Foundation +import PackagePlugin + +@main +struct CodeGeneratorPlugin: BuildToolPlugin { + func createBuildCommands(context: PackagePlugin.PluginContext, target _: PackagePlugin.Target) async throws -> [PackagePlugin.Command] { + let output = context.pluginWorkDirectoryURL + + let floatingPointTypes: [String] = ["Float", "Double"] + let simdSizes = [2, 4, 8, 16, 32, 64] + + let outputFiles = floatingPointTypes.flatMap { floatingPointType in + simdSizes.flatMap { simdSize in + [ + output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift"), + output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift"), + ] + } + [ + output.appending(component: "\(floatingPointType)+RealFunctions+Derivatives.swift"), + ] + } + [ + output.appending(component: "SIMD+RealFunctions.swift"), + ] + + return [ + .buildCommand( + displayName: "Generate Code", + executable: try context.tool(named: "CodeGeneratorExecutable").url, + arguments: [output.relativePath], + environment: [:], + inputFiles: [], + outputFiles: outputFiles + ) + ] + } +} diff --git a/Sources/CodeGeneratorExecutable/CodeGenerator.swift b/Sources/CodeGeneratorExecutable/CodeGenerator.swift new file mode 100644 index 0000000..ea0b46f --- /dev/null +++ b/Sources/CodeGeneratorExecutable/CodeGenerator.swift @@ -0,0 +1,81 @@ +import Foundation + +@main +struct CodeGenerator { + static func main() throws { + // Use swift-argument-parser or just CommandLine, here we just imply that 2 paths are passed in: input and output + guard CommandLine.arguments.count == 2 else { + throw CodeGeneratorError.invalidArguments + } + // arguments[0] is the path to this command line tool + let output = URL(filePath: CommandLine.arguments[1]) + + // generate default implementations of RealFunctions for SIMD protocol + let realFunctionSIMDFileURL = output.appending(component: "SIMD+RealFunctions.swift") + let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension(objectType: "SIMD", type: "Self", whereClause: true, simdAccelerated: false) + try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8) + + let floatingPointTypes: [String] = ["Float", "Double"] + let simdSizes: [Int] = [2, 4, 8, 16, 32, 64] + + for floatingPointType in floatingPointTypes { + // Generator Derivatives for RealFunctions for floating point types + let realFunctionDerivativesFileURL = output.appending( + component: "\(floatingPointType)+RealFunctions+Derivatives.swift", + directoryHint: .notDirectory + ) + let type = floatingPointType + let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType) + try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) + + for simdSize in simdSizes { + let realFunctionFileURL = output.appending( + component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift", + directoryHint: .notDirectory + ) + let simdType = "SIMD\(simdSize)<\(floatingPointType)>" + + // no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those. + var simdAccelerated: Bool + if simdSize > 16 || (simdSize == 16 && floatingPointType == "Double") { + simdAccelerated = false + } else { + simdAccelerated = true + } + + // Generate RealFunctions implementations on concrete SIMD types to attach derivatives to + let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension(objectType: simdType, type: simdType, whereClause: false, simdAccelerated: simdAccelerated) + try realFunctionsExtensionCode.write(to: realFunctionFileURL, atomically: true, encoding: .utf8) + + // Generate RealFunctions derivatives for concrete SIMD types + let realFunctionDerivativesFileURL = output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift") + let type = "SIMD\(simdSize)<\(floatingPointType)>" + let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType) + try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) + } + } + } +} + +struct RealFunction { + var name: String + var simdName: String? + var arguments: [Argument] + + struct Argument { + var name: String + var label: String? + var type: String? = nil + } + + init(name: String, simdName: String? = nil, arguments: [Argument] = [.init(name: "x", label: "_")]) { + self.name = name + self.simdName = simdName + self.arguments = arguments + } +} + +enum CodeGeneratorError: Error { + case invalidArguments + case invalidData +} diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift new file mode 100644 index 0000000..448b0c1 --- /dev/null +++ b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift @@ -0,0 +1,212 @@ +struct RealFunctionsDerivativesGenerator { + static func realFunctionsDerivativesExtension(type: String, floatingPointType: String) -> String { + """ + #if canImport(_Differentiation) + import _Differentiation + import RealModule + + // MARK: ElementaryFunctions derivatives + extension \(type) { + @derivative(of: exp) + public static func _vjpExp(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = exp(x) + return (value: value, pullback: { v in v * value }) + } + + @derivative(of: expMinusOne) + public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + return (value: expMinusOne(x), pullback: { v in v * exp(x) }) + } + + @derivative(of: cosh) + public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: cosh(x), pullback: { v in sinh(x) }) + } + + @derivative(of: sinh) + public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: sinh(x), pullback: { v in cosh(x) }) + } + + @derivative(of: tanh) + public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + ( + value: tanh(x), + pullback: { v in + let coshx = cosh(x) + return v / (coshx * coshx) + } + ) + } + + @derivative(of: cos) + public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: cos(x), pullback: { v in -v * sin(x) }) + } + + @derivative(of: sin) + public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: sin(x), pullback: { v in v * cos(x) }) + } + + @derivative(of: tan) + public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + ( + value: tan(x), + pullback: { v in + let cosx = cos(x) + return v / (cosx * cosx) + } + ) + } + + @derivative(of: log(_:)) + public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log(x), pullback: { v in v / x }) + } + + @derivative(of: acosh) + public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + // only valid for x > 1 + return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) }) + } + + @derivative(of: asinh) + public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: asinh(x), pullback: { v in v / sqrt(x * x + 1) }) + } + + @derivative(of: atanh) + public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: atanh(x), pullback: { v in v / (1 - x * x) }) + } + + @derivative(of: acos) + public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: acos(x), pullback: { v in -v / (1 - x * x) }) + } + + @derivative(of: asin) + public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: asin(x), pullback: { v in v / (1 - x * x) }) + } + + @derivative(of: atan) + public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: atan(x), pullback: { v in v / (x * x + 1) }) + } + + @derivative(of: log(onePlus:)) + public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log(onePlus: x), pullback: { v in v / (1 + x) }) + } + + @derivative(of: pow) + public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { + let value = pow(x, y) + // pullback wrt y is not defined for (x < 0) and (x = 0, y = 0) + return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) }) + } + + @derivative(of: pow) + public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) }) + } + + @derivative(of: sqrt) + public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = sqrt(x) + return (value: value, pullback: { v in v / (2 * value) }) + } + + @derivative(of: root) + public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = root(x, n) + return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) }) + } + } + + // MARK: RealFunctions derivatives + extension \(type) { + @derivative(of: erf) + public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) + } + + @derivative(of: erfc) + public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) + } + + @derivative(of: exp2) + public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = exp2(x) + return (value, { v in v * value * .log(2) }) + } + + @derivative(of: exp10) + public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + let value = exp10(x) + return (value, { v in v * value * .log(10) }) + } + + @derivative(of: gamma) + public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + fatalError("unimplemented") + } + + @derivative(of: log2) + public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log2(x), pullback: { v in v / (.log(2) * x) }) + } + + @derivative(of: log10) + public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + (value: log10(x), pullback: { v in v / (.log(10) * x) }) + } + + @derivative(of: logGamma) + public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + fatalError("unimplemented") + } + + @derivative(of: atan2) + public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { + ( + value: atan2(y: y, x: x), + pullback: { v in + let c = x * x + y * y + return (v * x / c, -v * y / c) + } + ) + } + + @derivative(of: hypot) + public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { + ( + value: hypot(x, y), + pullback: { v in + let c = sqrt(x * x + y * y) + return (v * x / c, v * y / c) + } + ) + } + } + + // MARK: FloatingPoint functions derivatives + extension \(type) { + @derivative(of: abs) + public static func _vjpAbs(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { + \({ + if type == floatingPointType { + "x < 0 ? (value: -x, pullback: { v in .zero - v }) : (value: x, pullback: { v in v })" + } else { + "(value: abs(x), pullback: { v in v.replacing(with: -v, where: x .< .zero) })" + } + }()) + } + } + #endif + """ + } +} diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift new file mode 100644 index 0000000..4357287 --- /dev/null +++ b/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift @@ -0,0 +1,136 @@ +struct RealFunctionsGenerator { + static func realFunctionsExtension(objectType: String, type: String, whereClause: Bool, simdAccelerated: Bool) -> String { + let elementaryFunctions = [ + RealFunction(name: "exp", simdName: "exp"), + RealFunction(name: "expMinusOne", simdName: "expm1"), + RealFunction(name: "cosh", simdName: "cosh"), + RealFunction(name: "sinh", simdName: "sinh"), + RealFunction(name: "tanh", simdName: "tanh"), + RealFunction(name: "cos", simdName: "cos"), + RealFunction(name: "sin", simdName: "sin"), + RealFunction(name: "tan", simdName: "tan"), + RealFunction(name: "log", simdName: "log"), + RealFunction(name: "log", simdName: "log1p", arguments: [.init(name: "x", label: "onePlus")]), + RealFunction(name: "acosh", simdName: "acosh"), + RealFunction(name: "asinh", simdName: "asinh"), + RealFunction(name: "atanh", simdName: "atanh"), + RealFunction(name: "acos", simdName: "acos"), + RealFunction(name: "asin", simdName: "asin"), + RealFunction(name: "atan", simdName: "atan"), + RealFunction(name: "pow", simdName: "pow", arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")]), + RealFunction(name: "pow", simdName: "pow", arguments: [.init(name: "x", label: "_"), .init(name: "y", label: "_")]), + RealFunction(name: "sqrt"), + RealFunction(name: "root", arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")]), + ] + + let realFunctions = [ + RealFunction(name: "atan2", simdName: "atan2", arguments: [.init(name: "y"), .init(name: "x")]), + RealFunction(name: "erf", simdName: "erf"), + RealFunction(name: "erfc", simdName: "erfc"), + RealFunction(name: "exp2", simdName: "exp2"), + RealFunction(name: "exp10", simdName: "exp10"), + RealFunction(name: "hypot", simdName: "hypot", arguments: [.init(name: "x", label: "_"), .init(name: "y", label: "_")]), + RealFunction(name: "gamma", simdName: "tgamma"), + RealFunction(name: "log2", simdName: "log2"), + RealFunction(name: "log10", simdName: "log10"), + RealFunction(name: "logGamma", simdName: "lgamma"), + ] + + let floatingPointFunctions = [ + RealFunction(name: "abs", simdName: "simd_abs"), + ] + + let elementaryFunctionsCode = elementaryFunctions.map { + realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) + }.joined(separator: "\n\n") + + let realFunctionsCode = realFunctions.map { + realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) + }.joined(separator: "\n\n") + + let floatingPointFunctionsCode = floatingPointFunctions.map { + realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) + }.joined(separator: "\n\n") + + let acceleratedHeader = """ + #if canImport(simd) + import simd + #endif + """ + + return """ + \(simdAccelerated ? acceleratedHeader : "") + import RealModule + + // MARK: ElementaryFunctions + extension \(objectType)\(whereClause ? " where Scalar: ElementaryFunctions" : "") { + \(elementaryFunctionsCode) + } + + // MARK: RealFunctions + extension \(objectType)\(whereClause ? " where Scalar: RealFunctions" : "") { + \(realFunctionsCode) + + // signGamma is missing here since we cannot return a SIMDX Otherwise we could also conform SIMD types to the RealFunctions protocol. + // @_transparent + // public static func signGamma(_ x: Self) -> SIMDX { + // fatalError() + // } + } + + // MARK: FloatingPointFunctions + extension \(objectType)\(whereClause ? " where Scalar: Real" : "") { + \(floatingPointFunctionsCode) + } + """ + } + + static func realFunctionTemplate(for function: RealFunction, type: String, simdAccelerated: Bool) -> String { + let interfaceArguments: String = function.arguments.map { + if let label = $0.label { + return "\(label) \($0.name): \($0.type ?? type)" + } else { + return "\($0.name): \($0.type ?? type)" + } + }.joined(separator: ", ") + + let implementationArguments = function.arguments.map { + if let label = $0.label { + "\(label == "_" ? "" : "\(label): ")\($0.name)\($0.type == nil ? "[i]" : "")" + } else { + "\($0.name): \($0.name)\($0.type == nil ? "[i]" : "")" + } + }.joined(separator: ", ") + + let regularImplementation = """ + @_transparent + public static func \(function.name)(\(interfaceArguments)) -> \(type) { + var v = Self() + for i in v.indices { + v[i] = .\(function.name)(\(implementationArguments)) + } + return v + } + """ + + guard simdAccelerated else { return regularImplementation } + + // we return the regular implementation if no simd equivalent is present (currently only true for sqrt and root) + guard let simdName = function.simdName else { return regularImplementation } + + let acceleratedArguments = function.arguments.map { arg in "\(arg.type.map { _ in ".init(repeating: .init(\(arg.name)))" } ?? arg.name)" }.joined(separator: ", ") + + let acceleratedImplementation = """ + #if canImport(simd) + @_transparent + public static func \(function.name)(\(interfaceArguments)) -> \(type) { + simd.\(simdName)(\(acceleratedArguments)) + } + #else + \(regularImplementation) + #endif + """ + + return acceleratedImplementation + } +} diff --git a/Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift b/Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift new file mode 100644 index 0000000..7320d30 --- /dev/null +++ b/Sources/RealModuleDifferentiable/FloatingPoint+ConcreteImplementations.swift @@ -0,0 +1,56 @@ +// These extensions are here due to sqrt(_:) being defined on the `Real` protocol and we currently can't define +// default derivatives for protocol requirements. So we have to create a concrete implementation for each type +// to attach the derivatives to. +extension Float { + @_transparent + public static func sqrt(_ x: Float) -> Float { + x.squareRoot() + } +} + +extension Double { + @_transparent + public static func sqrt(_ x: Double) -> Double { + x.squareRoot() + } +} + +#if !(os(macOS) || os(iOS) || os(tvOS) || os(watchOS)) +// This is a concrete version of the default implementation on the `Real` protocol +// this exists here so we can associate a derivative with this function on platforms that do not have a math library that provides exp10 +extension Float { + @_transparent + public static func exp10(_ x: Float) -> Float { + pow(10, x) + } +} + +extension Double { + @_transparent + public static func exp10(_ x: Double) -> Double { + pow(10, x) + } +} +#endif + +// Extensions so that SIMD can have a fallback `abs` implementation with similar api as the RealFunctions protocol +extension Float { + @_transparent + public static func abs(_ x: Float) -> Float { + Swift.abs(x) + } +} + +extension Double { + @_transparent + public static func abs(_ x: Double) -> Double { + Swift.abs(x) + } +} + +extension Real { + @_transparent + public static func abs(_ x: Self) -> Self { + Swift.abs(x) + } +} diff --git a/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift b/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift new file mode 100644 index 0000000..6e72e70 --- /dev/null +++ b/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift @@ -0,0 +1,25 @@ +import RealModule + +#if canImport(_Differentiation) +import _Differentiation +#endif + +#if !canImport(_Differentiation) +// add `AdditiveArithmetic` conformance since this is only present in the _Differentiation module which is not present everywhere +extension SIMD2: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +extension SIMD4: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +extension SIMD8: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +extension SIMD16: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +extension SIMD32: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +extension SIMD64: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +#endif + +// Elementary functions are generated for the SIMD protocol and thus every concrete SIMD type can conform to `ElementaryFunctions` +// Actual implementation is generated by the CodeGeneratorPlugin +// Add actual conformances to `ElementaryFunctions` to individual SIMD types +extension SIMD2: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } +extension SIMD4: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } +extension SIMD8: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } +extension SIMD16: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } +extension SIMD32: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } +extension SIMD64: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } From 448b9cea5786bbcf477666d18b404b0965e44609 Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Fri, 2 May 2025 13:17:33 -0600 Subject: [PATCH 2/6] formatting --- Plugins/CodeGeneratorPlugin.swift | 6 +- .../CodeGenerator.swift | 46 +++++++++---- .../RealFunctionsDerivativesGenerator.swift | 67 ++++++++++--------- .../RealFunctionsGenerator.swift | 51 ++++++++------ .../SIMD+ElementaryFunctions.swift | 24 +++---- 5 files changed, 110 insertions(+), 84 deletions(-) diff --git a/Plugins/CodeGeneratorPlugin.swift b/Plugins/CodeGeneratorPlugin.swift index 63fbab4..886d887 100644 --- a/Plugins/CodeGeneratorPlugin.swift +++ b/Plugins/CodeGeneratorPlugin.swift @@ -8,7 +8,7 @@ struct CodeGeneratorPlugin: BuildToolPlugin { let floatingPointTypes: [String] = ["Float", "Double"] let simdSizes = [2, 4, 8, 16, 32, 64] - + let outputFiles = floatingPointTypes.flatMap { floatingPointType in simdSizes.flatMap { simdSize in [ @@ -21,7 +21,7 @@ struct CodeGeneratorPlugin: BuildToolPlugin { } + [ output.appending(component: "SIMD+RealFunctions.swift"), ] - + return [ .buildCommand( displayName: "Generate Code", @@ -30,7 +30,7 @@ struct CodeGeneratorPlugin: BuildToolPlugin { environment: [:], inputFiles: [], outputFiles: outputFiles - ) + ), ] } } diff --git a/Sources/CodeGeneratorExecutable/CodeGenerator.swift b/Sources/CodeGeneratorExecutable/CodeGenerator.swift index ea0b46f..49711c4 100644 --- a/Sources/CodeGeneratorExecutable/CodeGenerator.swift +++ b/Sources/CodeGeneratorExecutable/CodeGenerator.swift @@ -12,9 +12,14 @@ struct CodeGenerator { // generate default implementations of RealFunctions for SIMD protocol let realFunctionSIMDFileURL = output.appending(component: "SIMD+RealFunctions.swift") - let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension(objectType: "SIMD", type: "Self", whereClause: true, simdAccelerated: false) + let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension( + objectType: "SIMD", + type: "Self", + whereClause: true, + simdAccelerated: false + ) try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8) - + let floatingPointTypes: [String] = ["Float", "Double"] let simdSizes: [Int] = [2, 4, 8, 16, 32, 64] @@ -25,32 +30,45 @@ struct CodeGenerator { directoryHint: .notDirectory ) let type = floatingPointType - let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType) + let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension( + type: type, + floatingPointType: floatingPointType + ) try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) - + for simdSize in simdSizes { let realFunctionFileURL = output.appending( component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift", directoryHint: .notDirectory ) let simdType = "SIMD\(simdSize)<\(floatingPointType)>" - + // no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those. var simdAccelerated: Bool if simdSize > 16 || (simdSize == 16 && floatingPointType == "Double") { simdAccelerated = false - } else { + } + else { simdAccelerated = true } - + // Generate RealFunctions implementations on concrete SIMD types to attach derivatives to - let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension(objectType: simdType, type: simdType, whereClause: false, simdAccelerated: simdAccelerated) + let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension( + objectType: simdType, + type: simdType, + whereClause: false, + simdAccelerated: simdAccelerated + ) try realFunctionsExtensionCode.write(to: realFunctionFileURL, atomically: true, encoding: .utf8) - + // Generate RealFunctions derivatives for concrete SIMD types - let realFunctionDerivativesFileURL = output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift") + let realFunctionDerivativesFileURL = output + .appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift") let type = "SIMD\(simdSize)<\(floatingPointType)>" - let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType) + let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension( + type: type, + floatingPointType: floatingPointType + ) try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) } } @@ -61,13 +79,13 @@ struct RealFunction { var name: String var simdName: String? var arguments: [Argument] - + struct Argument { var name: String var label: String? - var type: String? = nil + var type: String? } - + init(name: String, simdName: String? = nil, arguments: [Argument] = [.init(name: "x", label: "_")]) { self.name = name self.simdName = simdName diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift index 448b0c1..ad29810 100644 --- a/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift +++ b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift @@ -1,10 +1,10 @@ -struct RealFunctionsDerivativesGenerator { +enum RealFunctionsDerivativesGenerator { static func realFunctionsDerivativesExtension(type: String, floatingPointType: String) -> String { """ #if canImport(_Differentiation) import _Differentiation import RealModule - + // MARK: ElementaryFunctions derivatives extension \(type) { @derivative(of: exp) @@ -12,22 +12,22 @@ struct RealFunctionsDerivativesGenerator { let value = exp(x) return (value: value, pullback: { v in v * value }) } - + @derivative(of: expMinusOne) public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { return (value: expMinusOne(x), pullback: { v in v * exp(x) }) } - + @derivative(of: cosh) public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: cosh(x), pullback: { v in sinh(x) }) } - + @derivative(of: sinh) public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: sinh(x), pullback: { v in cosh(x) }) } - + @derivative(of: tanh) public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { ( @@ -38,17 +38,17 @@ struct RealFunctionsDerivativesGenerator { } ) } - + @derivative(of: cos) public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: cos(x), pullback: { v in -v * sin(x) }) } - + @derivative(of: sin) public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: sin(x), pullback: { v in v * cos(x) }) } - + @derivative(of: tan) public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { ( @@ -59,117 +59,117 @@ struct RealFunctionsDerivativesGenerator { } ) } - + @derivative(of: log(_:)) public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log(x), pullback: { v in v / x }) } - + @derivative(of: acosh) public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { // only valid for x > 1 return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) }) } - + @derivative(of: asinh) public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: asinh(x), pullback: { v in v / sqrt(x * x + 1) }) } - + @derivative(of: atanh) public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: atanh(x), pullback: { v in v / (1 - x * x) }) } - + @derivative(of: acos) public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: acos(x), pullback: { v in -v / (1 - x * x) }) } - + @derivative(of: asin) public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: asin(x), pullback: { v in v / (1 - x * x) }) } - + @derivative(of: atan) public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: atan(x), pullback: { v in v / (x * x + 1) }) } - + @derivative(of: log(onePlus:)) public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log(onePlus: x), pullback: { v in v / (1 + x) }) } - + @derivative(of: pow) public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { let value = pow(x, y) // pullback wrt y is not defined for (x < 0) and (x = 0, y = 0) return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) }) } - + @derivative(of: pow) public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) }) } - + @derivative(of: sqrt) public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = sqrt(x) return (value: value, pullback: { v in v / (2 * value) }) } - + @derivative(of: root) public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = root(x, n) return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) }) } } - + // MARK: RealFunctions derivatives extension \(type) { @derivative(of: erf) public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) } - + @derivative(of: erfc) public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) } - + @derivative(of: exp2) public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = exp2(x) return (value, { v in v * value * .log(2) }) } - + @derivative(of: exp10) public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = exp10(x) return (value, { v in v * value * .log(10) }) } - + @derivative(of: gamma) public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { fatalError("unimplemented") } - + @derivative(of: log2) public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log2(x), pullback: { v in v / (.log(2) * x) }) } - + @derivative(of: log10) public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log10(x), pullback: { v in v / (.log(10) * x) }) } - + @derivative(of: logGamma) public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { fatalError("unimplemented") } - + @derivative(of: atan2) public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { ( @@ -180,7 +180,7 @@ struct RealFunctionsDerivativesGenerator { } ) } - + @derivative(of: hypot) public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { ( @@ -192,7 +192,7 @@ struct RealFunctionsDerivativesGenerator { ) } } - + // MARK: FloatingPoint functions derivatives extension \(type) { @derivative(of: abs) @@ -200,7 +200,8 @@ struct RealFunctionsDerivativesGenerator { \({ if type == floatingPointType { "x < 0 ? (value: -x, pullback: { v in .zero - v }) : (value: x, pullback: { v in v })" - } else { + } + else { "(value: abs(x), pullback: { v in v.replacing(with: -v, where: x .< .zero) })" } }()) diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift index 4357287..dec0c5d 100644 --- a/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift +++ b/Sources/CodeGeneratorExecutable/RealFunctionsGenerator.swift @@ -1,4 +1,4 @@ -struct RealFunctionsGenerator { +enum RealFunctionsGenerator { static func realFunctionsExtension(objectType: String, type: String, whereClause: Bool, simdAccelerated: Bool) -> String { let elementaryFunctions = [ RealFunction(name: "exp", simdName: "exp"), @@ -17,12 +17,16 @@ struct RealFunctionsGenerator { RealFunction(name: "acos", simdName: "acos"), RealFunction(name: "asin", simdName: "asin"), RealFunction(name: "atan", simdName: "atan"), - RealFunction(name: "pow", simdName: "pow", arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")]), + RealFunction( + name: "pow", + simdName: "pow", + arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")] + ), RealFunction(name: "pow", simdName: "pow", arguments: [.init(name: "x", label: "_"), .init(name: "y", label: "_")]), RealFunction(name: "sqrt"), RealFunction(name: "root", arguments: [.init(name: "x", label: "_"), .init(name: "n", label: "_", type: "Int")]), ] - + let realFunctions = [ RealFunction(name: "atan2", simdName: "atan2", arguments: [.init(name: "y"), .init(name: "x")]), RealFunction(name: "erf", simdName: "erf"), @@ -35,7 +39,7 @@ struct RealFunctionsGenerator { RealFunction(name: "log10", simdName: "log10"), RealFunction(name: "logGamma", simdName: "lgamma"), ] - + let floatingPointFunctions = [ RealFunction(name: "abs", simdName: "simd_abs"), ] @@ -43,65 +47,67 @@ struct RealFunctionsGenerator { let elementaryFunctionsCode = elementaryFunctions.map { realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) }.joined(separator: "\n\n") - + let realFunctionsCode = realFunctions.map { realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) }.joined(separator: "\n\n") - + let floatingPointFunctionsCode = floatingPointFunctions.map { realFunctionTemplate(for: $0, type: type, simdAccelerated: simdAccelerated) }.joined(separator: "\n\n") - + let acceleratedHeader = """ #if canImport(simd) import simd #endif """ - + return """ \(simdAccelerated ? acceleratedHeader : "") import RealModule - + // MARK: ElementaryFunctions extension \(objectType)\(whereClause ? " where Scalar: ElementaryFunctions" : "") { \(elementaryFunctionsCode) } - + // MARK: RealFunctions extension \(objectType)\(whereClause ? " where Scalar: RealFunctions" : "") { \(realFunctionsCode) - + // signGamma is missing here since we cannot return a SIMDX Otherwise we could also conform SIMD types to the RealFunctions protocol. // @_transparent // public static func signGamma(_ x: Self) -> SIMDX { // fatalError() // } } - + // MARK: FloatingPointFunctions extension \(objectType)\(whereClause ? " where Scalar: Real" : "") { \(floatingPointFunctionsCode) } """ } - + static func realFunctionTemplate(for function: RealFunction, type: String, simdAccelerated: Bool) -> String { let interfaceArguments: String = function.arguments.map { if let label = $0.label { return "\(label) \($0.name): \($0.type ?? type)" - } else { + } + else { return "\($0.name): \($0.type ?? type)" } }.joined(separator: ", ") - + let implementationArguments = function.arguments.map { if let label = $0.label { "\(label == "_" ? "" : "\(label): ")\($0.name)\($0.type == nil ? "[i]" : "")" - } else { + } + else { "\($0.name): \($0.name)\($0.type == nil ? "[i]" : "")" } }.joined(separator: ", ") - + let regularImplementation = """ @_transparent public static func \(function.name)(\(interfaceArguments)) -> \(type) { @@ -112,14 +118,15 @@ struct RealFunctionsGenerator { return v } """ - + guard simdAccelerated else { return regularImplementation } // we return the regular implementation if no simd equivalent is present (currently only true for sqrt and root) guard let simdName = function.simdName else { return regularImplementation } - - let acceleratedArguments = function.arguments.map { arg in "\(arg.type.map { _ in ".init(repeating: .init(\(arg.name)))" } ?? arg.name)" }.joined(separator: ", ") - + + let acceleratedArguments = function.arguments + .map { arg in "\(arg.type.map { _ in ".init(repeating: .init(\(arg.name)))" } ?? arg.name)" }.joined(separator: ", ") + let acceleratedImplementation = """ #if canImport(simd) @_transparent @@ -130,7 +137,7 @@ struct RealFunctionsGenerator { \(regularImplementation) #endif """ - + return acceleratedImplementation } } diff --git a/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift b/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift index 6e72e70..23520a0 100644 --- a/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift +++ b/Sources/RealModuleDifferentiable/SIMD+ElementaryFunctions.swift @@ -6,20 +6,20 @@ import _Differentiation #if !canImport(_Differentiation) // add `AdditiveArithmetic` conformance since this is only present in the _Differentiation module which is not present everywhere -extension SIMD2: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } -extension SIMD4: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } -extension SIMD8: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } -extension SIMD16: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } -extension SIMD32: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } -extension SIMD64: @retroactive AdditiveArithmetic where Scalar: FloatingPoint { } +extension SIMD2: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD4: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD8: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD16: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD32: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} +extension SIMD64: @retroactive AdditiveArithmetic where Scalar: FloatingPoint {} #endif // Elementary functions are generated for the SIMD protocol and thus every concrete SIMD type can conform to `ElementaryFunctions` // Actual implementation is generated by the CodeGeneratorPlugin // Add actual conformances to `ElementaryFunctions` to individual SIMD types -extension SIMD2: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } -extension SIMD4: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } -extension SIMD8: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } -extension SIMD16: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } -extension SIMD32: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } -extension SIMD64: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions { } +extension SIMD2: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD4: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD8: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD16: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD32: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} +extension SIMD64: @retroactive ElementaryFunctions where Scalar: FloatingPoint & ElementaryFunctions {} From 8fd0b5a5ec2ab9c720e3b8262326da48c2059c07 Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Thu, 8 May 2025 10:39:21 +0200 Subject: [PATCH 3/6] add some tests and fix incorrect derivative --- README.md | 8 + .../CodeGenerator.swift | 1 - .../RealFunctionsDerivativesGenerator.swift | 35 ++- .../RealModuleDifferentiableTests.swift | 225 +++++++++++++++++- 4 files changed, 264 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b06dbf9..36d6f69 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ # swift-numerics-differentiable +This package attempts to add more Differentiable capabilities to the existing [swift-numerics](https://github.com/apple/swift-numerics) package. Every target in swift-numerics has a Differentiable counterpart that `@_exported import`s the original module such that when you import `NumericsDifferentiable` you will also get all the contents of the `Numerics` module from swift-numerics. + +## RealModule Differentiable +- Registers derivatives to the `Float` and `Double` conformances to `ElementaryFunctions` and `RealFunctions` from swift-numerics. +- Conforms all `SIMD{n}` types to `ElementaryFunctions` and adds most of the protocol requirements from `RealFunctions` as well (`signGamma` is not implementable) +- Registers derivatives for all the provided `ElementaryFunctions` and `RealFunctions` implementations on SIMD{n} +- Tries to leverage Apple's `simd` framework to accelerate these operations where possible on Apple platforms. + ## Contributing ### Code Formatting This package makes use of [SwiftFormat](https://github.com/nicklockwood/SwiftFormat?tab=readme-ov-file#command-line-tool), which you can install diff --git a/Sources/CodeGeneratorExecutable/CodeGenerator.swift b/Sources/CodeGeneratorExecutable/CodeGenerator.swift index 49711c4..dc6000d 100644 --- a/Sources/CodeGeneratorExecutable/CodeGenerator.swift +++ b/Sources/CodeGeneratorExecutable/CodeGenerator.swift @@ -3,7 +3,6 @@ import Foundation @main struct CodeGenerator { static func main() throws { - // Use swift-argument-parser or just CommandLine, here we just imply that 2 paths are passed in: input and output guard CommandLine.arguments.count == 2 else { throw CodeGeneratorError.invalidArguments } diff --git a/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift index ad29810..8857a24 100644 --- a/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift +++ b/Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift @@ -8,27 +8,32 @@ enum RealFunctionsDerivativesGenerator { // MARK: ElementaryFunctions derivatives extension \(type) { @derivative(of: exp) + @_transparent public static func _vjpExp(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = exp(x) return (value: value, pullback: { v in v * value }) } @derivative(of: expMinusOne) + @_transparent public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { return (value: expMinusOne(x), pullback: { v in v * exp(x) }) } @derivative(of: cosh) + @_transparent public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: cosh(x), pullback: { v in sinh(x) }) } @derivative(of: sinh) + @_transparent public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: sinh(x), pullback: { v in cosh(x) }) } @derivative(of: tanh) + @_transparent public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { ( value: tanh(x), @@ -40,16 +45,19 @@ enum RealFunctionsDerivativesGenerator { } @derivative(of: cos) + @_transparent public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: cos(x), pullback: { v in -v * sin(x) }) } @derivative(of: sin) + @_transparent public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: sin(x), pullback: { v in v * cos(x) }) } @derivative(of: tan) + @_transparent public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { ( value: tan(x), @@ -61,47 +69,56 @@ enum RealFunctionsDerivativesGenerator { } @derivative(of: log(_:)) + @_transparent public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log(x), pullback: { v in v / x }) } @derivative(of: acosh) + @_transparent public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { // only valid for x > 1 return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) }) } @derivative(of: asinh) + @_transparent public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: asinh(x), pullback: { v in v / sqrt(x * x + 1) }) } @derivative(of: atanh) + @_transparent public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: atanh(x), pullback: { v in v / (1 - x * x) }) } @derivative(of: acos) + @_transparent public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { - (value: acos(x), pullback: { v in -v / (1 - x * x) }) + (value: acos(x), pullback: { v in -v / .sqrt(1 - x * x) }) } @derivative(of: asin) + @_transparent public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { - (value: asin(x), pullback: { v in v / (1 - x * x) }) + (value: asin(x), pullback: { v in v / .sqrt(1 - x * x) }) } @derivative(of: atan) + @_transparent public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: atan(x), pullback: { v in v / (x * x + 1) }) } @derivative(of: log(onePlus:)) + @_transparent public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log(onePlus: x), pullback: { v in v / (1 + x) }) } @derivative(of: pow) + @_transparent public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { let value = pow(x, y) // pullback wrt y is not defined for (x < 0) and (x = 0, y = 0) @@ -109,17 +126,20 @@ enum RealFunctionsDerivativesGenerator { } @derivative(of: pow) + @_transparent public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) }) } @derivative(of: sqrt) + @_transparent public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = sqrt(x) return (value: value, pullback: { v in v / (2 * value) }) } @derivative(of: root) + @_transparent public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = root(x, n) return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) }) @@ -129,48 +149,57 @@ enum RealFunctionsDerivativesGenerator { // MARK: RealFunctions derivatives extension \(type) { @derivative(of: erf) + @_transparent public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) } @derivative(of: erfc) + @_transparent public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) }) } @derivative(of: exp2) + @_transparent public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = exp2(x) return (value, { v in v * value * .log(2) }) } @derivative(of: exp10) + @_transparent public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { let value = exp10(x) return (value, { v in v * value * .log(10) }) } @derivative(of: gamma) + @_transparent public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { fatalError("unimplemented") } @derivative(of: log2) + @_transparent public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log2(x), pullback: { v in v / (.log(2) * x) }) } @derivative(of: log10) + @_transparent public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { (value: log10(x), pullback: { v in v / (.log(10) * x) }) } @derivative(of: logGamma) + @_transparent public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { fatalError("unimplemented") } @derivative(of: atan2) + @_transparent public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { ( value: atan2(y: y, x: x), @@ -182,6 +211,7 @@ enum RealFunctionsDerivativesGenerator { } @derivative(of: hypot) + @_transparent public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) { ( value: hypot(x, y), @@ -196,6 +226,7 @@ enum RealFunctionsDerivativesGenerator { // MARK: FloatingPoint functions derivatives extension \(type) { @derivative(of: abs) + @_transparent public static func _vjpAbs(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) { \({ if type == floatingPointType { diff --git a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift index 1939214..5b26141 100644 --- a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift +++ b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift @@ -1,6 +1,227 @@ +import _Differentiation @testable import RealModuleDifferentiable import Testing -@Test func example() async throws { - // Write your test here and use APIs like `#expect(...)` to check expected conditions. +@Suite +struct TestRegisteredDerivatives { + // These tests are more about the derivatives being correctly registered than the correct values. Should probably use a form of + // `.isApproximatelyEqual(to:)` for the results in the future but that doesn't combine too well with a SIMD vector comparison. + @Test + func + testExp() + { + let vwpb = valueWithPullback(at: 2.0, of: Float.exp) + #expect(vwpb.value == 7.38905609893065) + #expect(vwpb.pullback(1) == 7.38905609893065) + } + + @Test + func testExpMinusOne() { + let vwpb = valueWithPullback(at: 2.0, of: Double.expMinusOne(_:)) + #expect(vwpb.value == 6.38905609893065) + #expect(vwpb.pullback(1) == 7.38905609893065) + } + + @Test + func testCosh() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 2.0), of: SIMD2.cosh) + #expect(vwpb.value == .init(repeating: 3.7621956)) + #expect(vwpb.pullback(.one) == .init(repeating: 3.6268604)) + } + + @Test + func testSinh() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 2.0), of: SIMD4.sinh) + #expect(vwpb.value == .init(repeating: 3.6268604)) + #expect(vwpb.pullback(.one) == .init(repeating: 3.7621956)) + } + + @Test + func testTanh() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 2.0), of: SIMD8.tanh) + #expect(vwpb.value == .init(repeating: 0.9640276)) + #expect(vwpb.pullback(.one) == .init(repeating: 0.07065083)) + } + + @Test + func testCos() { + let vwpb = valueWithPullback(at: SIMD16(repeating: .pi / 2), of: SIMD16.cos) + #expect(vwpb.value == .init(repeating: 7.54979E-08)) + #expect(vwpb.pullback(.one) == .init(repeating: -1)) + } + + @Test + func testSin() { + let vwpb = valueWithPullback(at: SIMD32(repeating: .pi / 2), of: SIMD32.sin) + #expect(vwpb.value == .init(repeating: 1)) + #expect(vwpb.pullback(.one) == .init(repeating: 7.54979E-08)) + } + + @Test + func testTan() { + let vwpb = valueWithPullback(at: SIMD64(repeating: .pi / 4), of: SIMD64.tan) + #expect(vwpb.value == .init(repeating: 0.99999994)) + #expect(vwpb.pullback(.one) == .init(repeating: 1.9999998)) + } + + @Test + func testLog() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 2), of: SIMD2.log(_:)) + #expect(vwpb.value == SIMD2(repeating: 0.6931471805599453)) + #expect(vwpb.pullback(SIMD2.one) == .init(repeating: 0.5)) + } + + @Test + func testLogOnePlus() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 3), of: SIMD4.log(onePlus:)) + #expect(vwpb.value == .init(repeating: 1.3862943611198906)) + #expect(vwpb.pullback(.one) == .init(repeating: 0.25)) + } + + @Test + func testAcosh() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 2), of: SIMD8.acosh) + #expect(vwpb.value == .init(repeating: 1.3169578969248166)) + #expect(vwpb.pullback(.one) == .init(repeating: 1 / .sqrt(3))) + } + + @Test + func testAsinh() { + let vwpb = valueWithPullback(at: SIMD16(repeating: 2), of: SIMD16.asinh) + #expect(vwpb.value == .init(repeating: 1.4436354751788103)) + #expect(vwpb.pullback(.one) == .init(repeating: 1 / .sqrt(5))) + } + + @Test + func testAtanh() { + let vwpb = valueWithPullback(at: SIMD32(repeating: 0.5), of: SIMD32.atanh) + #expect(vwpb.value == .init(repeating: 0.5493061443340549)) + #expect(vwpb.pullback(.one) == .init(repeating: 4 / 3)) + } + + @Test + func testaCos() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 0.5), of: SIMD64.acos) + #expect(vwpb.value == .init(repeating: 1.0471975511965976)) + #expect(vwpb.pullback(.one) == .init(repeating: -1.1547005383792517)) + } + + @Test + func testaSin() { + let vwpb = valueWithPullback(at: 0.5, of: Float.asin) + #expect(vwpb.value == 0.5235988) + #expect(vwpb.pullback(1) == 1.1547005383792517) + } + + @Test + func testaTan() { + let vwpb = valueWithPullback(at: 0.5, of: Double.atan) + #expect(vwpb.value == 0.46364760900080615) + #expect(vwpb.pullback(1) == 0.8) + } + + @Test + func testPow() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 0.5), SIMD2(repeating: 2), of: SIMD2.pow(_:_:)) + #expect(vwpb.value == .init(repeating: 0.25)) + #expect(vwpb.pullback(.one) == (.init(repeating: 1.0), .init(repeating: -0.1732868))) + } + + @Test + func testPowInt() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 0.5), of: { x in SIMD4.pow(x, 2) }) + #expect(vwpb.value == .init(repeating: 0.25)) + #expect(vwpb.pullback(.one) == .init(repeating: 1.0)) + } + + @Test + func testSqrt() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 4), of: SIMD8.sqrt) + #expect(vwpb.value == .init(repeating: 2)) + #expect(vwpb.pullback(.one) == .init(repeating: 0.25)) + } + + @Test + func testRoot() { + let vwpb = valueWithPullback(at: SIMD16(repeating: 16), of: { x in SIMD16.root(x, 4) }) + #expect(vwpb.value == .init(repeating: 2)) + #expect(vwpb.pullback(.one) == .init(repeating: 1 / 32)) + } + + @Test + func testAtan2() { + let vwpb = valueWithPullback(at: SIMD32(repeating: 1), SIMD32(repeating: 0), of: SIMD32.atan2) + #expect(vwpb.value == .init(repeating: 1.5707964)) // .pi / 2 + #expect(vwpb.pullback(.one) == (.init(repeating: 0), .init(repeating: -1))) + } + + @Test + func testErf() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 0.5), of: SIMD64.erf) + #expect(vwpb.value == .init(repeating: 0.5204999)) + #expect(vwpb.pullback(.one) == .init(repeating: 0.87878263)) + } + + @Test + func testErfc() { + let vwpb = valueWithPullback(at: SIMD2(repeating: 0.5), of: SIMD2.erfc) + #expect(vwpb.value == .init(repeating: 0.4795001221869535)) + #expect(vwpb.pullback(.one) == .init(repeating: -0.8787825789354449)) + } + + @Test + func testExp2() { + let vwpb = valueWithPullback(at: SIMD4(repeating: 2), of: SIMD4.exp2) + #expect(vwpb.value == .init(repeating: 4)) + #expect(vwpb.pullback(.one) == .init(repeating: 4 * .log(2))) + } + + @Test + func testExp10() { + let vwpb = valueWithPullback(at: SIMD8(repeating: 2), of: SIMD8.exp10) + #expect(vwpb.value == .init(repeating: 100)) + #expect(vwpb.pullback(.one) == .init(repeating: 100 * .log(10))) + } + + @Test + func testHypot() { + let vwpb = valueWithPullback(at: SIMD16(repeating: 3), SIMD16(repeating: 4), of: SIMD16.hypot) + #expect(vwpb.value == .init(repeating: 5)) + #expect(vwpb.pullback(.one) == (.init(repeating: 3 / 5), .init(repeating: 4 / 5))) + } + + @Test(.disabled("derivative not implemented")) + func testGamma() { + let vwpb = valueWithPullback(at: SIMD32(repeating: 2), of: SIMD32.gamma) + #expect(vwpb.value == .init(repeating: 1)) + #expect(vwpb.pullback(.one) == .init(repeating: 0)) + } + + @Test + func testLog2() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 2), of: SIMD64.log2) + #expect(vwpb.value == .init(repeating: 1)) + #expect(vwpb.pullback(.one) == .init(repeating: 1 / .log(4))) + } + + @Test + func testLog10() { + let vwpb = valueWithPullback(at: 2.0, of: Float.log10) + #expect(vwpb.value == 0.30103) // .log(2) / .log(10) + #expect(vwpb.pullback(1) == 1 / .log(100)) + } + + @Test(.disabled("derivative not implemented")) + func testLogGamma() { + let vwpb = valueWithPullback(at: 2, of: Double.logGamma) + #expect(vwpb.value == 0) + #expect(vwpb.pullback(1) == 0) + } + + @Test + func testAbs() { + let vwpb = valueWithPullback(at: SIMD2(repeating: -2), of: SIMD2.abs) + #expect(vwpb.value == .init(repeating: 2)) + #expect(vwpb.pullback(.one) == .init(repeating: -1)) + } } From bb000dcf116d8b1e4f26ff1120d621ae625313ff Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Thu, 8 May 2025 10:45:52 +0200 Subject: [PATCH 4/6] process feedback --- Plugins/CodeGeneratorPlugin.swift | 8 ++++---- .../CodeGenerator.swift | 19 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/Plugins/CodeGeneratorPlugin.swift b/Plugins/CodeGeneratorPlugin.swift index 886d887..52e007d 100644 --- a/Plugins/CodeGeneratorPlugin.swift +++ b/Plugins/CodeGeneratorPlugin.swift @@ -7,13 +7,13 @@ struct CodeGeneratorPlugin: BuildToolPlugin { let output = context.pluginWorkDirectoryURL let floatingPointTypes: [String] = ["Float", "Double"] - let simdSizes = [2, 4, 8, 16, 32, 64] + let simdWidths = [2, 4, 8, 16, 32, 64] let outputFiles = floatingPointTypes.flatMap { floatingPointType in - simdSizes.flatMap { simdSize in + simdWidths.flatMap { simdWidth in [ - output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift"), - output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift"), + output.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions.swift"), + output.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions+Derivatives.swift"), ] } + [ output.appending(component: "\(floatingPointType)+RealFunctions+Derivatives.swift"), diff --git a/Sources/CodeGeneratorExecutable/CodeGenerator.swift b/Sources/CodeGeneratorExecutable/CodeGenerator.swift index dc6000d..60b88a6 100644 --- a/Sources/CodeGeneratorExecutable/CodeGenerator.swift +++ b/Sources/CodeGeneratorExecutable/CodeGenerator.swift @@ -20,7 +20,7 @@ struct CodeGenerator { try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8) let floatingPointTypes: [String] = ["Float", "Double"] - let simdSizes: [Int] = [2, 4, 8, 16, 32, 64] + let simdWidths: [Int] = [2, 4, 8, 16, 32, 64] for floatingPointType in floatingPointTypes { // Generator Derivatives for RealFunctions for floating point types @@ -28,23 +28,22 @@ struct CodeGenerator { component: "\(floatingPointType)+RealFunctions+Derivatives.swift", directoryHint: .notDirectory ) - let type = floatingPointType let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension( - type: type, + type: floatingPointType, floatingPointType: floatingPointType ) try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8) - for simdSize in simdSizes { + for simdWidth in simdWidths { let realFunctionFileURL = output.appending( - component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift", + component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions.swift", directoryHint: .notDirectory ) - let simdType = "SIMD\(simdSize)<\(floatingPointType)>" + let simdType = "SIMD\(simdWidth)<\(floatingPointType)>" // no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those. - var simdAccelerated: Bool - if simdSize > 16 || (simdSize == 16 && floatingPointType == "Double") { + let simdAccelerated: Bool + if simdWidth > 16 || (simdWidth == 16 && floatingPointType == "Double") { simdAccelerated = false } else { @@ -62,8 +61,8 @@ struct CodeGenerator { // Generate RealFunctions derivatives for concrete SIMD types let realFunctionDerivativesFileURL = output - .appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift") - let type = "SIMD\(simdSize)<\(floatingPointType)>" + .appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions+Derivatives.swift") + let type = "SIMD\(simdWidth)<\(floatingPointType)>" let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension( type: type, floatingPointType: floatingPointType From 8caa4071eb216eecf2cd4385cf04e9f0f0081e67 Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Thu, 8 May 2025 13:44:39 +0200 Subject: [PATCH 5/6] use isApproximatelyEqual to compare --- .../RealModuleDifferentiableTests.swift | 143 +++++++++--------- 1 file changed, 74 insertions(+), 69 deletions(-) diff --git a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift index 5b26141..b27abb8 100644 --- a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift +++ b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift @@ -4,224 +4,229 @@ import Testing @Suite struct TestRegisteredDerivatives { - // These tests are more about the derivatives being correctly registered than the correct values. Should probably use a form of - // `.isApproximatelyEqual(to:)` for the results in the future but that doesn't combine too well with a SIMD vector comparison. + // These tests are more about the derivatives being correctly registered than the correct values. We're only checking the first value for simds since we run the computation on the same values and only checking the first result makes using `.isApproximatelyEqual(to:)` a lot easier. @Test func testExp() { let vwpb = valueWithPullback(at: 2.0, of: Float.exp) - #expect(vwpb.value == 7.38905609893065) - #expect(vwpb.pullback(1) == 7.38905609893065) + #expect(vwpb.value.isApproximatelyEqual(to: 7.38905609893065)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 7.38905609893065)) } @Test func testExpMinusOne() { let vwpb = valueWithPullback(at: 2.0, of: Double.expMinusOne(_:)) - #expect(vwpb.value == 6.38905609893065) - #expect(vwpb.pullback(1) == 7.38905609893065) + #expect(vwpb.value.isApproximatelyEqual(to: 6.38905609893065)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 7.38905609893065)) } @Test func testCosh() { let vwpb = valueWithPullback(at: SIMD2(repeating: 2.0), of: SIMD2.cosh) - #expect(vwpb.value == .init(repeating: 3.7621956)) - #expect(vwpb.pullback(.one) == .init(repeating: 3.6268604)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 3.7621956)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 3.6268604)) } @Test func testSinh() { let vwpb = valueWithPullback(at: SIMD4(repeating: 2.0), of: SIMD4.sinh) - #expect(vwpb.value == .init(repeating: 3.6268604)) - #expect(vwpb.pullback(.one) == .init(repeating: 3.7621956)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 3.6268604)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 3.7621956)) } @Test func testTanh() { let vwpb = valueWithPullback(at: SIMD8(repeating: 2.0), of: SIMD8.tanh) - #expect(vwpb.value == .init(repeating: 0.9640276)) - #expect(vwpb.pullback(.one) == .init(repeating: 0.07065083)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.9640276)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.07065083)) } @Test func testCos() { let vwpb = valueWithPullback(at: SIMD16(repeating: .pi / 2), of: SIMD16.cos) - #expect(vwpb.value == .init(repeating: 7.54979E-08)) - #expect(vwpb.pullback(.one) == .init(repeating: -1)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -1)) } @Test func testSin() { let vwpb = valueWithPullback(at: SIMD32(repeating: .pi / 2), of: SIMD32.sin) - #expect(vwpb.value == .init(repeating: 1)) - #expect(vwpb.pullback(.one) == .init(repeating: 7.54979E-08)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) } @Test func testTan() { let vwpb = valueWithPullback(at: SIMD64(repeating: .pi / 4), of: SIMD64.tan) - #expect(vwpb.value == .init(repeating: 0.99999994)) - #expect(vwpb.pullback(.one) == .init(repeating: 1.9999998)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 2)) } @Test func testLog() { let vwpb = valueWithPullback(at: SIMD2(repeating: 2), of: SIMD2.log(_:)) - #expect(vwpb.value == SIMD2(repeating: 0.6931471805599453)) - #expect(vwpb.pullback(SIMD2.one) == .init(repeating: 0.5)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.6931471805599453)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.5)) } @Test func testLogOnePlus() { let vwpb = valueWithPullback(at: SIMD4(repeating: 3), of: SIMD4.log(onePlus:)) - #expect(vwpb.value == .init(repeating: 1.3862943611198906)) - #expect(vwpb.pullback(.one) == .init(repeating: 0.25)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1.3862943611198906)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.25)) } @Test func testAcosh() { let vwpb = valueWithPullback(at: SIMD8(repeating: 2), of: SIMD8.acosh) - #expect(vwpb.value == .init(repeating: 1.3169578969248166)) - #expect(vwpb.pullback(.one) == .init(repeating: 1 / .sqrt(3))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1.3169578969248166)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .sqrt(3))) } @Test func testAsinh() { let vwpb = valueWithPullback(at: SIMD16(repeating: 2), of: SIMD16.asinh) - #expect(vwpb.value == .init(repeating: 1.4436354751788103)) - #expect(vwpb.pullback(.one) == .init(repeating: 1 / .sqrt(5))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1.4436354751788103)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .sqrt(5))) } @Test func testAtanh() { let vwpb = valueWithPullback(at: SIMD32(repeating: 0.5), of: SIMD32.atanh) - #expect(vwpb.value == .init(repeating: 0.5493061443340549)) - #expect(vwpb.pullback(.one) == .init(repeating: 4 / 3)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.5493061443340549)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 4 / 3)) } @Test - func testaCos() { - let vwpb = valueWithPullback(at: SIMD64(repeating: 0.5), of: SIMD64.acos) - #expect(vwpb.value == .init(repeating: 1.0471975511965976)) - #expect(vwpb.pullback(.one) == .init(repeating: -1.1547005383792517)) + func testAcos() { + let vwpb = valueWithPullback(at: SIMD64(repeating: 1 / .sqrt(2)), of: SIMD64.acos) + #expect(vwpb.value[0].isApproximatelyEqual(to: .pi / 4)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -.sqrt(2))) } @Test - func testaSin() { - let vwpb = valueWithPullback(at: 0.5, of: Float.asin) - #expect(vwpb.value == 0.5235988) - #expect(vwpb.pullback(1) == 1.1547005383792517) + func testAsin() { + let vwpb = valueWithPullback(at: 1 / .sqrt(2), of: Float.asin) + #expect(vwpb.value.isApproximatelyEqual(to: .pi / 4)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: .sqrt(2))) } @Test - func testaTan() { + func testAtan() { let vwpb = valueWithPullback(at: 0.5, of: Double.atan) - #expect(vwpb.value == 0.46364760900080615) - #expect(vwpb.pullback(1) == 0.8) + #expect(vwpb.value.isApproximatelyEqual(to: 0.46364760900080615)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 0.8)) } @Test func testPow() { let vwpb = valueWithPullback(at: SIMD2(repeating: 0.5), SIMD2(repeating: 2), of: SIMD2.pow(_:_:)) - #expect(vwpb.value == .init(repeating: 0.25)) - #expect(vwpb.pullback(.one) == (.init(repeating: 1.0), .init(repeating: -0.1732868))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.25)) + let gradient = vwpb.pullback(.one) + #expect(gradient.0[0].isApproximatelyEqual(to: 1.0)) + #expect(gradient.1[0].isApproximatelyEqual(to: -0.1732868)) } @Test func testPowInt() { let vwpb = valueWithPullback(at: SIMD4(repeating: 0.5), of: { x in SIMD4.pow(x, 2) }) - #expect(vwpb.value == .init(repeating: 0.25)) - #expect(vwpb.pullback(.one) == .init(repeating: 1.0)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.25)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1)) } @Test func testSqrt() { let vwpb = valueWithPullback(at: SIMD8(repeating: 4), of: SIMD8.sqrt) - #expect(vwpb.value == .init(repeating: 2)) - #expect(vwpb.pullback(.one) == .init(repeating: 0.25)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 2)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.25)) } @Test func testRoot() { let vwpb = valueWithPullback(at: SIMD16(repeating: 16), of: { x in SIMD16.root(x, 4) }) - #expect(vwpb.value == .init(repeating: 2)) - #expect(vwpb.pullback(.one) == .init(repeating: 1 / 32)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 2)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / 32)) } @Test func testAtan2() { let vwpb = valueWithPullback(at: SIMD32(repeating: 1), SIMD32(repeating: 0), of: SIMD32.atan2) - #expect(vwpb.value == .init(repeating: 1.5707964)) // .pi / 2 - #expect(vwpb.pullback(.one) == (.init(repeating: 0), .init(repeating: -1))) + #expect(vwpb.value[0].isApproximatelyEqual(to: .pi / 2)) + let gradient = vwpb.pullback(.one) + #expect(gradient.0[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) + #expect(gradient.1[0].isApproximatelyEqual(to: -1)) } @Test func testErf() { let vwpb = valueWithPullback(at: SIMD64(repeating: 0.5), of: SIMD64.erf) - #expect(vwpb.value == .init(repeating: 0.5204999)) - #expect(vwpb.pullback(.one) == .init(repeating: 0.87878263)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 0.5204998778)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0.8787825789354449)) } @Test func testErfc() { let vwpb = valueWithPullback(at: SIMD2(repeating: 0.5), of: SIMD2.erfc) - #expect(vwpb.value == .init(repeating: 0.4795001221869535)) - #expect(vwpb.pullback(.one) == .init(repeating: -0.8787825789354449)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1 - 0.5204998778)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -0.8787825789354449)) } @Test func testExp2() { let vwpb = valueWithPullback(at: SIMD4(repeating: 2), of: SIMD4.exp2) - #expect(vwpb.value == .init(repeating: 4)) - #expect(vwpb.pullback(.one) == .init(repeating: 4 * .log(2))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 4)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 4 * .log(2))) } @Test func testExp10() { let vwpb = valueWithPullback(at: SIMD8(repeating: 2), of: SIMD8.exp10) - #expect(vwpb.value == .init(repeating: 100)) - #expect(vwpb.pullback(.one) == .init(repeating: 100 * .log(10))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 100)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 100 * .log(10))) } @Test func testHypot() { let vwpb = valueWithPullback(at: SIMD16(repeating: 3), SIMD16(repeating: 4), of: SIMD16.hypot) - #expect(vwpb.value == .init(repeating: 5)) - #expect(vwpb.pullback(.one) == (.init(repeating: 3 / 5), .init(repeating: 4 / 5))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 5)) + let gradient = vwpb.pullback(.one) + #expect(gradient.0[0].isApproximatelyEqual(to: 3 / 5)) + #expect(gradient.1[0].isApproximatelyEqual(to: 4 / 5)) } @Test(.disabled("derivative not implemented")) func testGamma() { let vwpb = valueWithPullback(at: SIMD32(repeating: 2), of: SIMD32.gamma) - #expect(vwpb.value == .init(repeating: 1)) - #expect(vwpb.pullback(.one) == .init(repeating: 0)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) } @Test func testLog2() { let vwpb = valueWithPullback(at: SIMD64(repeating: 2), of: SIMD64.log2) - #expect(vwpb.value == .init(repeating: 1)) - #expect(vwpb.pullback(.one) == .init(repeating: 1 / .log(4))) + #expect(vwpb.value[0].isApproximatelyEqual(to: 1)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: 1 / .log(4))) } @Test func testLog10() { let vwpb = valueWithPullback(at: 2.0, of: Float.log10) - #expect(vwpb.value == 0.30103) // .log(2) / .log(10) - #expect(vwpb.pullback(1) == 1 / .log(100)) + #expect(vwpb.value.isApproximatelyEqual(to: .log(2) / .log(10))) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 1 / .log(100))) } @Test(.disabled("derivative not implemented")) func testLogGamma() { let vwpb = valueWithPullback(at: 2, of: Double.logGamma) - #expect(vwpb.value == 0) - #expect(vwpb.pullback(1) == 0) + #expect(vwpb.value.isApproximatelyEqual(to: 0)) + #expect(vwpb.pullback(1).isApproximatelyEqual(to: 0, absoluteTolerance: .ulpOfOne)) } @Test func testAbs() { let vwpb = valueWithPullback(at: SIMD2(repeating: -2), of: SIMD2.abs) - #expect(vwpb.value == .init(repeating: 2)) - #expect(vwpb.pullback(.one) == .init(repeating: -1)) + #expect(vwpb.value[0].isApproximatelyEqual(to: 2)) + #expect(vwpb.pullback(.one)[0].isApproximatelyEqual(to: -1)) } } From d105d4fd2e03bbd7e2c413d025268940e326dba6 Mon Sep 17 00:00:00 2001 From: Jaap Wijnen Date: Thu, 8 May 2025 13:50:23 +0200 Subject: [PATCH 6/6] add feedback --- Sources/CodeGeneratorExecutable/CodeGenerator.swift | 8 +------- .../RealModuleDifferentiableTests.swift | 4 +++- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/Sources/CodeGeneratorExecutable/CodeGenerator.swift b/Sources/CodeGeneratorExecutable/CodeGenerator.swift index 60b88a6..ca56c60 100644 --- a/Sources/CodeGeneratorExecutable/CodeGenerator.swift +++ b/Sources/CodeGeneratorExecutable/CodeGenerator.swift @@ -42,13 +42,7 @@ struct CodeGenerator { let simdType = "SIMD\(simdWidth)<\(floatingPointType)>" // no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those. - let simdAccelerated: Bool - if simdWidth > 16 || (simdWidth == 16 && floatingPointType == "Double") { - simdAccelerated = false - } - else { - simdAccelerated = true - } + let simdAccelerated = simdWidth < 16 || (simdWidth == 16 && floatingPointType == "Float") // Generate RealFunctions implementations on concrete SIMD types to attach derivatives to let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension( diff --git a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift index b27abb8..e49837e 100644 --- a/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift +++ b/Tests/RealModuleDifferentiableTests/RealModuleDifferentiableTests.swift @@ -4,7 +4,9 @@ import Testing @Suite struct TestRegisteredDerivatives { - // These tests are more about the derivatives being correctly registered than the correct values. We're only checking the first value for simds since we run the computation on the same values and only checking the first result makes using `.isApproximatelyEqual(to:)` a lot easier. + // These tests are more about the derivatives being correctly registered than the correct values. We're only checking the first value + // for simds since we run the computation on the same values and only checking the first result makes using `.isApproximatelyEqual(to:)` + // a lot easier. @Test func testExp()