Skip to content

Commit f0531bb

Browse files
committed
add tests from swift-numerics implementation
1 parent 1e38e8e commit f0531bb

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import ComplexModuleDifferentiable
2+
import Testing
3+
4+
#if canImport(_Differentiation)
5+
@Suite
6+
struct ComplexDifferentiableTests {
7+
@Test
8+
func componentGetter() {
9+
#expect(gradient(at: Complex<Float>(5, 5)) { $0.real * 2 } == Complex(2, 0))
10+
#expect(gradient(at: Complex<Float>(5, 5)) { $0.imaginary * 2 } == Complex(0, 2))
11+
#expect(gradient(at: Complex<Float>(5, 5)) {
12+
$0.real * 5 + $0.imaginary * 2
13+
} == Complex(5, 2))
14+
}
15+
16+
@Test
17+
func initializer() {
18+
let pb1 = pullback(at: 4, -3) { r, i in Complex<Float>(r, i) }
19+
let tan1 = pb1(Complex(-1, 2))
20+
#expect(tan1.0 == -1)
21+
#expect(tan1.1 == 2)
22+
23+
let pb2 = pullback(at: 4, -3) { r, i in Complex<Float>(r * r, i + i)
24+
}
25+
let tan2 = pb2(Complex(-1, 1))
26+
#expect(tan2.0 == -8)
27+
#expect(tan2.1 == 2)
28+
}
29+
30+
@Test
31+
func conjugate() {
32+
let pullback = pullback(at: Complex<Float>(20, -4)) { x in x.conjugate }
33+
#expect(pullback(Complex(1, 0)) == Complex(1, 0))
34+
#expect(pullback(Complex(0, 1)) == Complex(0, -1))
35+
#expect(pullback(Complex(-1, 1)) == Complex(-1, -1))
36+
}
37+
38+
@Test
39+
func arithmetics() {
40+
let additionPullback = pullback(at: Complex<Float>(2, 3)) { x in
41+
x + Complex(5, 6)
42+
}
43+
#expect(additionPullback(Complex(1, 1)) == Complex(1, 1))
44+
45+
let subtractPullback = pullback(at: Complex<Float>(2, 3)) { x in
46+
Complex(5, 6) - x
47+
}
48+
#expect(subtractPullback(Complex(1, 1)) == Complex(-1, -1))
49+
50+
let multiplyPullback = pullback(at: Complex<Float>(2, 3)) { x in x * x }
51+
#expect(multiplyPullback(Complex(1, 0)) == Complex(4, 6))
52+
#expect(multiplyPullback(Complex(0, 1)) == Complex(-6, 4))
53+
#expect(multiplyPullback(Complex(1, 1)) == Complex(-2, 10))
54+
55+
let dividePullback = pullback(at: Complex<Float>(20, -4)) { x in
56+
x / Complex(2, 2)
57+
}
58+
#expect(dividePullback(Complex(1, 0)) == Complex(0.25, -0.25))
59+
#expect(dividePullback(Complex(0, 1)) == Complex(0.25, 0.25))
60+
}
61+
}
62+
63+
#endif

0 commit comments

Comments
 (0)