Skip to content

Commit 4c1e25a

Browse files
authored
add SubDitStr function to enable DitStr slicing (#54)
* add `SubDitStr` function to enable `DitStr` slicing * * fix doc and doctests * new `DitStr` function to raise `SubDitStr` struct to `DitStr` * tests added * * add @views macro for `SubDitStr` * add benchmark * remove `using BenchmarkTools` * comment bm() * update * update
1 parent e264b10 commit 4c1e25a

File tree

3 files changed

+162
-20
lines changed

3 files changed

+162
-20
lines changed

src/BitBasis.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export bitarray, basis, packbits, bfloat, bfloat_r, bint, bint_r, flip
66
export anyone, allone, bmask, baddrs, readbit, setbit, controller
77
export swapbits, ismatch, neg, breflect, btruncate
88
export LongLongUInt
9+
export SubDitStr
910

1011
include("utils.jl")
1112
include("longlonguint.jl")

src/DitStr.jl

Lines changed: 139 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
const UIntStorage = Union{UInt8,UInt16,UInt32,UInt64,UInt128,LongLongUInt}
2-
const IntStorage = Union{Int8,Int16,Int32,Int64,Int128,BigInt,UIntStorage}
1+
const UIntStorage = Union{UInt8, UInt16, UInt32, UInt64, UInt128, LongLongUInt}
2+
const IntStorage = Union{Int8, Int16, Int32,Int64,Int128,BigInt,UIntStorage}
33

44
########## DitStr #########
55
"""
@@ -37,26 +37,26 @@ function DitStr{D,T}(vector::Union{AbstractVector,Tuple}) where {D,T}
3737
val = zero(T)
3838
D_power_k = one(T)
3939
for k in 1:length(vector)
40-
0 <= vector[k] <= D-1 || error("expect 0-$(D-1), got $(vector[k])")
40+
0 <= vector[k] <= D - 1 || error("expect 0-$(D-1), got $(vector[k])")
4141
val = accum(Val{D}(), val, vector[k], D_power_k)
4242
D_power_k = _lshift(Val{D}(), D_power_k, 1)
4343
end
4444
return DitStr{D,length(vector),T}(val)
4545
end
4646
# val += x * y
47-
accum(::Val{D}, val, x, y) where D = val + x * y
47+
accum(::Val{D}, val, x, y) where {D} = val + x * y
4848
accum(::Val{2}, val, x, y) = iszero(x) ? val : val y
4949
DitStr{D}(vector::Tuple{T,Vararg{T,N}}) where {N,T,D} = DitStr{D,T}(vector)
5050
DitStr{D}(vector::AbstractVector{T}) where {D,T} = DitStr{D,T}(vector)
5151
DitStr{D,N,T}(val::DitStr) where {D,N,T<:Integer} = convert(DitStr{D,N,T}, val)
5252
DitStr{D,N,T}(val::DitStr{D,N,T}) where {D,N,T<:Integer} = val
5353

5454
const DitStr64{D,N} = DitStr{D,N,Int64}
55-
const LongDitStr{D,N} = DitStr{D,N,LongLongUInt{C}} where C
55+
const LongDitStr{D,N} = DitStr{D,N,LongLongUInt{C}} where {C}
5656
LongDitStr{D}(vector::AbstractVector{T}) where {D,T} = DitStr{D,longinttype(length(vector), D)}(vector)
5757

5858
Base.show(io::IO, ditstr::DitStr{D,N,<:Integer}) where {D,N} =
59-
print(io, string(buffer(ditstr), base = D, pad = N), "$(''+D)")
59+
print(io, string(buffer(ditstr), base=D, pad=N), "$(''+D)")
6060
Base.show(io::IO, ditstr::DitStr{D,N,<:LongLongUInt}) where {D,N} =
6161
print(io, join(map(string, [ditstr[end:-1:1]...])), "$(''+D)")
6262

@@ -146,7 +146,7 @@ Read the dit config at given location.
146146
"""
147147
@inline @generated function readat(x::DitStr{D,N,T}, locs::Integer...) where {D,N,T}
148148
length(locs) == 0 && return :(zero($T))
149-
Expr(:call, :+, [:($_lshift($(Val(D)), mod($_rshift($(Val{D}()), buffer(x), locs[$i]-1), $D), $(i - 1))) for i=1:length(locs)]...)
149+
Expr(:call, :+, [:($_lshift($(Val(D)), mod($_rshift($(Val{D}()), buffer(x), locs[$i] - 1), $D), $(i - 1))) for i = 1:length(locs)]...)
150150
end
151151

152152
Base.@propagate_inbounds function Base.getindex(dit::DitStr{D,N}, index::Integer) where {D,N}
@@ -159,6 +159,128 @@ Base.@propagate_inbounds function Base.getindex(dit::DitStr{D,N,T}, itr::Abstrac
159159
return map(x -> readat(dit, x), itr)
160160
end
161161

162+
163+
"""
164+
SubDitStr{D,N,T<:Integer} <: Integer
165+
166+
The struct as a `SubString`-like object for `DitStr`(`SubString` is an official implementation of sliced strings, see [String](https://docs.julialang.org/en/v1/base/strings/#Base.SubString) for reference). This slicing returns a view into the parent `DitStr` instead of making a copy (similar to the `@views` macro for strings).
167+
168+
`SubDitStr` can be used to describe the qubit configuration within the subspace of the entire Hilbert space.It provides similar `getindex`, `length` functions as `DitStr`.
169+
170+
SubDitStr(dit::DitStr{D,N,T}, i::Int, j::Int)
171+
SubDitStr(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer})
172+
173+
Or by `@views` macro for `DitStr` (this macro makes your life easier by supporting `begin` and `end` syntax):
174+
175+
@views dit[i:j]
176+
177+
Returns a `SubDitStr`.
178+
179+
### Examples
180+
181+
```jldoctest
182+
julia> x = DitStr{3, 5}(71)
183+
02122 ₍₃₎
184+
185+
julia> sx = SubDitStr(x, 2, 4)
186+
SubDitStr{3, 5, Int64}(02122 ₍₃₎, 1, 3)
187+
188+
julia> @views x[2:end]
189+
SubDitStr{3, 5, Int64}(02122 ₍₃₎, 1, 4)
190+
191+
julia> sx == dit"212;3"
192+
true
193+
```
194+
"""
195+
struct SubDitStr{D,N,T<:Integer} <: Integer
196+
dit::DitStr{D,N,T}
197+
offset::Int
198+
ncodeunits::Int
199+
200+
function SubDitStr(dit::DitStr{D,N,T}, i::Int, j::Int) where {D,N,T}
201+
i j || return new{D,N,T}(dit, 0, 0)
202+
@boundscheck begin
203+
1 i length(dit) || throw(BoundsError(dit, i))
204+
1 j length(dit) || throw(BoundsError(dit, i))
205+
end
206+
return new{D,N,T}(dit, i - 1, j - i + 1)
207+
end
208+
end
209+
210+
Base.@propagate_inbounds Base.view(dit::DitStr{D,N,T}, i::Integer, j::Integer) where {D,N,T} = SubDitStr(dit, i, j)
211+
Base.@propagate_inbounds Base.view(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer}) where {D,N,T} = SubDitStr(dit, first(r), last(r))
212+
Base.@propagate_inbounds Base.maybeview(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer}) where {D,N,T} = view(dit,r)
213+
214+
"""
215+
DitStr(dit::SubDitStr{D,N,T}) -> DitStr{D,N,T}
216+
Raise type `SubDitStr` to `DitStr`.
217+
```jldoctest
218+
julia> x = DitStr{3, 5}(71)
219+
02122 ₍₃₎
220+
221+
julia> sx = SubDitStr(x, 2, 4)
222+
SubDitStr{3, 5, Int64}(02122 ₍₃₎, 1, 3)
223+
224+
julia> DitStr(sx)
225+
212 ₍₃₎
226+
```
227+
"""
228+
function DitStr(dit::SubDitStr{D,N,T}) where {D,N,T}
229+
val = zero(T)
230+
D_power_k = one(T)
231+
len = ncodeunits(dit)
232+
for k in 1:len
233+
val = accum(Val{D}(), val, readat(dit.dit, dit.offset + k), D_power_k)
234+
D_power_k = _lshift(Val{D}(), D_power_k, 1)
235+
end
236+
return DitStr{D,len,T}(val)
237+
end
238+
239+
ncodeunits(dit::SubDitStr{D,N,T}) where {D,N,T} = dit.ncodeunits
240+
241+
## bounds checking ##
242+
Base.checkbounds(::Type{Bool}, dit::SubDitStr{D,N,T}, i::Integer) where {D,N,T} =
243+
1 i ncodeunits(dit)
244+
Base.checkbounds(::Type{Bool}, dit::SubDitStr{D,N,T}, r::AbstractRange{<:Integer}) where {D,N,T} =
245+
isempty(r) || (1 minimum(r) && maximum(r) ncodeunits(dit))
246+
Base.checkbounds(::Type{Bool}, dit::SubDitStr{D,N,T}, I::AbstractArray{<:Integer}) where {D,N,T} =
247+
all(i -> checkbounds(Bool, dit, i), I)
248+
Base.checkbounds(dit::SubDitStr{D,N,T}, I::Union{Integer,AbstractArray}) where {D,N,T} = checkbounds(Bool, dit, I) ? nothing : throw(BoundsError(dit, I))
249+
250+
Base.@propagate_inbounds SubDitStr(dit::DitStr{D,N,T}, i::Integer, j::Integer) where {D,N,T} = SubDitStr{D,N,T}(dit, i, j)
251+
Base.@propagate_inbounds SubDitStr(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer}) where {D,N,T} = SubDitStr{D,N,T}(dit, first(r), last(r))
252+
253+
Base.@propagate_inbounds function SubDitStr(dit::SubDitStr{D,N,T}, i::Int, j::Int) where {D,N,T}
254+
@boundscheck i j && checkbounds(dit, i:j)
255+
SubString(dit.dit, dit.offset + i, dit.offset + j)
256+
end
257+
258+
Base.length(dit::SubDitStr{D,N,T}) where {D,N,T} = ncodeunits(dit)
259+
260+
"""
261+
==(lhs::SubDitStr{D,N,T}, rhs::DitStr{D,N,T}) -> Bool
262+
==(lhs::DitStr{D,N,T}, rhs::SubDitStr{D,N,T}) -> Bool
263+
==(lhs::SubDitStr{D,N,T}, rhs::SubDitStr{D,N,T}) -> Bool
264+
Compare the equality between `SubDitStr` and `DitStr`.
265+
"""
266+
function Base.:(==)(lhs::SubDitStr{D,N1}, rhs::DitStr{D,N2}) where {D,N1,N2}
267+
length(lhs) == length(rhs) && @inbounds all(i -> lhs[i] == rhs[i], 1:length(lhs))
268+
end
269+
270+
function Base.:(==)(lhs::SubDitStr{D,N1}, rhs::SubDitStr{D,N2}) where {D,N1,N2}
271+
length(lhs) == length(rhs) && @inbounds all(i -> lhs[i] == rhs[i], 1:length(lhs))
272+
end
273+
274+
function Base.:(==)(lhs::DitStr{D,N1}, rhs::SubDitStr{D,N2}) where {D,N1,N2}
275+
length(lhs) == length(rhs) && @inbounds all(i -> lhs[i] == rhs[i], 1:length(lhs))
276+
end
277+
278+
function Base.getindex(dit::SubDitStr{D,N,T}, i::Integer) where {D,N,T}
279+
@boundscheck checkbounds(dit, i)
280+
@inbounds return getindex(dit.dit, dit.offset + i)
281+
end
282+
283+
162284
# TODO: support AbstractArray, should return its corresponding shape
163285

164286
Base.@propagate_inbounds function Base.getindex(
@@ -178,7 +300,7 @@ end
178300

179301
Base.eltype(::DitStr{D,N,T}) where {D,N,T} = T
180302

181-
function Base.iterate(dit::DitStr, state::Integer = 1)
303+
function Base.iterate(dit::DitStr, state::Integer=1)
182304
if state > length(dit)
183305
return nothing
184306
else
@@ -201,8 +323,8 @@ function Base.rand(::Type{T}) where {D,N,Ti,T<:DitStr{D,N,Ti}}
201323
end
202324

203325
######################### Operations #####################
204-
_lshift(::Val{D}, x::Integer, i::Integer) where D = x * (D^i)
205-
_rshift(::Val{D}, x::Integer, i::Integer) where D = x ÷ (D^i)
326+
_lshift(::Val{D}, x::Integer, i::Integer) where {D} = x * (D^i)
327+
_rshift(::Val{D}, x::Integer, i::Integer) where {D} = x ÷ (D^i)
206328
_lshift(::Val{2}, x::Integer, i::Integer) = x << i
207329
_rshift(::Val{2}, x::Integer, i::Integer) = x >> i
208330

@@ -230,10 +352,10 @@ Base.repeat(s::DitStr, n::Integer) = join([s for i in 1:n]...)
230352
Create an onehot vector in type `Vector{T}` or a batch of onehot vector in type `Matrix{T}`, where index `x + 1` is one.
231353
One can specify the value of the nonzero entry by inputing a pair.
232354
"""
233-
onehot(::Type{T}, n::DitStr{D,N,T1}; nbatch=nothing) where {D,T, N,T1} = _onehot(T, D^N, buffer(n)+1; nbatch)
355+
onehot(::Type{T}, n::DitStr{D,N,T1}; nbatch=nothing) where {D,T,N,T1} = _onehot(T, D^N, buffer(n) + 1; nbatch)
234356
onehot(n::DitStr; nbatch=nothing) = onehot(ComplexF64, n; nbatch)
235357

236-
readbit(x::DitStr{D, N, LongLongUInt{C}}, loc::Int) where {D, N, C} = readbit(x.buf, loc)
358+
readbit(x::DitStr{D,N,LongLongUInt{C}}, loc::Int) where {D,N,C} = readbit(x.buf, loc)
237359

238360
########## @dit_str macro ##############
239361
"""
@@ -297,20 +419,20 @@ function parse_dit(::Type{T}, str::String) where {T<:Integer}
297419
if res === nothing
298420
error("Input string literal format error, should be e.g. `dit\"01121;3\"`")
299421
end
300-
return _parse_dit(Val(parse(Int,res[2])), T, res[1])
422+
return _parse_dit(Val(parse(Int, res[2])), T, res[1])
301423
end
302424

303-
function _parse_dit(::Val{D}, ::Type{T}, str::AbstractString) where {D, T<:Integer}
425+
function _parse_dit(::Val{D}, ::Type{T}, str::AbstractString) where {D,T<:Integer}
304426
TT = T <: LongLongUInt ? longinttype(count(isdigit, str), D) : T
305427
_parse_dit_safe(Val(D), TT, str)
306428
end
307429

308-
function _parse_dit_safe(::Val{D}, ::Type{T}, str::AbstractString) where {D, T<:Integer}
430+
function _parse_dit_safe(::Val{D}, ::Type{T}, str::AbstractString) where {D,T<:Integer}
309431
val = zero(T)
310432
k = 0
311433
maxk = max_num_elements(T, D)
312434
for each in reverse(str)
313-
k >= maxk-1 && error("string length is larger than $(maxk), use @ldit_str instead")
435+
k >= maxk - 1 && error("string length is larger than $(maxk), use @ldit_str instead")
314436
v = each - '0'
315437
if 0 <= v < D
316438
val += _lshift(Val(D), T(v), k)
@@ -324,6 +446,6 @@ function _parse_dit_safe(::Val{D}, ::Type{T}, str::AbstractString) where {D, T<:
324446
return DitStr{D,k,T}(val)
325447
end
326448

327-
max_num_elements(::Type{T}, D::Int) where T<:Integer = floor(Int, log(typemax(T))/log(D))
449+
max_num_elements(::Type{T}, D::Int) where {T<:Integer} = floor(Int, log(typemax(T)) / log(D))
328450
max_num_elements(::Type{BigInt}, D::Int) = typemax(Int)
329451
max_num_elements(::Type{LongLongUInt{C}}, D::Int) where {C} = max_num_elements(UInt, D) * C

test/DitStr.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ using BitBasis, Test
55
@test x isa DitStr
66
@test x |> length == 6
77
println(x)
8-
@test buffer(x) === Int64(3^5 + 3^4 + 2*3^3 + 3^2)
8+
@test buffer(x) === Int64(3^5 + 3^4 + 2 * 3^3 + 3^2)
99
@test_throws ErrorException randn(100)[x]
1010
@test BitBasis.readat(x, 2) == Int64(0)
1111
@test x[2] == Int64(0)
1212
@test x[3] == Int64(1)
1313
@test x[4] == Int64(2)
14-
@test [x...] == Int64[0,0,1,2,1,1]
15-
@test [DitStr{3}(Int64[0,0,1,2,1,1])...] == Int64[0,0,1,2,1,1]
14+
@test [x...] == Int64[0, 0, 1, 2, 1, 1]
15+
@test [DitStr{3}(Int64[0, 0, 1, 2, 1, 1])...] == Int64[0, 0, 1, 2, 1, 1]
1616
@test_throws ErrorException BitBasis.parse_dit(Int64, "112103;3")
1717
@test_throws ErrorException BitBasis.parse_dit(Int64, "112101;")
1818

@@ -24,4 +24,23 @@ using BitBasis, Test
2424
@test_throws ErrorException BitBasis.parse_dit(Int64, "12341111111111111111111111111111111111111111111111111111111;5")
2525

2626
@test hash(x) isa UInt64
27+
end
28+
29+
@testset "SubDitStr" begin
30+
x = dit"112100;3"
31+
sx = SubDitStr(x, 2, 4) # bit"210"
32+
@test_throws BoundsError SubDitStr(x, 2, 7)
33+
@test checkbounds(sx, 1) == nothing
34+
@test getindex(sx, 1) == 0
35+
@test getindex(sx, 2) == 1
36+
@test getindex(sx, 3) == 2
37+
@test_throws BoundsError getindex(sx, 4)
38+
@test_throws BoundsError getindex(sx, 0)
39+
@test length(sx) == 3
40+
@test sx == dit"210;3"
41+
@test dit"210;3" == sx
42+
@test DitStr(sx) == dit"210;3"
43+
@test SubDitStr(dit"210;3",1,length(dit"210;3")) == sx
44+
@test (@views x[4:end]) == dit"112;3"
45+
@test (@views x[begin:3]) == dit"100;3"
2746
end

0 commit comments

Comments
 (0)