Skip to content

Commit 1e38e8e

Browse files
committed
transfer Complex, Differentiable conformance here and add some additional derivatives
The whole Complex conformance to Differentiable now lives in swift-numerics-differentiable. I also added some derivatives we previously couldn't express like: +=, -=, *=, /=, real.set, imaginary.set, and some we simply didn't have yet: init(_:), init(imaginary:)
1 parent b9804af commit 1e38e8e

File tree

2 files changed

+156
-1
lines changed

2 files changed

+156
-1
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ let package = Package(
2323
),
2424
],
2525
dependencies: [
26-
.package(url: "https://github.com/apple/swift-numerics", .upToNextMinor(from: "1.0.2")),
26+
.package(url: "https://github.com/apple/swift-numerics", exact: "1.1.0-prerelease"),
2727
],
2828
targets: [
2929
.executableTarget(name: "CodeGeneratorExecutable"),
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#if canImport(_Differentiation)
2+
3+
import ComplexModule
4+
5+
extension Complex: @retroactive Differentiable where RealType: Differentiable, RealType.TangentVector == RealType {
6+
public typealias TangentVector = Self
7+
8+
@inlinable
9+
public mutating func move(by offset: Complex<RealType>) {
10+
self += offset
11+
}
12+
}
13+
14+
extension Complex where RealType: Differentiable, RealType.TangentVector == RealType {
15+
@derivative(of: init(_:_:))
16+
@_transparent
17+
public static func _vjpInit(_ real: RealType, _ imaginary: RealType) -> (value: Complex, pullback: (Complex) -> (RealType, RealType)) {
18+
(
19+
value: .init(real, imaginary),
20+
pullback: { v in (v.real, v.imaginary) }
21+
)
22+
}
23+
24+
@derivative(of: init(_:))
25+
@_transparent
26+
public static func _vjpInit(_ real: RealType) -> (value: Complex, pullback: (Complex) -> RealType) {
27+
(
28+
value: .init(real),
29+
pullback: { v in v.real }
30+
)
31+
}
32+
33+
@derivative(of: init(imaginary:))
34+
@_transparent
35+
public static func _vjpInit(imaginary: RealType) -> (value: Complex, pullback: (Complex) -> RealType) {
36+
(
37+
value: .init(imaginary: imaginary),
38+
pullback: { v in v.imaginary }
39+
)
40+
}
41+
42+
@derivative(of: real)
43+
@_transparent
44+
public func _vjpReal() -> (value: RealType, pullback: (RealType) -> Complex) {
45+
(value: real, pullback: { v in Complex(v, .zero) })
46+
}
47+
48+
@derivative(of: real.set)
49+
@_transparent
50+
public mutating func _vjpRealSet(_ newValue: RealType) -> (value: Void, pullback: (inout Complex) -> RealType) {
51+
self.real = newValue
52+
return (
53+
value: (),
54+
pullback: { v in
55+
let real = v.real
56+
v.real = .zero
57+
return real
58+
}
59+
)
60+
}
61+
62+
@derivative(of: imaginary)
63+
@_transparent
64+
public func _vjpImaginary() -> (value: RealType, pullback: (RealType) -> Complex) {
65+
(value: imaginary, pullback: { v in Complex(.zero, v) })
66+
}
67+
68+
@derivative(of: imaginary.set)
69+
@_transparent
70+
public mutating func _vjpImaginarySet(_ newValue: RealType) -> (value: Void, pullback: (inout Complex) -> RealType) {
71+
self.imaginary = newValue
72+
return (
73+
value: (),
74+
pullback: { v in
75+
let imaginary = v.imaginary
76+
v.imaginary = .zero
77+
return imaginary
78+
}
79+
)
80+
}
81+
82+
@derivative(of: +)
83+
@_transparent
84+
public static func _vjpAdd(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
85+
(value: z + w, pullback: { v in (v, v) })
86+
}
87+
88+
@derivative(of: +=)
89+
@_transparent
90+
public static func _vjpAddAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
91+
z += w
92+
return (value: (), pullback: { v in v })
93+
}
94+
95+
@derivative(of: -)
96+
@_transparent
97+
public static func _vjpSubtract(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
98+
(value: z - w, pullback: { v in (v, -v) })
99+
}
100+
101+
@derivative(of: -=)
102+
@_transparent
103+
public static func _vjpSubtractAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
104+
z -= w
105+
return (value: (), pullback: { v in -v })
106+
}
107+
108+
@derivative(of: *)
109+
@_transparent
110+
public static func _vjpMultiply(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
111+
(value: z * w, pullback: { v in (w * v, z * v) })
112+
}
113+
114+
@derivative(of: *=)
115+
@_transparent
116+
public static func _vjpMultiplyAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
117+
defer { z *= w }
118+
return (
119+
value: (),
120+
pullback: { [z = z] v in
121+
let drhs = z * v
122+
v *= w
123+
return drhs
124+
}
125+
)
126+
}
127+
128+
@derivative(of: /)
129+
@_transparent
130+
public static func _vjpDivide(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
131+
(value: z / w, pullback: { v in (v / w, -z / (w * w) * v) })
132+
}
133+
134+
@derivative(of: /=)
135+
@_transparent
136+
public static func _vjpDivideAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
137+
defer { z /= w }
138+
return (
139+
value: (),
140+
pullback: { [z = z] v in
141+
let drhs = -z / (w * w) * v
142+
v /= w
143+
return drhs
144+
}
145+
)
146+
}
147+
148+
@derivative(of: conjugate)
149+
@_transparent
150+
public func _vjpConjugate() -> (value: Complex, pullback: (Complex) -> Complex) {
151+
(value: conjugate, pullback: { v in v.conjugate })
152+
}
153+
}
154+
155+
#endif

0 commit comments

Comments
 (0)