Skip to content

Commit 4662bca

Browse files
committed
Add derivatives for RealFunctions
Also adds RealFunctions conformance to SIMD types and their derivatives
1 parent b8c1f4b commit 4662bca

File tree

7 files changed

+558
-0
lines changed

7 files changed

+558
-0
lines changed

Package.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import PackageDescription
44

55
let package = Package(
66
name: "swift-numerics-differentiable",
7+
platforms: [
8+
.macOS(.v13),
9+
],
710
products: [
811
.library(
912
name: "NumericsDifferentiable",
@@ -22,6 +25,12 @@ let package = Package(
2225
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.2"),
2326
],
2427
targets: [
28+
.executableTarget(name: "CodeGeneratorExecutable"),
29+
.plugin(
30+
name: "CodeGeneratorPlugin",
31+
capability: .buildTool,
32+
dependencies: ["CodeGeneratorExecutable"]
33+
),
2534
.target(
2635
name: "NumericsDifferentiable",
2736
dependencies: [
@@ -34,6 +43,9 @@ let package = Package(
3443
name: "RealModuleDifferentiable",
3544
dependencies: [
3645
.product(name: "RealModule", package: "swift-numerics"),
46+
],
47+
plugins: [
48+
"CodeGeneratorPlugin",
3749
]
3850
),
3951
.target(

Plugins/CodeGeneratorPlugin.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import Foundation
2+
import PackagePlugin
3+
4+
@main
5+
struct CodeGeneratorPlugin: BuildToolPlugin {
6+
func createBuildCommands(context: PackagePlugin.PluginContext, target _: PackagePlugin.Target) async throws -> [PackagePlugin.Command] {
7+
let output = context.pluginWorkDirectoryURL
8+
9+
let floatingPointTypes: [String] = ["Float", "Double"]
10+
let simdSizes = [2, 4, 8, 16, 32, 64]
11+
12+
let outputFiles = floatingPointTypes.flatMap { floatingPointType in
13+
simdSizes.flatMap { simdSize in
14+
[
15+
output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift"),
16+
output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift"),
17+
]
18+
} + [
19+
output.appending(component: "\(floatingPointType)+RealFunctions+Derivatives.swift"),
20+
]
21+
} + [
22+
output.appending(component: "SIMD+RealFunctions.swift"),
23+
]
24+
25+
return [
26+
.buildCommand(
27+
displayName: "Generate Code",
28+
executable: try context.tool(named: "CodeGeneratorExecutable").url,
29+
arguments: [output.relativePath],
30+
environment: [:],
31+
inputFiles: [],
32+
outputFiles: outputFiles
33+
)
34+
]
35+
}
36+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import Foundation
2+
3+
@main
4+
struct CodeGenerator {
5+
static func main() throws {
6+
// Use swift-argument-parser or just CommandLine, here we just imply that 2 paths are passed in: input and output
7+
guard CommandLine.arguments.count == 2 else {
8+
throw CodeGeneratorError.invalidArguments
9+
}
10+
// arguments[0] is the path to this command line tool
11+
let output = URL(filePath: CommandLine.arguments[1])
12+
13+
// generate default implementations of RealFunctions for SIMD protocol
14+
let realFunctionSIMDFileURL = output.appending(component: "SIMD+RealFunctions.swift")
15+
let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension(objectType: "SIMD", type: "Self", whereClause: true, simdAccelerated: false)
16+
try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8)
17+
18+
let floatingPointTypes: [String] = ["Float", "Double"]
19+
let simdSizes: [Int] = [2, 4, 8, 16, 32, 64]
20+
21+
for floatingPointType in floatingPointTypes {
22+
// Generator Derivatives for RealFunctions for floating point types
23+
let realFunctionDerivativesFileURL = output.appending(
24+
component: "\(floatingPointType)+RealFunctions+Derivatives.swift",
25+
directoryHint: .notDirectory
26+
)
27+
let type = floatingPointType
28+
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType)
29+
try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8)
30+
31+
for simdSize in simdSizes {
32+
let realFunctionFileURL = output.appending(
33+
component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift",
34+
directoryHint: .notDirectory
35+
)
36+
let simdType = "SIMD\(simdSize)<\(floatingPointType)>"
37+
38+
// no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those.
39+
var simdAccelerated: Bool
40+
if simdSize > 16 || (simdSize == 16 && floatingPointType == "Double") {
41+
simdAccelerated = false
42+
} else {
43+
simdAccelerated = true
44+
}
45+
46+
// Generate RealFunctions implementations on concrete SIMD types to attach derivatives to
47+
let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension(objectType: simdType, type: simdType, whereClause: false, simdAccelerated: simdAccelerated)
48+
try realFunctionsExtensionCode.write(to: realFunctionFileURL, atomically: true, encoding: .utf8)
49+
50+
// Generate RealFunctions derivatives for concrete SIMD types
51+
let realFunctionDerivativesFileURL = output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift")
52+
let type = "SIMD\(simdSize)<\(floatingPointType)>"
53+
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType)
54+
try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8)
55+
}
56+
}
57+
}
58+
}
59+
60+
struct RealFunction {
61+
var name: String
62+
var simdName: String?
63+
var arguments: [Argument]
64+
65+
struct Argument {
66+
var name: String
67+
var label: String?
68+
var type: String? = nil
69+
}
70+
71+
init(name: String, simdName: String? = nil, arguments: [Argument] = [.init(name: "x", label: "_")]) {
72+
self.name = name
73+
self.simdName = simdName
74+
self.arguments = arguments
75+
}
76+
}
77+
78+
enum CodeGeneratorError: Error {
79+
case invalidArguments
80+
case invalidData
81+
}
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
struct RealFunctionsDerivativesGenerator {
2+
static func realFunctionsDerivativesExtension(type: String, floatingPointType: String) -> String {
3+
"""
4+
#if canImport(_Differentiation)
5+
import _Differentiation
6+
import RealModule
7+
8+
// MARK: ElementaryFunctions derivatives
9+
extension \(type) {
10+
@derivative(of: exp)
11+
public static func _vjpExp(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
12+
let value = exp(x)
13+
return (value: value, pullback: { v in v * value })
14+
}
15+
16+
@derivative(of: expMinusOne)
17+
public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
18+
return (value: expMinusOne(x), pullback: { v in v * exp(x) })
19+
}
20+
21+
@derivative(of: cosh)
22+
public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
23+
(value: cosh(x), pullback: { v in sinh(x) })
24+
}
25+
26+
@derivative(of: sinh)
27+
public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
28+
(value: sinh(x), pullback: { v in cosh(x) })
29+
}
30+
31+
@derivative(of: tanh)
32+
public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
33+
(
34+
value: tanh(x),
35+
pullback: { v in
36+
let coshx = cosh(x)
37+
return v / (coshx * coshx)
38+
}
39+
)
40+
}
41+
42+
@derivative(of: cos)
43+
public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
44+
(value: cos(x), pullback: { v in -v * sin(x) })
45+
}
46+
47+
@derivative(of: sin)
48+
public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
49+
(value: sin(x), pullback: { v in v * cos(x) })
50+
}
51+
52+
@derivative(of: tan)
53+
public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
54+
(
55+
value: tan(x),
56+
pullback: { v in
57+
let cosx = cos(x)
58+
return v / (cosx * cosx)
59+
}
60+
)
61+
}
62+
63+
@derivative(of: log(_:))
64+
public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
65+
(value: log(x), pullback: { v in v / x })
66+
}
67+
68+
@derivative(of: acosh)
69+
public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
70+
// only valid for x > 1
71+
return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) })
72+
}
73+
74+
@derivative(of: asinh)
75+
public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
76+
(value: asinh(x), pullback: { v in v / sqrt(x * x + 1) })
77+
}
78+
79+
@derivative(of: atanh)
80+
public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
81+
(value: atanh(x), pullback: { v in v / (1 - x * x) })
82+
}
83+
84+
@derivative(of: acos)
85+
public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
86+
(value: acos(x), pullback: { v in -v / (1 - x * x) })
87+
}
88+
89+
@derivative(of: asin)
90+
public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
91+
(value: asin(x), pullback: { v in v / (1 - x * x) })
92+
}
93+
94+
@derivative(of: atan)
95+
public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
96+
(value: atan(x), pullback: { v in v / (x * x + 1) })
97+
}
98+
99+
@derivative(of: log(onePlus:))
100+
public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
101+
(value: log(onePlus: x), pullback: { v in v / (1 + x) })
102+
}
103+
104+
@derivative(of: pow)
105+
public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
106+
let value = pow(x, y)
107+
// pullback wrt y is not defined for (x < 0) and (x = 0, y = 0)
108+
return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) })
109+
}
110+
111+
@derivative(of: pow)
112+
public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) {
113+
(value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) })
114+
}
115+
116+
@derivative(of: sqrt)
117+
public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
118+
let value = sqrt(x)
119+
return (value: value, pullback: { v in v / (2 * value) })
120+
}
121+
122+
@derivative(of: root)
123+
public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) {
124+
let value = root(x, n)
125+
return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) })
126+
}
127+
}
128+
129+
// MARK: RealFunctions derivatives
130+
extension \(type) {
131+
@derivative(of: erf)
132+
public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
133+
(value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) })
134+
}
135+
136+
@derivative(of: erfc)
137+
public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
138+
(value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) })
139+
}
140+
141+
@derivative(of: exp2)
142+
public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
143+
let value = exp2(x)
144+
return (value, { v in v * value * .log(2) })
145+
}
146+
147+
@derivative(of: exp10)
148+
public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
149+
let value = exp10(x)
150+
return (value, { v in v * value * .log(10) })
151+
}
152+
153+
@derivative(of: gamma)
154+
public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
155+
fatalError("unimplemented")
156+
}
157+
158+
@derivative(of: log2)
159+
public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
160+
(value: log2(x), pullback: { v in v / (.log(2) * x) })
161+
}
162+
163+
@derivative(of: log10)
164+
public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
165+
(value: log10(x), pullback: { v in v / (.log(10) * x) })
166+
}
167+
168+
@derivative(of: logGamma)
169+
public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
170+
fatalError("unimplemented")
171+
}
172+
173+
@derivative(of: atan2)
174+
public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
175+
(
176+
value: atan2(y: y, x: x),
177+
pullback: { v in
178+
let c = x * x + y * y
179+
return (v * x / c, -v * y / c)
180+
}
181+
)
182+
}
183+
184+
@derivative(of: hypot)
185+
public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
186+
(
187+
value: hypot(x, y),
188+
pullback: { v in
189+
let c = sqrt(x * x + y * y)
190+
return (v * x / c, v * y / c)
191+
}
192+
)
193+
}
194+
}
195+
196+
// MARK: FloatingPoint functions derivatives
197+
extension \(type) {
198+
@derivative(of: abs)
199+
public static func _vjpAbs(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
200+
\({
201+
if type == floatingPointType {
202+
"x < 0 ? (value: -x, pullback: { v in .zero - v }) : (value: x, pullback: { v in v })"
203+
} else {
204+
"(value: abs(x), pullback: { v in v.replacing(with: -v, where: x .< .zero) })"
205+
}
206+
}())
207+
}
208+
}
209+
#endif
210+
"""
211+
}
212+
}

0 commit comments

Comments
 (0)