Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 570f837

Browse files
authored
Improve ScaledArray and add RotatingTimeRange (#24)
1 parent fc03843 commit 570f837

File tree

12 files changed

+396
-220
lines changed

12 files changed

+396
-220
lines changed

src/DiffinDiffsBase.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ export coef, vcov, stderror, confint, nobs, dof_residual, responsename, coefname
3232
export cb,
3333
,
3434
exampledata,
35+
3536
RotatingTimeValue,
3637
rotatingtime,
38+
RotatingRange,
3739

3840
VecColumnTable,
3941
VecColsRow,
@@ -46,6 +48,7 @@ export cb,
4648
ScaledArray,
4749
ScaledVector,
4850
ScaledMatrix,
51+
scale,
4952

5053
TreatmentSharpness,
5154
SharpDesign,
@@ -126,6 +129,7 @@ export cb,
126129
rescale
127130

128131
include("utils.jl")
132+
include("time.jl")
129133
include("tables.jl")
130134
include("ScaledArrays.jl")
131135
include("treatments.jl")

src/ScaledArrays.jl

Lines changed: 94 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,31 @@ mutable struct RefArray{R}
66
end
77

88
"""
9-
ScaledArray{T,N,RA,S,P} <: AbstractArray{T,N}
9+
ScaledArray{T,R,N,RA,P} <: AbstractArray{T,N}
1010
1111
An array type that stores data as indices of a range.
1212
1313
# Fields
1414
- `refs::RA<:AbstractArray{<:Any, N}`: an array of indices.
15-
- `start::T`: the starting value of the range.
16-
- `step::S`: the step size of the range.
17-
- `stop::T`: the stopping value of the range.
1815
- `pool::P<:AbstractRange{T}`: a range that covers all possible values stored by the array.
16+
- `invpool::Dict{T,R}`: a map from array elements to indices of `pool`.
1917
"""
20-
mutable struct ScaledArray{T,N,RA,S,P} <: AbstractArray{T,N}
18+
mutable struct ScaledArray{T,R,N,RA,P} <: AbstractArray{T,N}
2119
refs::RA
22-
start::T
23-
step::S
24-
stop::T
2520
pool::P
26-
ScaledArray{T,N,RA,S,P}(rs::RefArray{RA}, start::T, step::S, stop::T,
27-
pool::P=start:step:stop) where
28-
{T, N, RA<:AbstractArray{<:Any, N}, S, P<:AbstractRange{T}} =
29-
new{T,N,RA,S,P}(rs.a, start, step, stop, pool)
21+
invpool::Dict{T,R}
22+
ScaledArray{T,R,N,RA,P}(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where
23+
{T, R, N, RA<:AbstractArray{R, N}, P<:AbstractRange{T}} =
24+
new{T,R,N,RA,P}(rs.a, pool, invpool)
3025
end
3126

32-
ScaledArray(rs::RefArray{RA}, start::T, step::S, stop::T, pool::P=start:step:stop) where
33-
{T,RA,S,P} = ScaledArray{T,ndims(RA),RA,S,P}(rs, start, step, stop, pool)
27+
ScaledArray(rs::RefArray{RA}, pool::P, invpool::Dict{T,R}) where
28+
{T,R,RA<:AbstractArray{R},P} = ScaledArray{T,R,ndims(RA),RA,P}(rs, pool, invpool)
3429

35-
const ScaledVector{T} = ScaledArray{T,1}
36-
const ScaledMatrix{T} = ScaledArray{T,2}
30+
const ScaledVector{T,R} = ScaledArray{T,R,1}
31+
const ScaledMatrix{T,R} = ScaledArray{T,R,2}
32+
33+
scale(sa::ScaledArray) = step(sa.pool)
3734

3835
function _validmin(min, xmin, isstart::Bool)
3936
if min === nothing
@@ -55,10 +52,11 @@ function _validmax(max, xmax, isstart::Bool)
5552
return max
5653
end
5754

58-
function validstartstepstop(x::AbstractArray, start, step, stop, usepool)
55+
function validpool(x::AbstractArray, T::Type, start, step, stop, usepool::Bool)
5956
step === nothing && throw(ArgumentError("step cannot be nothing"))
6057
pool = DataAPI.refpool(x)
61-
xmin, xmax = usepool && pool !== nothing ? extrema(pool) : extrema(x)
58+
xs = skipmissing(usepool && pool !== nothing ? pool : x)
59+
xmin, xmax = extrema(xs)
6260
applicable(+, xmin, step) || throw(ArgumentError(
6361
"step of type $(typeof(step)) does not match array with element type $(eltype(x))"))
6462
if xmin + step > xmin
@@ -70,47 +68,83 @@ function validstartstepstop(x::AbstractArray, start, step, stop, usepool)
7068
else
7169
throw(ArgumentError("step cannot be zero"))
7270
end
73-
T = promote_type(eltype(x), eltype(start:step:stop))
74-
return convert(T, start), convert(T, stop)
71+
start = convert(T, start)
72+
stop = convert(T, stop)
73+
return start:step:stop
7574
end
7675

77-
function _scaledlabel(x::AbstractArray, step, reftype::Type{<:Signed}=DEFAULT_REF_TYPE;
78-
start=nothing, stop=nothing, usepool::Bool=true)
79-
start, stop = validstartstepstop(x, start, step, stop, usepool)
80-
pool = start:step:stop
81-
while typemax(reftype) < length(pool)
82-
reftype = widen(reftype)
76+
function _scaledlabel!(labels::AbstractArray, invpool::Dict, xs::AbstractArray, start, step)
77+
z = zero(valtype(invpool))
78+
@inbounds for i in eachindex(labels)
79+
x = xs[i]
80+
lbl = get(invpool, x, z)
81+
if lbl !== z
82+
labels[i] = lbl
83+
elseif ismissing(x)
84+
labels[i] = z
85+
invpool[x] = z
86+
else
87+
r = start:step:x
88+
lbl = length(r)
89+
labels[i] = lbl
90+
invpool[x] = lbl
91+
end
8392
end
84-
refs = similar(x, reftype)
85-
@inbounds for i in eachindex(refs)
86-
refs[i] = length(start:step:x[i])
93+
end
94+
95+
function scaledlabel(xs::AbstractArray, stepsize,
96+
R::Type=DEFAULT_REF_TYPE, T::Type=eltype(xs);
97+
start=nothing, stop=nothing, usepool::Bool=true)
98+
pool = validpool(xs, T, start, stepsize, stop, usepool)
99+
T = Missing <: T ? Union{eltype(pool), Missing} : eltype(pool)
100+
start = first(pool)
101+
stepsize = step(pool)
102+
if R <: Integer
103+
while typemax(R) < length(pool)
104+
R = widen(R)
105+
end
87106
end
88-
return refs, start, step, stop
107+
labels = similar(xs, R)
108+
invpool = Dict{T,R}()
109+
_scaledlabel!(labels, invpool, xs, start, stepsize)
110+
return labels, pool, invpool
89111
end
90112

91-
function ScaledArray(x::AbstractArray, reftype::Type, start, step, stop, usepool::Bool=true)
92-
refs, start, step, stop = _scaledlabel(x, step, reftype; start=start, stop=stop, usepool=usepool)
93-
return ScaledArray(RefArray(refs), start, step, stop)
113+
function ScaledArray(x::AbstractArray, reftype::Type, xtype::Type, start, step, stop, usepool::Bool=true)
114+
refs, pool, invpool = scaledlabel(x, step, reftype, xtype; start=start, stop=stop, usepool=usepool)
115+
return ScaledArray(RefArray(refs), pool, invpool)
94116
end
95117

96-
function ScaledArray(sa::ScaledArray, reftype::Type, start, step, stop, usepool::Bool=true)
97-
if step !== nothing && step != sa.step
98-
refs, start, step, stop = _scaledlabel(sa, step, reftype; start=start, stop=stop, usepool=usepool)
99-
return ScaledArray(RefArray(refs), start, step, stop)
118+
function ScaledArray(sa::ScaledArray, reftype::Type, xtype::Type, start, step, stop, usepool::Bool=true)
119+
if step !== nothing && step != scale(sa)
120+
refs, pool, invpool = scaledlabel(sa, step, reftype, xtype; start=start, stop=stop, usepool=usepool)
121+
return ScaledArray(RefArray(refs), pool, invpool)
100122
else
101-
step = sa.step
102-
start, stop = validstartstepstop(sa, start, step, stop, usepool)
123+
step = scale(sa)
124+
pool = validpool(sa, xtype, start, step, stop, usepool)
125+
T = Missing <: xtype ? Union{eltype(pool), Missing} : eltype(pool)
103126
refs = similar(sa.refs, reftype)
104-
if start == sa.start
127+
invpool = Dict{T, reftype}()
128+
start0 = first(sa.pool)
129+
start = first(pool)
130+
stop = last(pool)
131+
if start == start0
105132
copy!(refs, sa.refs)
106-
elseif start < sa.start && start < stop || start > sa.start && start > stop
107-
offset = length(start:step:sa.start) - 1
133+
copy!(invpool, sa.invpool)
134+
elseif start < start0 && start < stop || start > start0 && start > stop
135+
offset = length(start:step:start0) - 1
108136
refs .= sa.refs .+ offset
137+
for (k, v) in sa.invpool
138+
invpool[k] = v + offset
139+
end
109140
else
110-
offset = length(sa.start:step:start) - 1
141+
offset = length(start0:step:start) - 1
111142
refs .= sa.refs .- offset
143+
for (k, v) in sa.invpool
144+
invpool[k] = v - offset
145+
end
112146
end
113-
return ScaledArray(RefArray(refs), start, step, stop)
147+
return ScaledArray(RefArray(refs), pool, invpool)
114148
end
115149
end
116150

@@ -126,39 +160,43 @@ If `start` or `stop` is not specified, it will be chosen based on the extrema of
126160
- `usepool::Bool=true`: find extrema of `x` based on `DataAPI.refpool`.
127161
"""
128162
ScaledArray(x::AbstractArray, start, step, stop=nothing;
129-
reftype::Type=DEFAULT_REF_TYPE, usepool::Bool=true) =
130-
ScaledArray(x, reftype, start, step, stop, usepool)
163+
reftype::Type=DEFAULT_REF_TYPE, xtype::Type=eltype(x), usepool::Bool=true) =
164+
ScaledArray(x, reftype, xtype, start, step, stop, usepool)
131165

132166
ScaledArray(sa::ScaledArray, start, step, stop=nothing;
133-
reftype::Type=eltype(refarray(sa)), usepool::Bool=true) =
134-
ScaledArray(sa, reftype, start, step, stop, usepool)
167+
reftype::Type=eltype(refarray(sa)), xtype::Type=eltype(sa), usepool::Bool=true) =
168+
ScaledArray(sa, reftype, xtype, start, step, stop, usepool)
135169

136170
ScaledArray(x::AbstractArray, step; reftype::Type=DEFAULT_REF_TYPE,
137-
start=nothing, stop=nothing, usepool::Bool=true) =
138-
ScaledArray(x, reftype, start, step, stop, usepool)
171+
start=nothing, stop=nothing, xtype::Type=eltype(x), usepool::Bool=true) =
172+
ScaledArray(x, reftype, xtype, start, step, stop, usepool)
139173

140174
ScaledArray(sa::ScaledArray, step=nothing; reftype::Type=eltype(refarray(sa)),
141-
start=nothing, stop=nothing, usepool::Bool=true) =
142-
ScaledArray(sa, reftype, start, step, stop, usepool)
175+
start=nothing, stop=nothing, xtype::Type=eltype(sa), usepool::Bool=true) =
176+
ScaledArray(sa, reftype, xtype, start, step, stop, usepool)
143177

144178
Base.size(sa::ScaledArray) = size(sa.refs)
145-
Base.IndexStyle(::Type{<:ScaledArray{T,N,RA}}) where {T,N,RA} = IndexStyle(RA)
179+
Base.IndexStyle(::Type{<:ScaledArray{T,R,N,RA}}) where {T,R,N,RA} = IndexStyle(RA)
146180

147181
DataAPI.refarray(sa::ScaledArray) = sa.refs
148182
DataAPI.refvalue(sa::ScaledArray, n::Integer) = getindex(DataAPI.refpool(sa), n)
149183
DataAPI.refpool(sa::ScaledArray) = sa.pool
184+
DataAPI.invrefpool(sa::ScaledArray) = sa.invpool
150185

151186
DataAPI.refarray(ssa::SubArray{<:Any, <:Any, <:ScaledArray}) =
152187
view(parent(ssa).refs, ssa.indices...)
153188
DataAPI.refvalue(ssa::SubArray{<:Any, <:Any, <:ScaledArray}, n::Integer) =
154189
DataAPI.refvalue(parent(ssa), n)
155190
DataAPI.refpool(ssa::SubArray{<:Any, <:Any, <:ScaledArray}) =
156191
DataAPI.refpool(parent(ssa))
192+
DataAPI.invrefpool(ssa::SubArray{<:Any, <:Any, <:ScaledArray}) =
193+
DataAPI.invrefpool(parent(ssa))
157194

158195
@inline function Base.getindex(sa::ScaledArray, i::Int)
159196
refs = DataAPI.refarray(sa)
160197
@boundscheck checkbounds(refs, i)
161198
@inbounds n = refs[i]
199+
iszero(n) && return missing
162200
pool = DataAPI.refpool(sa)
163201
@boundscheck checkbounds(pool, n)
164202
return @inbounds pool[n]
@@ -169,13 +207,14 @@ end
169207
@boundscheck checkbounds(refs, I...)
170208
@inbounds ns = refs[I...]
171209
pool = DataAPI.refpool(sa)
172-
@boundscheck checkbounds(pool, ns)
210+
N = length(pool)
211+
@boundscheck checkindex(Bool, 0:N, ns) || throw_boundserror(pool, ns)
173212
return @inbounds pool[ns]
174213
end
175214

176215
function Base.:(==)(x::ScaledArray, y::ScaledArray)
177216
size(x) == size(y) || return false
178-
x.start == y.start && x.step == y.step && return x.refs == y.refs
217+
first(x.pool) == first(y.pool) && step(x.pool) == step(y.pool) && return x.refs == y.refs
179218
eq = true
180219
for (p, q) in zip(x, y)
181220
# missing could arise

src/operations.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,17 @@ end
2020

2121
# Obtain unique labels for row-wise pairs of values from a1 and a2 when mult is large enough
2222
function _mult!(a1::AbstractArray, a2::AbstractArray, mult)
23-
a1 .+= mult .* (a2 .- 1)
23+
z = zero(eltype(a1))
24+
@inbounds for i in eachindex(a1)
25+
x1 = a1[i]
26+
x2 = a2[i]
27+
# Handle missing values represented by zeros
28+
if iszero(x1) || iszero(x2)
29+
a1[i] = z
30+
else
31+
a1[i] += mult * (x2 - 1)
32+
end
33+
end
2434
end
2535

2636
"""
@@ -128,13 +138,24 @@ The returned array ensures well-defined time intervals for operations involving
128138
- `rotation=nothing`: rotation groups in a rotating sampling design; use [`RotatingTimeValue`](@ref)s as reference values.
129139
"""
130140
function settime(time::AbstractArray; step=nothing, reftype::Type{<:Signed}=Int32, rotation=nothing)
131-
eltype(time) <: ValidTimeType ||
132-
throw(ArgumentError("unaccepted element type $(eltype(time)) from time column"))
133-
step === nothing && (step = one(eltype(time)))
141+
T = eltype(time)
142+
T <: ValidTimeType && !(T <: RotatingTimeValue) ||
143+
throw(ArgumentError("unaccepted element type $T from time column"))
144+
step === nothing && (step = one(T))
134145
time = ScaledArray(time, step; reftype=reftype)
135146
if rotation !== nothing
136147
refs = rotatingtime(rotation, time.refs)
137-
time = ScaledArray(RefArray(refs), time.start, time.step, time.stop)
148+
rots = unique(rotation)
149+
invpool = Dict{RotatingTimeValue{eltype(rotation), T}, eltype(refs)}()
150+
for (k, v) in time.invpool
151+
for r in rots
152+
rt = RotatingTimeValue(r, k)
153+
invpool[rt] = RotatingTimeValue(r, v)
154+
end
155+
end
156+
rmin, rmax = extrema(rots)
157+
pool = RotatingTimeValue(rmin, first(time.pool)):scale(time):RotatingTimeValue(rmax, last(time.pool))
158+
time = ScaledArray(RefArray(refs), pool, invpool)
138159
end
139160
return time
140161
end

src/procedures.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,16 @@ function overlap!(esample::BitVector, tr_rows::BitVector, tr::DynamicTreatment,
9393
pr::NotYetTreatedParallel{Unconditional}, treatname::Symbol, data)
9494
overlap_time, _c, _t = _overlaptime(tr, tr_rows, data)
9595
timetype = eltype(overlap_time)
96+
invpool = invrefpool(getcolumn(data, tr.time))
9697
if !(timetype <: RotatingTimeValue)
9798
ecut = pr.ecut[1]
99+
invpool === nothing || (ecut = invpool[ecut])
98100
valid_cohort = filter(x -> x < ecut || x in pr.e, overlap_time)
99101
filter!(x -> x < ecut, overlap_time)
100102
else
101-
ecut = IdDict(e.rotation=>e.time for e in pr.ecut)
103+
ecut = pr.ecut
104+
invpool === nothing || (ecut = (invpool[e] for e in ecut))
105+
ecut = IdDict(e.rotation=>e.time for e in ecut)
102106
valid_cohort = filter(x -> x.time < ecut[x.rotation] || x in pr.e, overlap_time)
103107
filter!(x -> x.time < ecut[x.rotation], overlap_time)
104108
end

0 commit comments

Comments
 (0)