diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 152e331..553a2ca 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -13,7 +13,7 @@ jobs: name: Test Swift ${{ matrix.swift }} Ubuntu Latest strategy: matrix: - swift: ["6.0.3", "6.1"] + swift: ["6.1", "6.1.2"] runs-on: ubuntu-latest container: swift:${{ matrix.swift }} steps: @@ -24,7 +24,7 @@ jobs: name: Test Swift ${{ matrix.swift }} macOS strategy: matrix: - swift: ["6.0.3", "6.1"] + swift: ["6.1", "6.1.2"] runs-on: macos-15 steps: - uses: actions/checkout@v4 diff --git a/Package.swift b/Package.swift index 7593324..a2d88b7 100644 --- a/Package.swift +++ b/Package.swift @@ -23,7 +23,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/apple/swift-numerics", .upToNextMinor(from: "1.0.2")), + .package(url: "https://github.com/apple/swift-numerics", from: "1.1.0"), ], targets: [ .executableTarget(name: "CodeGeneratorExecutable"), diff --git a/Sources/ComplexModuleDifferentiable/Complex+Differentiable.swift b/Sources/ComplexModuleDifferentiable/Complex+Differentiable.swift new file mode 100644 index 0000000..89b7845 --- /dev/null +++ b/Sources/ComplexModuleDifferentiable/Complex+Differentiable.swift @@ -0,0 +1,155 @@ +#if canImport(_Differentiation) + +import ComplexModule + +extension Complex: @retroactive Differentiable where RealType: Differentiable, RealType.TangentVector == RealType { + public typealias TangentVector = Self + + @inlinable + public mutating func move(by offset: Complex) { + self += offset + } +} + +extension Complex where RealType: Differentiable, RealType.TangentVector == RealType { + @derivative(of: init(_:_:)) + @_transparent + public static func _vjpInit(_ real: RealType, _ imaginary: RealType) -> (value: Complex, pullback: (Complex) -> (RealType, RealType)) { + ( + value: .init(real, imaginary), + pullback: { v in (v.real, v.imaginary) } + ) + } + + @derivative(of: init(_:)) + @_transparent + public static func _vjpInit(_ real: RealType) -> (value: Complex, pullback: (Complex) -> RealType) { + ( + value: .init(real), + pullback: { v in v.real } + ) + } + + @derivative(of: init(imaginary:)) + @_transparent + public static func _vjpInit(imaginary: RealType) -> (value: Complex, pullback: (Complex) -> RealType) { + ( + value: .init(imaginary: imaginary), + pullback: { v in v.imaginary } + ) + } + + @derivative(of: real) + @_transparent + public func _vjpReal() -> (value: RealType, pullback: (RealType) -> Complex) { + (value: real, pullback: { v in Complex(v, .zero) }) + } + + @derivative(of: real.set) + @_transparent + public mutating func _vjpRealSet(_ newValue: RealType) -> (value: Void, pullback: (inout Complex) -> RealType) { + self.real = newValue + return ( + value: (), + pullback: { v in + let real = v.real + v.real = .zero + return real + } + ) + } + + @derivative(of: imaginary) + @_transparent + public func _vjpImaginary() -> (value: RealType, pullback: (RealType) -> Complex) { + (value: imaginary, pullback: { v in Complex(.zero, v) }) + } + + @derivative(of: imaginary.set) + @_transparent + public mutating func _vjpImaginarySet(_ newValue: RealType) -> (value: Void, pullback: (inout Complex) -> RealType) { + self.imaginary = newValue + return ( + value: (), + pullback: { v in + let imaginary = v.imaginary + v.imaginary = .zero + return imaginary + } + ) + } + + @derivative(of: +) + @_transparent + public static func _vjpAdd(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) { + (value: z + w, pullback: { v in (v, v) }) + } + + @derivative(of: +=) + @_transparent + public static func _vjpAddAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) { + z += w + return (value: (), pullback: { v in v }) + } + + @derivative(of: -) + @_transparent + public static func _vjpSubtract(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) { + (value: z - w, pullback: { v in (v, -v) }) + } + + @derivative(of: -=) + @_transparent + public static func _vjpSubtractAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) { + z -= w + return (value: (), pullback: { v in -v }) + } + + @derivative(of: *) + @_transparent + public static func _vjpMultiply(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) { + (value: z * w, pullback: { v in (w * v, z * v) }) + } + + @derivative(of: *=) + @_transparent + public static func _vjpMultiplyAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) { + defer { z *= w } + return ( + value: (), + pullback: { [z = z] v in + let drhs = z * v + v *= w + return drhs + } + ) + } + + @derivative(of: /) + @_transparent + public static func _vjpDivide(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) { + (value: z / w, pullback: { v in (v / w, -z / (w * w) * v) }) + } + + @derivative(of: /=) + @_transparent + public static func _vjpDivideAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) { + defer { z /= w } + return ( + value: (), + pullback: { [z = z] v in + let drhs = -z / (w * w) * v + v /= w + return drhs + } + ) + } + + @derivative(of: conjugate) + @_transparent + public func _vjpConjugate() -> (value: Complex, pullback: (Complex) -> Complex) { + (value: conjugate, pullback: { v in v.conjugate }) + } +} + +#endif diff --git a/Tests/ComplexModuleDifferentiableTests/Complex+DifferentiableTests.swift b/Tests/ComplexModuleDifferentiableTests/Complex+DifferentiableTests.swift new file mode 100644 index 0000000..0c75b25 --- /dev/null +++ b/Tests/ComplexModuleDifferentiableTests/Complex+DifferentiableTests.swift @@ -0,0 +1,63 @@ +import ComplexModuleDifferentiable +import Testing + +#if canImport(_Differentiation) +@Suite +struct ComplexDifferentiableTests { + @Test + func componentGetter() { + #expect(gradient(at: Complex(5, 5)) { $0.real * 2 } == Complex(2, 0)) + #expect(gradient(at: Complex(5, 5)) { $0.imaginary * 2 } == Complex(0, 2)) + #expect(gradient(at: Complex(5, 5)) { + $0.real * 5 + $0.imaginary * 2 + } == Complex(5, 2)) + } + + @Test + func initializer() { + let pb1 = pullback(at: 4, -3) { r, i in Complex(r, i) } + let tan1 = pb1(Complex(-1, 2)) + #expect(tan1.0 == -1) + #expect(tan1.1 == 2) + + let pb2 = pullback(at: 4, -3) { r, i in Complex(r * r, i + i) + } + let tan2 = pb2(Complex(-1, 1)) + #expect(tan2.0 == -8) + #expect(tan2.1 == 2) + } + + @Test + func conjugate() { + let pullback = pullback(at: Complex(20, -4)) { x in x.conjugate } + #expect(pullback(Complex(1, 0)) == Complex(1, 0)) + #expect(pullback(Complex(0, 1)) == Complex(0, -1)) + #expect(pullback(Complex(-1, 1)) == Complex(-1, -1)) + } + + @Test + func arithmetics() { + let additionPullback = pullback(at: Complex(2, 3)) { x in + x + Complex(5, 6) + } + #expect(additionPullback(Complex(1, 1)) == Complex(1, 1)) + + let subtractPullback = pullback(at: Complex(2, 3)) { x in + Complex(5, 6) - x + } + #expect(subtractPullback(Complex(1, 1)) == Complex(-1, -1)) + + let multiplyPullback = pullback(at: Complex(2, 3)) { x in x * x } + #expect(multiplyPullback(Complex(1, 0)) == Complex(4, 6)) + #expect(multiplyPullback(Complex(0, 1)) == Complex(-6, 4)) + #expect(multiplyPullback(Complex(1, 1)) == Complex(-2, 10)) + + let dividePullback = pullback(at: Complex(20, -4)) { x in + x / Complex(2, 2) + } + #expect(dividePullback(Complex(1, 0)) == Complex(0.25, -0.25)) + #expect(dividePullback(Complex(0, 1)) == Complex(0.25, 0.25)) + } +} + +#endif