diff --git a/src/StructsOfArrays.jl b/src/StructsOfArrays.jl index 12c2751..af64b26 100644 --- a/src/StructsOfArrays.jl +++ b/src/StructsOfArrays.jl @@ -1,10 +1,11 @@ module StructsOfArrays -export StructOfArrays +export StructOfArrays, ScalarRepeat immutable StructOfArrays{T,N,U<:Tuple} <: AbstractArray{T,N} arrays::U end + @generated function StructOfArrays{T}(::Type{T}, dims::Integer...) (!isleaftype(T) || T.mutable) && return :(throw(ArgumentError("can only create an StructOfArrays of leaf type immutables"))) isempty(T.types) && return :(throw(ArgumentError("cannot create an StructOfArrays of an empty or bitstype"))) @@ -14,6 +15,30 @@ end end StructOfArrays(T::Type, dims::Tuple{Vararg{Integer}}) = StructOfArrays(T, dims...) +function StructOfArrays(T::Type, first_array::AbstractArray, rest::AbstractArray...) + (!isleaftype(T) || T.mutable) && throw(ArgumentError( + "can only create an StructOfArrays of leaf type immutables" + )) + arrays = (first_array, rest...) + target_eltypes = flattened_bitstypes(T) + source_eltypes = DataType[] + #flatten array eltypes + for elem in arrays + append!(source_eltypes, flattened_bitstypes(eltype(elem))) + end + # flattened eltypes don't match with flattened struct type + if target_eltypes != source_eltypes + throw(ArgumentError("""$T does not match the given parameters. + Argument types: $(map(typeof, arrays)) + Flattened struct types: $target_eltypes + Flattened argument types: $source_eltypes + """)) + end + # flattened they match! ♥💕 + typetuple = Tuple{map(typeof, arrays)...} + StructOfArrays{T, ndims(first_array), typetuple}(arrays) +end + Base.linearindexing{T<:StructOfArrays}(::Type{T}) = Base.LinearFast() @generated function Base.similar{T}(A::StructOfArrays, ::Type{T}, dims::Dims) @@ -31,13 +56,131 @@ Base.convert{T,S,N}(::Type{StructOfArrays{T}}, A::AbstractArray{S,N}) = Base.convert{T,N}(::Type{StructOfArrays}, A::AbstractArray{T,N}) = convert(StructOfArrays{T,N}, A) -Base.size(A::StructOfArrays) = size(A.arrays[1]) -Base.size(A::StructOfArrays, d) = size(A.arrays[1], d) +Base.size(A::StructOfArrays) = size(first(A.arrays)) +Base.size(A::StructOfArrays, d) = size(first(A.arrays), d) + +""" +returns all field types of a composite type or tuple. +If it's neither composite, nor tuple, it will just return the DataType. +""" +fieldtypes{T<:Tuple}(::Type{T}) = (T.parameters...) +function fieldtypes{T}(::Type{T}) + if nfields(T) > 0 + return ntuple(i->fieldtype(T, i), nfields(T)) + else + return T + end +end + +""" +Returns a flattened and unflattened view of the elemenents of a type +E.g: +immutable X +x::Float32 +y::Float32 +end +immutable Y +a::X # tuples would get expanded as well +b::Float32 +c::Float32 +end +Would return +[Float32, Float32, Float32, Float32] +and +[(Y, [(X, [Float32, Float32]), Float32, Float32]] +""" +function flattened_bitstypes{T}(::Type{T}, flattened=DataType[]) + fields = fieldtypes(T) + if isa(fields, DataType) + if (!isleaftype(T) || T.mutable) + throw(ArgumentError("can only create an StructOfArrays of leaf type immutables")) + end + push!(flattened, fields) + return flattened + else + for T in fields + flattened_bitstypes(T, flattened) + end + end + flattened +end + +""" +Takes a tuple of array types with arbitrary structs as elements. +return `flat_indices` and `temporaries`. `flat_indices` is a vector with indices to every elemen in the array. +`temporaries` is a vector of temporaries, which effectively store the elemens from the arrays +E.g. +flatindexes((Vector{Vec3f0}) will return: +with `array_expr=(A.arrays)` and `index_expr=:([i...])`: +`temporaries`: + [:(value1 = A.arrays[i...])] +`flat_indices`: + [:(value1.(1).(1)), :(value1.(1).(2)), :(value1.(1).(3))] # .(1) to acces tuple of Vec3 +""" +function flatindexes(arrays) + temporaries = [] + flat_indices = [] + for (i, array) in enumerate(arrays) + tmpsym = symbol("value$i") + push!(temporaries, :($(tmpsym) = A.arrays[$i][i...])) + index_expr = :($tmpsym) + flatindexes(eltype(array), index_expr, flat_indices) + end + flat_indices, temporaries +end -@generated function Base.getindex{T}(A::StructOfArrays{T}, i::Integer...) - Expr(:block, Expr(:meta, :inline), - Expr(:new, T, [:(A.arrays[$j][i...]) for j = 1:length(T.types)]...)) +function flatindexes(T, index_expr, flat_indices) + fields = fieldtypes(T) + if isa(fields, DataType) + push!(flat_indices, index_expr) + return nothing + else + for (i,T) in enumerate(fields) + new_index_expr = :($(index_expr).($i)) + flatindexes(T, new_index_expr, flat_indices) + end + end + nothing end + +""" +Creates a nested type T from elements in `flat_indices`. +`flat_indices` can be any array with expressions inside, as long as there is an +element for every field in `T`. +""" +function typecreator(T, lower_constr, flat_indices, i=1) + i>length(flat_indices) && return i + # we need to special case tuples, since e.g. Tuple{Float32, Float32}(1f0, 1f0) + # is not defined. + if T<:Tuple + constructor = Expr(:tuple) + else + constructor = Expr(:call, T) + end + push!(lower_constr.args, constructor) + fields = fieldtypes(T) + if isa(fields, DataType) + push!(constructor.args, flat_indices[i]) + return i+1 + else + for T in fields + i = typecreator(T, constructor, flat_indices, i) + end + end + return i +end + +@generated function Base.getindex{T, N, ArrayTypes}(A::StructOfArrays{T, N, ArrayTypes}, i::Integer...) + #flatten the indices, + flat_indices, temporaries = flatindexes((ArrayTypes.parameters...)) + type_constructor = Expr(:block) + # create a constructor expression, which uses the flattened indexes to create the type + typecreator(T, type_constructor, flat_indices) + # put everything in a block! + Expr(:block, Expr(:meta, :inline), temporaries..., type_constructor) +end + + @generated function Base.setindex!{T}(A::StructOfArrays{T}, x, i::Integer...) quote $(Expr(:meta, :inline)) diff --git a/test/runtests.jl b/test/runtests.jl index a03e03a..cc94086 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,3 +31,56 @@ small = StructOfArrays(Complex64, 2) @test typeof(similar(small, SubString)) === Vector{SubString} @test typeof(similar(small, OneField)) === Vector{OneField} @test typeof(similar(small, Complex128)) <: StructOfArrays + +immutable Vec{N,T} + _::NTuple{N,T} +end +immutable HyperCube{N,T} + origin::Vec{N,T} + width::Vec{N,T} +end +immutable Instance{P, S, T, R} + primitive::P + scale::S + translation::T + rotation::R +end +immutable ScalarRepeat{T,N} <: AbstractArray{T,N} + value::T + size::NTuple{N,Int} +end +Base.size(sr::ScalarRepeat) = sr.size +Base.size(sr::ScalarRepeat, d) = sr.size[d] +Base.getindex(sr::ScalarRepeat, i...) = sr.value +Base.linearindexing{T<:ScalarRepeat}(::Type{T}) = Base.LinearFast() + + +function test_topologic_structs() + hco_x,hco_yz = rand(Float32, 10), [Vec{2,Float32}((rand(Float32), rand(Float32))) for i=1:10] + hcw_z,hcw_xy = rand(Float32, 10), [Vec{2,Float32}((rand(Float32), rand(Float32))) for i=1:10] + scale = ScalarRepeat(1f0, (10,)) + translation = ScalarRepeat(Vec{3, Float32}((2,1,3)), (10,)) + rotation = [Vec{4, Float32}((rand(Float32),rand(Float32),rand(Float32),rand(Float32))) for i=1:10] + soa = StructOfArrays( + Instance{HyperCube{3, Float32}, Float32, Vec{3, Float32}, Vec{4,Float32}}, + hco_x,hco_yz, hcw_xy, hcw_z, scale, translation, rotation + ) + zipped = zip(hco_x,hco_yz, hcw_xy, hcw_z, scale, translation, rotation) + for (i,(ox,oyz, wxy, wz, s, t, r)) in enumerate(zipped) + instance = soa[i] + @test instance.primitive.origin.(1).(1) === ox + @test instance.primitive.origin.(1).(2) === oyz.(1).(1) + @test instance.primitive.origin.(1).(3) === oyz.(1).(2) + + @test instance.primitive.width.(1).(1) === wxy.(1).(1) + @test instance.primitive.width.(1).(2) === wxy.(1).(2) + @test instance.primitive.width.(1).(3) === wz + + @test instance.scale === s + @test instance.translation === t + @test instance.rotation === r + + end +end + +test_topologic_structs()