|
| 1 | +import ComplexModuleDifferentiable |
| 2 | +import Testing |
| 3 | + |
| 4 | +#if canImport(_Differentiation) |
| 5 | +@Suite |
| 6 | +struct ComplexDifferentiableTests { |
| 7 | + @Test |
| 8 | + func componentGetter() { |
| 9 | + #expect(gradient(at: Complex<Float>(5, 5)) { $0.real * 2 } == Complex(2, 0)) |
| 10 | + #expect(gradient(at: Complex<Float>(5, 5)) { $0.imaginary * 2 } == Complex(0, 2)) |
| 11 | + #expect(gradient(at: Complex<Float>(5, 5)) { |
| 12 | + $0.real * 5 + $0.imaginary * 2 |
| 13 | + } == Complex(5, 2)) |
| 14 | + } |
| 15 | + |
| 16 | + @Test |
| 17 | + func initializer() { |
| 18 | + let pb1 = pullback(at: 4, -3) { r, i in Complex<Float>(r, i) } |
| 19 | + let tan1 = pb1(Complex(-1, 2)) |
| 20 | + #expect(tan1.0 == -1) |
| 21 | + #expect(tan1.1 == 2) |
| 22 | + |
| 23 | + let pb2 = pullback(at: 4, -3) { r, i in Complex<Float>(r * r, i + i) |
| 24 | + } |
| 25 | + let tan2 = pb2(Complex(-1, 1)) |
| 26 | + #expect(tan2.0 == -8) |
| 27 | + #expect(tan2.1 == 2) |
| 28 | + } |
| 29 | + |
| 30 | + @Test |
| 31 | + func conjugate() { |
| 32 | + let pullback = pullback(at: Complex<Float>(20, -4)) { x in x.conjugate } |
| 33 | + #expect(pullback(Complex(1, 0)) == Complex(1, 0)) |
| 34 | + #expect(pullback(Complex(0, 1)) == Complex(0, -1)) |
| 35 | + #expect(pullback(Complex(-1, 1)) == Complex(-1, -1)) |
| 36 | + } |
| 37 | + |
| 38 | + @Test |
| 39 | + func arithmetics() { |
| 40 | + let additionPullback = pullback(at: Complex<Float>(2, 3)) { x in |
| 41 | + x + Complex(5, 6) |
| 42 | + } |
| 43 | + #expect(additionPullback(Complex(1, 1)) == Complex(1, 1)) |
| 44 | + |
| 45 | + let subtractPullback = pullback(at: Complex<Float>(2, 3)) { x in |
| 46 | + Complex(5, 6) - x |
| 47 | + } |
| 48 | + #expect(subtractPullback(Complex(1, 1)) == Complex(-1, -1)) |
| 49 | + |
| 50 | + let multiplyPullback = pullback(at: Complex<Float>(2, 3)) { x in x * x } |
| 51 | + #expect(multiplyPullback(Complex(1, 0)) == Complex(4, 6)) |
| 52 | + #expect(multiplyPullback(Complex(0, 1)) == Complex(-6, 4)) |
| 53 | + #expect(multiplyPullback(Complex(1, 1)) == Complex(-2, 10)) |
| 54 | + |
| 55 | + let dividePullback = pullback(at: Complex<Float>(20, -4)) { x in |
| 56 | + x / Complex(2, 2) |
| 57 | + } |
| 58 | + #expect(dividePullback(Complex(1, 0)) == Complex(0.25, -0.25)) |
| 59 | + #expect(dividePullback(Complex(0, 1)) == Complex(0.25, 0.25)) |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +#endif |
0 commit comments