Skip to content

Commit 448b9ce

Browse files
committed
formatting
1 parent 4662bca commit 448b9ce

File tree

5 files changed

+110
-84
lines changed

5 files changed

+110
-84
lines changed

Plugins/CodeGeneratorPlugin.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ struct CodeGeneratorPlugin: BuildToolPlugin {
88

99
let floatingPointTypes: [String] = ["Float", "Double"]
1010
let simdSizes = [2, 4, 8, 16, 32, 64]
11-
11+
1212
let outputFiles = floatingPointTypes.flatMap { floatingPointType in
1313
simdSizes.flatMap { simdSize in
1414
[
@@ -21,7 +21,7 @@ struct CodeGeneratorPlugin: BuildToolPlugin {
2121
} + [
2222
output.appending(component: "SIMD+RealFunctions.swift"),
2323
]
24-
24+
2525
return [
2626
.buildCommand(
2727
displayName: "Generate Code",
@@ -30,7 +30,7 @@ struct CodeGeneratorPlugin: BuildToolPlugin {
3030
environment: [:],
3131
inputFiles: [],
3232
outputFiles: outputFiles
33-
)
33+
),
3434
]
3535
}
3636
}

Sources/CodeGeneratorExecutable/CodeGenerator.swift

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@ struct CodeGenerator {
1212

1313
// generate default implementations of RealFunctions for SIMD protocol
1414
let realFunctionSIMDFileURL = output.appending(component: "SIMD+RealFunctions.swift")
15-
let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension(objectType: "SIMD", type: "Self", whereClause: true, simdAccelerated: false)
15+
let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension(
16+
objectType: "SIMD",
17+
type: "Self",
18+
whereClause: true,
19+
simdAccelerated: false
20+
)
1621
try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8)
17-
22+
1823
let floatingPointTypes: [String] = ["Float", "Double"]
1924
let simdSizes: [Int] = [2, 4, 8, 16, 32, 64]
2025

@@ -25,32 +30,45 @@ struct CodeGenerator {
2530
directoryHint: .notDirectory
2631
)
2732
let type = floatingPointType
28-
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType)
33+
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(
34+
type: type,
35+
floatingPointType: floatingPointType
36+
)
2937
try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8)
30-
38+
3139
for simdSize in simdSizes {
3240
let realFunctionFileURL = output.appending(
3341
component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions.swift",
3442
directoryHint: .notDirectory
3543
)
3644
let simdType = "SIMD\(simdSize)<\(floatingPointType)>"
37-
45+
3846
// no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those.
3947
var simdAccelerated: Bool
4048
if simdSize > 16 || (simdSize == 16 && floatingPointType == "Double") {
4149
simdAccelerated = false
42-
} else {
50+
}
51+
else {
4352
simdAccelerated = true
4453
}
45-
54+
4655
// Generate RealFunctions implementations on concrete SIMD types to attach derivatives to
47-
let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension(objectType: simdType, type: simdType, whereClause: false, simdAccelerated: simdAccelerated)
56+
let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension(
57+
objectType: simdType,
58+
type: simdType,
59+
whereClause: false,
60+
simdAccelerated: simdAccelerated
61+
)
4862
try realFunctionsExtensionCode.write(to: realFunctionFileURL, atomically: true, encoding: .utf8)
49-
63+
5064
// Generate RealFunctions derivatives for concrete SIMD types
51-
let realFunctionDerivativesFileURL = output.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift")
65+
let realFunctionDerivativesFileURL = output
66+
.appending(component: "SIMD\(simdSize)+\(floatingPointType)+RealFunctions+Derivatives.swift")
5267
let type = "SIMD\(simdSize)<\(floatingPointType)>"
53-
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(type: type, floatingPointType: floatingPointType)
68+
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(
69+
type: type,
70+
floatingPointType: floatingPointType
71+
)
5472
try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8)
5573
}
5674
}
@@ -61,13 +79,13 @@ struct RealFunction {
6179
var name: String
6280
var simdName: String?
6381
var arguments: [Argument]
64-
82+
6583
struct Argument {
6684
var name: String
6785
var label: String?
68-
var type: String? = nil
86+
var type: String?
6987
}
70-
88+
7189
init(name: String, simdName: String? = nil, arguments: [Argument] = [.init(name: "x", label: "_")]) {
7290
self.name = name
7391
self.simdName = simdName

Sources/CodeGeneratorExecutable/RealFunctionsDerivativesGenerator.swift

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
1-
struct RealFunctionsDerivativesGenerator {
1+
enum RealFunctionsDerivativesGenerator {
22
static func realFunctionsDerivativesExtension(type: String, floatingPointType: String) -> String {
33
"""
44
#if canImport(_Differentiation)
55
import _Differentiation
66
import RealModule
7-
7+
88
// MARK: ElementaryFunctions derivatives
99
extension \(type) {
1010
@derivative(of: exp)
1111
public static func _vjpExp(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
1212
let value = exp(x)
1313
return (value: value, pullback: { v in v * value })
1414
}
15-
15+
1616
@derivative(of: expMinusOne)
1717
public static func _vjpExpMinusOne(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
1818
return (value: expMinusOne(x), pullback: { v in v * exp(x) })
1919
}
20-
20+
2121
@derivative(of: cosh)
2222
public static func _vjpCosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
2323
(value: cosh(x), pullback: { v in sinh(x) })
2424
}
25-
25+
2626
@derivative(of: sinh)
2727
public static func _vjpSinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
2828
(value: sinh(x), pullback: { v in cosh(x) })
2929
}
30-
30+
3131
@derivative(of: tanh)
3232
public static func _vjpTanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
3333
(
@@ -38,17 +38,17 @@ struct RealFunctionsDerivativesGenerator {
3838
}
3939
)
4040
}
41-
41+
4242
@derivative(of: cos)
4343
public static func _vjpCos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
4444
(value: cos(x), pullback: { v in -v * sin(x) })
4545
}
46-
46+
4747
@derivative(of: sin)
4848
public static func _vjpSin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
4949
(value: sin(x), pullback: { v in v * cos(x) })
5050
}
51-
51+
5252
@derivative(of: tan)
5353
public static func _vjpTan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
5454
(
@@ -59,117 +59,117 @@ struct RealFunctionsDerivativesGenerator {
5959
}
6060
)
6161
}
62-
62+
6363
@derivative(of: log(_:))
6464
public static func _vjpLog(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
6565
(value: log(x), pullback: { v in v / x })
6666
}
67-
67+
6868
@derivative(of: acosh)
6969
public static func _vjpAcosh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
7070
// only valid for x > 1
7171
return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) })
7272
}
73-
73+
7474
@derivative(of: asinh)
7575
public static func _vjpAsinh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
7676
(value: asinh(x), pullback: { v in v / sqrt(x * x + 1) })
7777
}
78-
78+
7979
@derivative(of: atanh)
8080
public static func _vjpAtanh(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
8181
(value: atanh(x), pullback: { v in v / (1 - x * x) })
8282
}
83-
83+
8484
@derivative(of: acos)
8585
public static func _vjpAcos(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
8686
(value: acos(x), pullback: { v in -v / (1 - x * x) })
8787
}
88-
88+
8989
@derivative(of: asin)
9090
public static func _vjpAsin(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
9191
(value: asin(x), pullback: { v in v / (1 - x * x) })
9292
}
93-
93+
9494
@derivative(of: atan)
9595
public static func _vjpAtan(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
9696
(value: atan(x), pullback: { v in v / (x * x + 1) })
9797
}
98-
98+
9999
@derivative(of: log(onePlus:))
100100
public static func _vjpLog(onePlus x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
101101
(value: log(onePlus: x), pullback: { v in v / (1 + x) })
102102
}
103-
103+
104104
@derivative(of: pow)
105105
public static func _vjpPow(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
106106
let value = pow(x, y)
107107
// pullback wrt y is not defined for (x < 0) and (x = 0, y = 0)
108108
return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) })
109109
}
110-
110+
111111
@derivative(of: pow)
112112
public static func _vjpPow(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) {
113113
(value: pow(x, n), pullback: { v in v * \(floatingPointType)(n) * pow(x, n - 1) })
114114
}
115-
115+
116116
@derivative(of: sqrt)
117117
public static func _vjpSqrt(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
118118
let value = sqrt(x)
119119
return (value: value, pullback: { v in v / (2 * value) })
120120
}
121-
121+
122122
@derivative(of: root)
123123
public static func _vjpRoot(_ x: \(type), _ n: Int) -> (value: \(type), pullback: (\(type)) -> \(type)) {
124124
let value = root(x, n)
125125
return (value: value, pullback: { v in v * value / (x * \(floatingPointType)(n)) })
126126
}
127127
}
128-
128+
129129
// MARK: RealFunctions derivatives
130130
extension \(type) {
131131
@derivative(of: erf)
132132
public static func _vjpErf(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
133133
(value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) })
134134
}
135-
135+
136136
@derivative(of: erfc)
137137
public static func _vjpErfc(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
138138
(value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt(\(floatingPointType).pi) })
139139
}
140-
140+
141141
@derivative(of: exp2)
142142
public static func _vjpExp2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
143143
let value = exp2(x)
144144
return (value, { v in v * value * .log(2) })
145145
}
146-
146+
147147
@derivative(of: exp10)
148148
public static func _vjpExp10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
149149
let value = exp10(x)
150150
return (value, { v in v * value * .log(10) })
151151
}
152-
152+
153153
@derivative(of: gamma)
154154
public static func _vjpGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
155155
fatalError("unimplemented")
156156
}
157-
157+
158158
@derivative(of: log2)
159159
public static func _vjpLog2(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
160160
(value: log2(x), pullback: { v in v / (.log(2) * x) })
161161
}
162-
162+
163163
@derivative(of: log10)
164164
public static func _vjpLog10(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
165165
(value: log10(x), pullback: { v in v / (.log(10) * x) })
166166
}
167-
167+
168168
@derivative(of: logGamma)
169169
public static func _vjpLogGamma(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
170170
fatalError("unimplemented")
171171
}
172-
172+
173173
@derivative(of: atan2)
174174
public static func _vjpAtan2(y: \(type), x: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
175175
(
@@ -180,7 +180,7 @@ struct RealFunctionsDerivativesGenerator {
180180
}
181181
)
182182
}
183-
183+
184184
@derivative(of: hypot)
185185
public static func _vjpHypot(_ x: \(type), _ y: \(type)) -> (value: \(type), pullback: (\(type)) -> (\(type), \(type))) {
186186
(
@@ -192,15 +192,16 @@ struct RealFunctionsDerivativesGenerator {
192192
)
193193
}
194194
}
195-
195+
196196
// MARK: FloatingPoint functions derivatives
197197
extension \(type) {
198198
@derivative(of: abs)
199199
public static func _vjpAbs(_ x: \(type)) -> (value: \(type), pullback: (\(type)) -> \(type)) {
200200
\({
201201
if type == floatingPointType {
202202
"x < 0 ? (value: -x, pullback: { v in .zero - v }) : (value: x, pullback: { v in v })"
203-
} else {
203+
}
204+
else {
204205
"(value: abs(x), pullback: { v in v.replacing(with: -v, where: x .< .zero) })"
205206
}
206207
}())

0 commit comments

Comments
 (0)