Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,18 @@ _semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:An
batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B)

"""
batched_vec(A::Array{T,3}, B::Matrix)
batched_vec(A::Array{T,3}, b::Vector)
batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix)
batched_vec(A::AbstractArray{T,3}, b::AbstractVector)
batched_vec(A::AbstractArray, B::AbstractArray)

Batched matrix-vector multiplication:
Batched matrix-vector multiplication. For the 3D case:
the result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`,
or else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`.

For the general N-D case where `ndims(A) == ndims(B) + 1`:
the result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`.
The batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`.

With the same argument types, `batched_mul(A, B)` would regard `B` as
a fixed matrix, not a batch of vectors. Both reshape and then
call `batched_mul(::Array{T,3}, ::Array{T,3})`.
Expand All @@ -181,8 +186,27 @@ julia> batched_vec(A,B) |> size

julia> batched_vec(A,b) |> size
(16, 32)

julia> A4d, B3d = randn(16,8,10,32), randn(8,10,32); # 4D and 3D arrays

julia> batched_vec(A4d, B3d) |> size
(16, 10, 32)
```
"""
function batched_vec(A::AbstractArray, B::AbstractArray)
ndims(A) == ndims(B) + 1 || throw(DimensionMismatch(
"batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))"))
size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch(
"batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])"))

# Reshape B to add a singleton dimension for matrix multiplication
B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...)
# Perform batched multiplication
C = batched_mul(A, B_reshaped)
# Remove the singleton dimension
return dropdims(C, dims=2)
end

batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))

Expand Down
30 changes: 30 additions & 0 deletions test/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,33 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect

gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P))
end

@testset "batched_vec: N-D batches" begin
# Test 4D case: A is 4D, B is 3D
A4d = randn(4, 5, 3, 2) # (matrix_rows, matrix_cols, batch_dim1, batch_dim2)
B3d = randn(5, 3, 2) # (vector_length, batch_dim1, batch_dim2)

C = batched_vec(A4d, B3d)
@test size(C) == (4, 3, 2)

# Manual verification
for i in 1:3, j in 1:2
@test C[:, i, j] ≈ A4d[:, :, i, j] * B3d[:, i, j]
end

# Test 5D case: A is 5D, B is 4D
A5d = randn(3, 4, 2, 3, 2) # (matrix_rows, matrix_cols, batch1, batch2, batch3)
B4d = randn(4, 2, 3, 2) # (vector_length, batch1, batch2, batch3)

C5 = batched_vec(A5d, B4d)
@test size(C5) == (3, 2, 3, 2)

# Manual verification for a few cases
@test C5[:, 1, 1, 1] ≈ A5d[:, :, 1, 1, 1] * B4d[:, 1, 1, 1]
@test C5[:, 2, 3, 2] ≈ A5d[:, :, 2, 3, 2] * B4d[:, 2, 3, 2]

# Test dimension mismatch errors
@test_throws DimensionMismatch batched_vec(randn(3, 4, 2), randn(4, 3)) # ndims mismatch
@test_throws DimensionMismatch batched_vec(randn(3, 4, 2, 3), randn(4, 2, 2)) # batch size mismatch

end
Loading