Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,14 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
return dest
end

function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
dispatch_val(x) = x
dispatch_val(::Val{D}) where {D} = D

# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N}
dims = dispatch_val(dims)
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."

# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
A = maybe_expand_dims(A, dims)
Bs = maybe_expand_dims.(Bs, (dims,))

Expand All @@ -775,7 +779,7 @@ function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) wher
MLIR.Dialects.stablehlo.concatenate(
[A.mlir_data, [B.mlir_data for B in Bs]...];
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
dimension=D - 1, # stablehlo expects this to be zero-indexed
dimension=dims - 1, # stablehlo expects this to be zero-indexed
),
1,
),
Expand Down
Loading