1- struct RealFunctionsDerivativesGenerator {
1+ enum RealFunctionsDerivativesGenerator {
22 static func realFunctionsDerivativesExtension( type: String , floatingPointType: String ) -> String {
33 """
44 #if canImport(_Differentiation)
55 import _Differentiation
66 import RealModule
7-
7+
88 // MARK: ElementaryFunctions derivatives
99 extension \( type) {
1010 @derivative(of: exp)
1111 public static func _vjpExp(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
1212 let value = exp(x)
1313 return (value: value, pullback: { v in v * value })
1414 }
15-
15+
1616 @derivative(of: expMinusOne)
1717 public static func _vjpExpMinusOne(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
1818 return (value: expMinusOne(x), pullback: { v in v * exp(x) })
1919 }
20-
20+
2121 @derivative(of: cosh)
2222 public static func _vjpCosh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
2323 (value: cosh(x), pullback: { v in sinh(x) })
2424 }
25-
25+
2626 @derivative(of: sinh)
2727 public static func _vjpSinh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
2828 (value: sinh(x), pullback: { v in cosh(x) })
2929 }
30-
30+
3131 @derivative(of: tanh)
3232 public static func _vjpTanh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
3333 (
@@ -38,17 +38,17 @@ struct RealFunctionsDerivativesGenerator {
3838 }
3939 )
4040 }
41-
41+
4242 @derivative(of: cos)
4343 public static func _vjpCos(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
4444 (value: cos(x), pullback: { v in -v * sin(x) })
4545 }
46-
46+
4747 @derivative(of: sin)
4848 public static func _vjpSin(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
4949 (value: sin(x), pullback: { v in v * cos(x) })
5050 }
51-
51+
5252 @derivative(of: tan)
5353 public static func _vjpTan(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
5454 (
@@ -59,117 +59,117 @@ struct RealFunctionsDerivativesGenerator {
5959 }
6060 )
6161 }
62-
62+
6363 @derivative(of: log(_:))
6464 public static func _vjpLog(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
6565 (value: log(x), pullback: { v in v / x })
6666 }
67-
67+
6868 @derivative(of: acosh)
6969 public static func _vjpAcosh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
7070 // only valid for x > 1
7171 return (value: acosh(x), pullback: { v in v / sqrt(x * x - 1) })
7272 }
73-
73+
7474 @derivative(of: asinh)
7575 public static func _vjpAsinh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
7676 (value: asinh(x), pullback: { v in v / sqrt(x * x + 1) })
7777 }
78-
78+
7979 @derivative(of: atanh)
8080 public static func _vjpAtanh(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
8181 (value: atanh(x), pullback: { v in v / (1 - x * x) })
8282 }
83-
83+
8484 @derivative(of: acos)
8585 public static func _vjpAcos(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
8686 (value: acos(x), pullback: { v in -v / (1 - x * x) })
8787 }
88-
88+
8989 @derivative(of: asin)
9090 public static func _vjpAsin(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
9191 (value: asin(x), pullback: { v in v / (1 - x * x) })
9292 }
93-
93+
9494 @derivative(of: atan)
9595 public static func _vjpAtan(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
9696 (value: atan(x), pullback: { v in v / (x * x + 1) })
9797 }
98-
98+
9999 @derivative(of: log(onePlus:))
100100 public static func _vjpLog(onePlus x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
101101 (value: log(onePlus: x), pullback: { v in v / (1 + x) })
102102 }
103-
103+
104104 @derivative(of: pow)
105105 public static func _vjpPow(_ x: \( type) , _ y: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> ( \( type) , \( type) )) {
106106 let value = pow(x, y)
107107 // pullback wrt y is not defined for (x < 0) and (x = 0, y = 0)
108108 return (value: value, pullback: { v in (v * y * pow(x, y - 1), v * value * log(x)) })
109109 }
110-
110+
111111 @derivative(of: pow)
112112 public static func _vjpPow(_ x: \( type) , _ n: Int) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
113113 (value: pow(x, n), pullback: { v in v * \( floatingPointType) (n) * pow(x, n - 1) })
114114 }
115-
115+
116116 @derivative(of: sqrt)
117117 public static func _vjpSqrt(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
118118 let value = sqrt(x)
119119 return (value: value, pullback: { v in v / (2 * value) })
120120 }
121-
121+
122122 @derivative(of: root)
123123 public static func _vjpRoot(_ x: \( type) , _ n: Int) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
124124 let value = root(x, n)
125125 return (value: value, pullback: { v in v * value / (x * \( floatingPointType) (n)) })
126126 }
127127 }
128-
128+
129129 // MARK: RealFunctions derivatives
130130 extension \( type) {
131131 @derivative(of: erf)
132132 public static func _vjpErf(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
133133 (value: erf(x), pullback: { v in 2 * exp(-x * x) / .sqrt( \( floatingPointType) .pi) })
134134 }
135-
135+
136136 @derivative(of: erfc)
137137 public static func _vjpErfc(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
138138 (value: erfc(x), pullback: { v in -2 * exp(-x * x) / .sqrt( \( floatingPointType) .pi) })
139139 }
140-
140+
141141 @derivative(of: exp2)
142142 public static func _vjpExp2(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
143143 let value = exp2(x)
144144 return (value, { v in v * value * .log(2) })
145145 }
146-
146+
147147 @derivative(of: exp10)
148148 public static func _vjpExp10(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
149149 let value = exp10(x)
150150 return (value, { v in v * value * .log(10) })
151151 }
152-
152+
153153 @derivative(of: gamma)
154154 public static func _vjpGamma(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
155155 fatalError( " unimplemented " )
156156 }
157-
157+
158158 @derivative(of: log2)
159159 public static func _vjpLog2(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
160160 (value: log2(x), pullback: { v in v / (.log(2) * x) })
161161 }
162-
162+
163163 @derivative(of: log10)
164164 public static func _vjpLog10(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
165165 (value: log10(x), pullback: { v in v / (.log(10) * x) })
166166 }
167-
167+
168168 @derivative(of: logGamma)
169169 public static func _vjpLogGamma(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
170170 fatalError( " unimplemented " )
171171 }
172-
172+
173173 @derivative(of: atan2)
174174 public static func _vjpAtan2(y: \( type) , x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> ( \( type) , \( type) )) {
175175 (
@@ -180,7 +180,7 @@ struct RealFunctionsDerivativesGenerator {
180180 }
181181 )
182182 }
183-
183+
184184 @derivative(of: hypot)
185185 public static func _vjpHypot(_ x: \( type) , _ y: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> ( \( type) , \( type) )) {
186186 (
@@ -192,15 +192,16 @@ struct RealFunctionsDerivativesGenerator {
192192 )
193193 }
194194 }
195-
195+
196196 // MARK: FloatingPoint functions derivatives
197197 extension \( type) {
198198 @derivative(of: abs)
199199 public static func _vjpAbs(_ x: \( type) ) -> (value: \( type) , pullback: ( \( type) ) -> \( type) ) {
200200 \( {
201201 if type == floatingPointType {
202202 " x < 0 ? (value: -x, pullback: { v in .zero - v }) : (value: x, pullback: { v in v }) "
203- } else {
203+ }
204+ else {
204205 " (value: abs(x), pullback: { v in v.replacing(with: -v, where: x .< .zero) }) "
205206 }
206207 } ( ) )
0 commit comments