Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit e966273

Browse files
authored
Improve apply methods and fix issues with overlap! (#34)
1 parent 6d1dabf commit e966273

File tree

9 files changed

+129
-56
lines changed

9 files changed

+129
-56
lines changed

src/StatsProcedures.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,6 @@ function proceed(sps::Vector{<:StatsSpec};
486486
paused = false
487487
@inbounds for step in steps
488488
tasks = _byid(step) ? tasks_byid : tasks_byeq
489-
ntask = 0
490489
verbose && print("Running ", step, "...")
491490
# Group arguments by objectid or isequal
492491
for i in _sharedby(step)
@@ -495,6 +494,10 @@ function proceed(sps::Vector{<:StatsSpec};
495494
push!(get!(Vector{Int}, tasks, groupargs(step, traces[j])), j)
496495
end
497496
end
497+
ntask = length(tasks)
498+
nprocs = length(_sharedby(step))
499+
verbose && print("Scheduled ", ntask, ntask > 1 ? " tasks" : " task", " for ",
500+
nprocs, nprocs > 1 ? " procedures" : " procedure", "...\n")
498501

499502
for (gargs, ids) in tasks
500503
# Handle potential in-place operations on mutable objects
@@ -514,11 +517,9 @@ function proceed(sps::Vector{<:StatsSpec};
514517
traces[id] = merge(traces[id], ret)
515518
end
516519
end
517-
ntask = length(tasks)
518520
ntask_total += ntask
519521
empty!(tasks)
520-
nprocs = length(_sharedby(step))
521-
verbose && print("Finished ", ntask, ntask > 1 ? " tasks" : " task", " for ",
522+
verbose && print(" Finished ", ntask, ntask > 1 ? " tasks" : " task", " for ",
522523
nprocs, nprocs > 1 ? " procedures\n" : " procedure\n")
523524
step_count += 1
524525
step_count === pause && (paused = true) && break

src/did.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ collect estimation results for difference-in-differences
179179
with treatment of type `TR`.
180180
181181
# Interface definition
182-
| Required methods | Default definition | Brief description |
183-
|---|---|---|
182+
| Required method | Default definition | Brief description |
183+
|:---|:---|:---|
184184
| `coef(r)` | `r.coef` | Vector of point estimates for all coefficients including covariates |
185185
| `vcov(r)` | `r.vcov` | Variance-covariance matrix for estimates in `coef` |
186186
| `vce(r)` | `r.vce` | Covariance estimator |
@@ -189,8 +189,8 @@ with treatment of type `TR`.
189189
| `nobs(r)` | `r.nobs` | Number of observations (table rows) involved in estimation |
190190
| `outcomename(r)` | `r.yname` | Name of the outcome variable |
191191
| `coefnames(r)` | `r.coefnames` | Names (`Vector{String}`) of all coefficients including covariates |
192-
| `treatcells(r)` | `r.treatcells` | Tables.jl-compatible tabular description of treatment coefficients in the order of `coefnames` (without covariates) |
193-
| `weights(r)` | `r.weights` | Column name of the weight variable (if specified) |
192+
| `treatcells(r)` | `r.treatcells` | `Tables.jl`-compatible tabular description of treatment coefficients in the order of `coefnames` (without covariates) |
193+
| `weights(r)` | `r.weights` | Name of the column containing sample weights (if specified) |
194194
| `ntreatcoef(r)` | `size(treatcells(r), 1)` | Number of treatment coefficients |
195195
| `treatcoef(r)` | `view(coef(r), 1:ntreatcoef(r))` | A view of treatment coefficients |
196196
| `treatvcov(r)` | `(N = ntreatcoef(r); view(vcov(r), 1:N, 1:N))` | A view of variance-covariance matrix for treatment coefficients |
@@ -388,7 +388,7 @@ coefnames(r::AbstractDIDResult) = r.coefnames
388388
"""
389389
treatcells(r::AbstractDIDResult)
390390
391-
Return a Tables.jl-compatible tabular description of treatment coefficients
391+
Return a `Tables.jl`-compatible tabular description of treatment coefficients
392392
in the order of coefnames (without covariates).
393393
"""
394394
treatcells(r::AbstractDIDResult) = r.treatcells
@@ -550,7 +550,7 @@ end
550550

551551
# Helper functions for handling subset option that may involves Pairs
552552
_parse_subset(r::AbstractDIDResult, by::Pair, fill_x::Bool) =
553-
(inds = apply(treatcells(r), by); fill_x && _fill_x!(r, inds); return inds)
553+
(inds = apply_and(treatcells(r), by); fill_x && _fill_x!(r, inds); return inds)
554554

555555
function _parse_subset(r::AbstractDIDResult, inds, fill_x::Bool)
556556
eltype(inds) <: Pair || return inds
@@ -911,7 +911,7 @@ function post!(f, ::StataPostHDF, r::AbstractDIDResult;
911911
if at !== nothing
912912
pat = _postat!(f, r, at)
913913
pat === nothing && throw(ArgumentError(
914-
"Keyword argument of type $(typeof(at)) is not accepted."))
914+
"Keyword argument `at` of type $(typeof(at)) is not accepted."))
915915
pat == false || length(pat) != length(coef(r)) && throw(ArgumentError(
916916
"The length of at ($(length(pat))) does not match the length of b ($(length(coef(r))))"))
917917
end

src/operations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ are the same within each group.
4444
Instead of directly providing the relevant portions of columns as
4545
[`VecColumnTable`](@ref)``,
4646
one may specify the `names` of columns from
47-
`data` of any Tables.jl-compatible table type
47+
`data` of any `Tables.jl`-compatible table type
4848
over selected rows indicated by `esample`.
4949
Note that unless `esample` covers all rows of `data`,
5050
the row indices are those for the subsample selected based on `esample`
@@ -170,7 +170,7 @@ the [`ScaledArray`](@ref) `time`.
170170
If `time` is a [`RotatingTimeArray`](@ref) with the `time` field being a [`ScaledArray`](@ref),
171171
the returned array is also a [`RotatingTimeArray`](@ref)
172172
with the `time` field being the converted [`ScaledArray`](@ref).
173-
Alternative, the arrays may be specified with a Tables.jl-compatible `data` table
173+
Alternative, the arrays may be specified with a `Tables.jl`-compatible `data` table
174174
and column indices `colname` and `timename`.
175175
See also [`settime`](@ref).
176176
@@ -229,7 +229,7 @@ as a table containing the relevant columns or as arrays.
229229
that is returned by [`settime`](@ref).
230230
231231
# Arguments
232-
- `data`: a Tables.jl-compatible data table.
232+
- `data`: a `Tables.jl`-compatible data table.
233233
- `idname::Union{Symbol,Integer}`: the name of the column in `data` that contains unit IDs.
234234
- `timename::Union{Symbol,Integer}`: the name of the column in `data` that contains time values.
235235
- `id::AbstractArray`: the array containing unit IDs (only needed for the alternative method).

src/parallels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ show(io::IO, ::Unconditional) =
2222
"""
2323
unconditional()
2424
25-
Alias of [`Unconditional()`](@ref).
25+
Alias for [`Unconditional()`](@ref).
2626
"""
2727
unconditional() = Unconditional()
2828

@@ -56,7 +56,7 @@ show(io::IO, ::Exact) =
5656
"""
5757
exact()
5858
59-
Alias of [`Exact()`](@ref).
59+
Alias for [`Exact()`](@ref).
6060
"""
6161
exact() = Exact()
6262

@@ -252,7 +252,7 @@ See also [`notyettreated`](@ref).
252252
253253
# Fields
254254
- `e::Tuple{Vararg{ValidTimeType}}`: group indices for units that received the treatment relatively late.
255-
- `ecut::Tuple{Vararg{ValidTimeType}}`: user-specified period(s) when units in a group in `e` started to receive treatment.
255+
- `ecut::Tuple{Vararg{ValidTimeType}}`: user-specified period(s) when units in a group in `e` started to receive treatment or show anticipation effects.
256256
- `c::C`: an instance of [`ParallelCondition`](@ref).
257257
- `s::S`: an instance of [`ParallelStrength`](@ref).
258258

src/procedures.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ function checkdata!(data, subset::Union{BitVector, Nothing}, weightname::Union{S
1111
if subset !== nothing
1212
length(subset) == nrow || throw(DimensionMismatch(
1313
"data contain $(nrow) rows while subset has $(length(subset)) elements"))
14-
esample = subset
14+
# Do not modify subset in-place
15+
esample = copy(subset)
1516
else
1617
esample = trues(nrow)
1718
end
@@ -146,16 +147,16 @@ function checktreatvars(::DynamicTreatment{SharpDesign},
146147
end
147148
end
148149

149-
function _overlaptime(tr::DynamicTreatment, tr_rows::BitVector, data)
150+
function _overlaptime(tr::DynamicTreatment, esample::BitVector, tr_rows::BitVector, data)
150151
timeref = refarray(getcolumn(data, tr.time))
151-
control_time = Set(view(timeref, .!tr_rows))
152-
treated_time = Set(view(timeref, tr_rows))
152+
control_time = Set(view(timeref, .!tr_rows.&esample))
153+
treated_time = Set(view(timeref, tr_rows.&esample))
153154
return intersect(control_time, treated_time), control_time, treated_time
154155
end
155156

156157
function overlap!(esample::BitVector, tr_rows::BitVector, aux::BitVector, tr::DynamicTreatment,
157158
::NeverTreatedParallel{Unconditional}, treatname::Symbol, data)
158-
overlap_time, control_time, treated_time = _overlaptime(tr, tr_rows, data)
159+
overlap_time, control_time, treated_time = _overlaptime(tr, esample, tr_rows, data)
159160
if !(length(control_time)==length(treated_time)==length(overlap_time))
160161
aux[esample] .= view(refarray(getcolumn(data, tr.time)), esample) .∈ (overlap_time,)
161162
esample[esample] .&= view(aux, esample)
@@ -165,13 +166,12 @@ end
165166

166167
function overlap!(esample::BitVector, tr_rows::BitVector, aux::BitVector, tr::DynamicTreatment,
167168
pr::NotYetTreatedParallel{Unconditional}, treatname::Symbol, data)
168-
overlap_time, _c, _t = _overlaptime(tr, tr_rows, data)
169169
timecol = getcolumn(data, tr.time)
170+
# First exclude cohorts not suitable for comparisons
170171
if !(eltype(timecol) <: RotatingTimeValue)
171172
invpool = invrefpool(timecol)
172173
e = invpool === nothing ? Set(pr.e) : Set(invpool[c] for c in pr.e)
173174
ecut = invpool === nothing ? pr.ecut[1] : invpool[pr.ecut[1]]
174-
filter!(x -> x < ecut, overlap_time)
175175
isvalidcohort = x -> x < ecut || x in e
176176
else
177177
invpool = invrefpool(timecol.time)
@@ -182,13 +182,19 @@ function overlap!(esample::BitVector, tr_rows::BitVector, aux::BitVector, tr::Dy
182182
e = Set(RotatingTimeValue(c.rotation, invpool[c.time]) for c in pr.e)
183183
ecut = IdDict(e.rotation=>invpool[e.time] for e in pr.ecut)
184184
end
185-
filter!(x -> x.time < ecut[x.rotation], overlap_time)
186185
isvalidcohort = x -> x.time < ecut[x.rotation] || x in e
187186
end
188-
aux[esample] .= view(refarray(timecol), esample) .∈ (overlap_time,)
189-
esample[esample] .&= view(aux, esample)
190187
aux[esample] .= isvalidcohort.(view(refarray(getcolumn(data, treatname)), esample))
191188
esample[esample] .&= view(aux, esample)
189+
# Check overlaps among the remaining cohorts only
190+
overlap_time, _c, _t = _overlaptime(tr, esample, tr_rows, data)
191+
if !(eltype(timecol) <: RotatingTimeValue)
192+
filter!(x -> x < ecut, overlap_time)
193+
else
194+
filter!(x -> x.time < ecut[x.rotation], overlap_time)
195+
end
196+
aux[esample] .= view(refarray(timecol), esample) .∈ (overlap_time,)
197+
esample[esample] .&= view(aux, esample)
192198
tr_rows .&= esample
193199
end
194200

src/tables.jl

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
VecColumnTable <: AbstractColumns
33
4-
A Tables.jl-compatible column table that stores data as `Vector{AbstractVector}`
4+
A `Tables.jl`-compatible column table that stores data as `Vector{AbstractVector}`
55
and column names as `Vector{Symbol}`.
66
Retrieving columns by column names is achieved with a `Dict{Symbol,Int}`
77
that maps names to indices.
@@ -30,7 +30,7 @@ function VecColumnTable(columns::Vector{AbstractVector}, names::Vector{Symbol})
3030
end
3131

3232
function VecColumnTable(data, subset=nothing)
33-
Tables.istable(data) || throw(ArgumentError("input data is not Tables.jl-compatible"))
33+
Tables.istable(data) || throw(ArgumentError("input data is not `Tables.jl`-compatible"))
3434
names = collect(Tables.columnnames(data))
3535
ncol = length(names)
3636
columns = Vector{AbstractVector}(undef, ncol)
@@ -153,7 +153,7 @@ By default, columns are converted to drop support for missing values.
153153
When possible, resulting columns share memory with original columns.
154154
"""
155155
function subcolumns(data, names, rows=Colon(); nomissing=true)
156-
Tables.istable(data) || throw(ArgumentError("input data is not Tables.jl-compatible"))
156+
Tables.istable(data) || throw(ArgumentError("input data is not `Tables.jl`-compatible"))
157157
names = names isa Vector{Symbol} ? names : Symbol[names...]
158158
ncol = length(names)
159159
columns = Vector{AbstractVector}(undef, ncol)
@@ -227,46 +227,55 @@ end
227227

228228
"""
229229
apply(d, by::Pair)
230-
apply(d, bys::Pair...)
231230
232231
Apply a function elementwise to the specified column(s)
233-
in a Tables.jl-compatical table `d` and return the result.
232+
in a `Tables.jl`-compatible table `d` and return the result.
234233
235234
Depending on the argument(s) accepted by a function `f`,
236235
it is specified with argument `by` as either `column_index => f` or `column_indices => f`
237236
where `column_index` is either a `Symbol` or `Int` for a column in `d`
238237
and `column_indices` is an iterable collection of such indices for multiple columns.
239238
`f` is applied elementwise to each specified column
240239
to obtain an array of returned values.
241-
If multiple `Pair`s are provided,
242-
only the first `Pair` is applied and the rest of them are ignored.
243240
"""
244-
apply(d, by::Pair{Symbol,<:Function}) = by[2].(Tables.getcolumn(d, by[1]))
245-
apply(d, by::Pair) = by[2].((Tables.getcolumn(d, c) for c in by[1])...)
246-
apply(d, bys::Pair...) = apply(d, bys[1])
241+
apply(d, by::Pair{Symbol,<:Function}) =
242+
map(by[2], Tables.getcolumn(d, by[1]))
243+
244+
apply(d, @nospecialize(by::Pair)) =
245+
map(by[2], (Tables.getcolumn(d, c) for c in by[1])...)
247246

248247
"""
249248
apply_and!(inds::BitVector, d, by::Pair)
250249
apply_and!(inds::BitVector, d, bys::Pair...)
251250
252251
Apply a function that returns `true` or `false` elementwise
253-
to the specified column(s) in a Tables.jl-compatical table `d`
252+
to the specified column(s) in a `Tables.jl`-compatible table `d`
254253
and then update the elements in `inds` through bitwise `and` with the returned array.
254+
If an array instead of a function is provided,
255+
elementwise equality (`==`) comparison is applied between the column and the array.
255256
See also [`apply_and`](@ref).
256257
257258
The way a function is specified is the same as how it is done with [`apply`](@ref).
258259
If multiple `Pair`s are provided,
259260
`inds` are updated for each returned array through bitwise `and`.
260261
"""
261262
function apply_and!(inds::BitVector, d, by::Pair{Symbol,<:Function})
262-
inds .&= by[2].(Tables.getcolumn(d, by[1]))
263+
aux = map(by[2], Tables.getcolumn(d, by[1]))
264+
inds .&= aux
265+
return inds
263266
end
264267

265-
function apply_and!(inds::BitVector, d, by::Pair)
266-
inds .&= by[2].((Tables.getcolumn(d, c) for c in by[1])...)
268+
function apply_and!(inds::BitVector, d, @nospecialize(by::Pair))
269+
if by[2] isa Function
270+
aux = map(by[2], (Tables.getcolumn(d, c) for c in by[1])...)
271+
else
272+
aux = Tables.getcolumn(d, by[1]).==by[2]
273+
end
274+
inds .&= aux
275+
return inds
267276
end
268277

269-
function apply_and!(inds::BitVector, d, bys::Pair...)
278+
function apply_and!(inds::BitVector, d, @nospecialize(bys::Pair...))
270279
for by in bys
271280
apply_and!(inds, d, by)
272281
end
@@ -278,27 +287,57 @@ end
278287
apply_and(d, bys::Pair...)
279288
280289
Apply a function that returns `true` or `false` elementwise
281-
to the specified column(s) in a Tables.jl-compatical table `d` and return the result.
290+
to the specified column(s) in a `Tables.jl`-compatible table `d` and return the result.
291+
If an array instead of a function is provided,
292+
elementwise equality (`==`) comparison is applied between the column and the array.
282293
See also [`apply_and!`](@ref).
283294
284295
The way a function is specified is the same as how it is done with [`apply`](@ref).
285296
If multiple `Pair`s are provided,
286297
the returned array is obtained by combining
287298
arrays returned by each function through bitwise `and`.
288299
"""
289-
apply_and(d, by::Pair) = apply(d, by)
300+
apply_and(d, by::Pair{Symbol,<:Any}) = Tables.getcolumn(d, by[1]).==by[2]
301+
302+
function apply_and(d, by::Pair{Symbol,<:Function})
303+
src = Tables.getcolumn(d, by[1])
304+
out = similar(BitArray, size(src))
305+
map!(by[2], out, src)
306+
return out
307+
end
290308

291-
function apply_and(d, bys::Pair...)
292-
inds = apply(d, bys[1])
309+
function apply_and(d, @nospecialize(by::Pair))
310+
if by[2] isa Function
311+
src = ((Tables.getcolumn(d, c) for c in by[1])...,)
312+
out = similar(BitArray, size(src[1]))
313+
map!(by[2], out, src...)
314+
return out
315+
else
316+
return Tables.getcolumn(d, by[1]).==by[2]
317+
end
318+
end
319+
320+
function apply_and(d, @nospecialize(bys::Pair...))
321+
inds = apply_and(d, bys[1])
293322
apply_and!(inds, d, bys[2:end]...)
294323
return inds
295324
end
296325

326+
_parse_subset(cols::VecColumnTable, by::Pair) = (inds = apply_and(cols, by); return inds)
327+
328+
function _parse_subset(cols::VecColumnTable, inds)
329+
eltype(inds) <: Pair || return inds
330+
inds = apply_and(cols, inds...)
331+
return inds
332+
end
333+
334+
_parse_subset(::VecColumnTable, ::Colon) = Colon()
335+
297336
"""
298337
TableIndexedMatrix{T,M,R,C} <: AbstractMatrix{T}
299338
300339
Matrix with row and column indices that can be selected
301-
based on row values in a Tables.jl-compatible table respectively.
340+
based on row values in a `Tables.jl`-compatible table respectively.
302341
This is useful when how elements are stored into the matrix
303342
are determined by the rows of the tables.
304343
@@ -319,9 +358,9 @@ struct TableIndexedMatrix{T,M,R,C} <: AbstractMatrix{T}
319358
c::C
320359
function TableIndexedMatrix(m::AbstractMatrix, r, c)
321360
Tables.istable(r) ||
322-
throw(ArgumentError("r is not Tables.jl-compatible"))
361+
throw(ArgumentError("r is not `Tables.jl`-compatible"))
323362
Tables.istable(c) ||
324-
throw(ArgumentError("c is not Tables.jl-compatible"))
363+
throw(ArgumentError("c is not `Tables.jl`-compatible"))
325364
size(m, 1) == Tables.rowcount(r) ||
326365
throw(DimensionMismatch("m and r do not have the same number of rows"))
327366
size(m, 2) == Tables.rowcount(c) || throw(DimensionMismatch(

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ end
156156
# Check whether the input data is a column table
157157
function checktable(data)
158158
istable(data) ||
159-
throw(ArgumentError("data of type $(typeof(data)) is not Tables.jl-compatible"))
159+
throw(ArgumentError("data of type $(typeof(data)) is not `Tables.jl`-compatible"))
160160
Tables.columnaccess(data) ||
161161
throw(ArgumentError("data of type $(typeof(data)) is not a column table"))
162162
end

0 commit comments

Comments
 (0)