Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
name: Test Swift ${{ matrix.swift }} Ubuntu Latest
strategy:
matrix:
swift: ["6.0.3", "6.1"]
swift: ["6.1", "6.1.2"]
runs-on: ubuntu-latest
container: swift:${{ matrix.swift }}
steps:
Expand All @@ -24,7 +24,7 @@ jobs:
name: Test Swift ${{ matrix.swift }} macOS
strategy:
matrix:
swift: ["6.0.3", "6.1"]
swift: ["6.1", "6.1.2"]
runs-on: macos-15
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/apple/swift-numerics", .upToNextMinor(from: "1.0.2")),
.package(url: "https://github.com/apple/swift-numerics", from: "1.1.0"),
],
targets: [
.executableTarget(name: "CodeGeneratorExecutable"),
Expand Down
155 changes: 155 additions & 0 deletions Sources/ComplexModuleDifferentiable/Complex+Differentiable.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#if canImport(_Differentiation)

import ComplexModule

extension Complex: @retroactive Differentiable where RealType: Differentiable, RealType.TangentVector == RealType {
public typealias TangentVector = Self

@inlinable
public mutating func move(by offset: Complex<RealType>) {
self += offset
}
}

extension Complex where RealType: Differentiable, RealType.TangentVector == RealType {
@derivative(of: init(_:_:))
@_transparent
public static func _vjpInit(_ real: RealType, _ imaginary: RealType) -> (value: Complex, pullback: (Complex) -> (RealType, RealType)) {
(
value: .init(real, imaginary),
pullback: { v in (v.real, v.imaginary) }
)
}

@derivative(of: init(_:))
@_transparent
public static func _vjpInit(_ real: RealType) -> (value: Complex, pullback: (Complex) -> RealType) {
(
value: .init(real),
pullback: { v in v.real }
)
}

@derivative(of: init(imaginary:))
@_transparent
public static func _vjpInit(imaginary: RealType) -> (value: Complex, pullback: (Complex) -> RealType) {
(
value: .init(imaginary: imaginary),
pullback: { v in v.imaginary }
)
}

@derivative(of: real)
@_transparent
public func _vjpReal() -> (value: RealType, pullback: (RealType) -> Complex) {
(value: real, pullback: { v in Complex(v, .zero) })
}

@derivative(of: real.set)
@_transparent
public mutating func _vjpRealSet(_ newValue: RealType) -> (value: Void, pullback: (inout Complex) -> RealType) {
self.real = newValue
return (
value: (),
pullback: { v in
let real = v.real
v.real = .zero
return real
}
)
}

@derivative(of: imaginary)
@_transparent
public func _vjpImaginary() -> (value: RealType, pullback: (RealType) -> Complex) {
(value: imaginary, pullback: { v in Complex(.zero, v) })
}

@derivative(of: imaginary.set)
@_transparent
public mutating func _vjpImaginarySet(_ newValue: RealType) -> (value: Void, pullback: (inout Complex) -> RealType) {
self.imaginary = newValue
return (
value: (),
pullback: { v in
let imaginary = v.imaginary
v.imaginary = .zero
return imaginary
}
)
}

@derivative(of: +)
@_transparent
public static func _vjpAdd(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
(value: z + w, pullback: { v in (v, v) })
}

@derivative(of: +=)
@_transparent
public static func _vjpAddAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
z += w
return (value: (), pullback: { v in v })
}

@derivative(of: -)
@_transparent
public static func _vjpSubtract(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
(value: z - w, pullback: { v in (v, -v) })
}

@derivative(of: -=)
@_transparent
public static func _vjpSubtractAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
z -= w
return (value: (), pullback: { v in -v })
}

@derivative(of: *)
@_transparent
public static func _vjpMultiply(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
(value: z * w, pullback: { v in (w * v, z * v) })
}

@derivative(of: *=)
@_transparent
public static func _vjpMultiplyAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
defer { z *= w }
return (
value: (),
pullback: { [z = z] v in
let drhs = z * v
v *= w
return drhs
}
)
}

@derivative(of: /)
@_transparent
public static func _vjpDivide(z: Complex, w: Complex) -> (value: Complex, pullback: (Complex) -> (Complex, Complex)) {
(value: z / w, pullback: { v in (v / w, -z / (w * w) * v) })
}

@derivative(of: /=)
@_transparent
public static func _vjpDivideAssign(z: inout Complex, w: Complex) -> (value: Void, pullback: (inout Complex) -> (Complex)) {
defer { z /= w }
return (
value: (),
pullback: { [z = z] v in
let drhs = -z / (w * w) * v
v /= w
return drhs
}
)
}

@derivative(of: conjugate)
@_transparent
public func _vjpConjugate() -> (value: Complex, pullback: (Complex) -> Complex) {
(value: conjugate, pullback: { v in v.conjugate })
}
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import ComplexModuleDifferentiable
import Testing

#if canImport(_Differentiation)
@Suite
struct ComplexDifferentiableTests {
@Test
func componentGetter() {
#expect(gradient(at: Complex<Float>(5, 5)) { $0.real * 2 } == Complex(2, 0))
#expect(gradient(at: Complex<Float>(5, 5)) { $0.imaginary * 2 } == Complex(0, 2))
#expect(gradient(at: Complex<Float>(5, 5)) {
$0.real * 5 + $0.imaginary * 2
} == Complex(5, 2))
}

@Test
func initializer() {
let pb1 = pullback(at: 4, -3) { r, i in Complex<Float>(r, i) }
let tan1 = pb1(Complex(-1, 2))
#expect(tan1.0 == -1)
#expect(tan1.1 == 2)

let pb2 = pullback(at: 4, -3) { r, i in Complex<Float>(r * r, i + i)
}
let tan2 = pb2(Complex(-1, 1))
#expect(tan2.0 == -8)
#expect(tan2.1 == 2)
}

@Test
func conjugate() {
let pullback = pullback(at: Complex<Float>(20, -4)) { x in x.conjugate }
#expect(pullback(Complex(1, 0)) == Complex(1, 0))
#expect(pullback(Complex(0, 1)) == Complex(0, -1))
#expect(pullback(Complex(-1, 1)) == Complex(-1, -1))
}

@Test
func arithmetics() {
let additionPullback = pullback(at: Complex<Float>(2, 3)) { x in
x + Complex(5, 6)
}
#expect(additionPullback(Complex(1, 1)) == Complex(1, 1))

let subtractPullback = pullback(at: Complex<Float>(2, 3)) { x in
Complex(5, 6) - x
}
#expect(subtractPullback(Complex(1, 1)) == Complex(-1, -1))

let multiplyPullback = pullback(at: Complex<Float>(2, 3)) { x in x * x }
#expect(multiplyPullback(Complex(1, 0)) == Complex(4, 6))
#expect(multiplyPullback(Complex(0, 1)) == Complex(-6, 4))
#expect(multiplyPullback(Complex(1, 1)) == Complex(-2, 10))

let dividePullback = pullback(at: Complex<Float>(20, -4)) { x in
x / Complex(2, 2)
}
#expect(dividePullback(Complex(1, 0)) == Complex(0.25, -0.25))
#expect(dividePullback(Complex(0, 1)) == Complex(0.25, 0.25))
}
}

#endif