Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 888d6d4

Browse files
authored
Allow sorting VecColumnTable and define TermSet as wrapped Set (#19)
1 parent 7bed71e commit 888d6d4

File tree

13 files changed

+352
-69
lines changed

13 files changed

+352
-69
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,11 @@ Reexport = "0.2, 1"
2929
StatsBase = "0.33"
3030
StatsModels = "0.6.18"
3131
Tables = "1.2"
32-
TypedTables = "1.2"
3332
julia = "1.3"
3433

3534
[extras]
3635
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
3736
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
38-
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3937

4038
[targets]
41-
test = ["DataFrames", "Test", "TypedTables"]
39+
test = ["DataFrames", "Test"]

src/DiffinDiffsBase.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ using PooledArrays: _label
1010
using Reexport
1111
using StatsBase: Weights, uweights
1212
@reexport using StatsModels
13+
using StatsModels: Schema
1314
using Tables
1415
using Tables: AbstractColumns, table, istable, columnnames, getcolumn
1516

1617
import Base: ==, show, union
1718
import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
1819
import StatsBase: coef, vcov, responsename, coefnames, weights, nobs, dof_residual
19-
import StatsModels: termvars
20+
import StatsModels: concrete_term, schema, termvars
2021

2122
const TimeType = Int
2223

@@ -27,6 +28,10 @@ export cb,
2728
,
2829
exampledata,
2930

31+
VecColumnTable,
32+
VecColsRow,
33+
subcolumns,
34+
3035
TreatmentSharpness,
3136
SharpDesign,
3237
sharp,
@@ -51,13 +56,11 @@ export cb,
5156
istreated,
5257

5358
TermSet,
59+
termset,
5460
eachterm,
5561
TreatmentTerm,
5662
treat,
5763

58-
VecColumnTable,
59-
subcolumns,
60-
6164
findcell,
6265
cellrows,
6366

@@ -84,10 +87,10 @@ export cb,
8487
treatnames
8588

8689
include("utils.jl")
90+
include("tables.jl")
8791
include("treatments.jl")
8892
include("parallels.jl")
8993
include("terms.jl")
90-
include("tables.jl")
9194
include("operations.jl")
9295
include("StatsProcedures.jl")
9396
include("procedures.jl")

src/did.jl

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ _key(::Any) = throw(ArgumentError("unacceptable positional arguments"))
2020

2121
function _totermset!(args::Dict{Symbol,Any}, s::Symbol)
2222
if haskey(args, s) && !(args[s] isa TermSet)
23-
ts = TermSet()
24-
foreach(t->setindex!(ts, nothing, t), args[s])
25-
args[s] = ts
23+
arg = args[s]
24+
args[s] = arg isa Symbol ? termset(arg) : termset(arg...)
2625
end
2726
end
2827

@@ -185,7 +184,7 @@ All concrete subtypes of `DIDResult` are expected to have the following fields:
185184
- `yname::String`: name of the outcome variable (generated by `StatsModels.coefnames`).
186185
- `coefnames::Vector{String}`: names of all treatment coefficients and covariates.
187186
- `coefinds::Dict{String, Int}`: a map from `coefnames` to integer indices for retrieving estimates by name.
188-
- `treatinds::Table`: tabular descriptions of treatment coefficients in the order of `coefnames` (do not contain covariates).
187+
- `treatcells`: Tables.jl-compatible tabular descriptions of treatment coefficients in the order of `coefnames` (do not contain covariates).
189188
- `weights::Union{Symbol, Nothing}`: column name of the weight variable (if specified).
190189
"""
191190
abstract type DIDResult <: StatisticalModel end
@@ -237,18 +236,36 @@ coef(r::DIDResult, i::Int) = r.coef[i]
237236
coef(r::DIDResult, inds) = [coef(r, ind) for ind in inds]
238237

239238
"""
239+
coef(r::DIDResult, bys::Pair{Symbol,<:Function}...)
240240
coef(f::Function, r::DIDResult)
241241
242242
Return a vector of point estimates for treatment coefficients
243-
selected based on whether `f` returns `true` or `false`
244-
for each corresponding row in `treatinds`.
243+
selected based on the specified functions that return either `true` or `false`.
244+
If `bys` are specified,
245+
each function is applied column-wise to each value of the specified column in `treatcells`.
246+
A coefficient is selected if the returned values from the functions are all `true`.
247+
If a single function `f` is provided without specifing column name,
248+
`f` is applied row-wise to each row of `treatcells`.
245249
246250
!!! note
247251
This method only selects estimates for treatment coefficients.
248252
Covariates are not taken into account.
249253
"""
254+
@inline function coef(r::DIDResult, @nospecialize(bys::Pair{Symbol,<:Function}...))
255+
nby = length(bys)
256+
by = bys[1]
257+
inds = by[2].(getcolumn(r.treatcells, by[1]))
258+
if nby > 1
259+
for i in 2:nby
260+
by = bys[i]
261+
inds .&= by[2].(getcolumn(r.treatcells, by[1]))
262+
end
263+
end
264+
return view(r.coef, 1:length(r.treatcells[1]))[inds]
265+
end
266+
250267
@inline coef(f::Function, r::DIDResult) =
251-
view(r.coef, 1:length(r.treatinds))[f.(r.treatinds)]
268+
view(r.coef, 1:length(r.treatcells[1]))[f.(Tables.rows(r.treatcells))]
252269

253270
"""
254271
vcov(r::DIDResult)
@@ -282,19 +299,38 @@ function vcov(r::DIDResult, inds)
282299
end
283300

284301
"""
302+
vcov(r::DIDResult, bys::Pair{Symbol,<:Function}...)
285303
vcov(f::Function, r::DIDResult)
286304
287305
Return a variance-covariance matrix for treatment coefficients
288-
selected based on whether `f` returns `true` or `false`
289-
for each corresponding row in `treatinds`.
306+
selected based on the specified functions that return either `true` or `false`.
307+
If `bys` are specified,
308+
each function is applied column-wise to each value of the specified column in `treatcells`.
309+
A coefficient is selected if the returned values from the functions are all `true`.
310+
If a single function `f` is provided without specifing column name,
311+
`f` is applied row-wise to each row of `treatcells`.
290312
291313
!!! note
292314
This method only selects estimates for treatment coefficients.
293315
Covariates are not taken into account.
294316
"""
317+
@inline function vcov(r::DIDResult, @nospecialize(bys::Pair{Symbol,<:Function}...))
318+
N = length(r.treatcells[1])
319+
nby = length(bys)
320+
by = bys[1]
321+
inds = by[2].(getcolumn(r.treatcells, by[1]))
322+
if nby > 1
323+
for i in 2:nby
324+
by = bys[i]
325+
inds .&= by[2].(getcolumn(r.treatcells, by[1]))
326+
end
327+
end
328+
return view(r.vcov, 1:N, 1:N)[inds, inds]
329+
end
330+
295331
@inline function vcov(f::Function, r::DIDResult)
296-
N = length(r.treatinds)
297-
inds = f.(r.treatinds)
332+
N = length(r.treatcells[1])
333+
inds = f.(Tables.rows(r.treatcells))
298334
return view(r.vcov, 1:N, 1:N)[inds, inds]
299335
end
300336

@@ -340,7 +376,7 @@ coefnames(r::DIDResult) = r.coefnames
340376
341377
Return a vector of names for treatment coefficients.
342378
"""
343-
treatnames(r::DIDResult) = r.coefnames[1:size(r.treatinds,1)]
379+
treatnames(r::DIDResult) = r.coefnames[1:size(r.treatcells,1)]
344380

345381
"""
346382
weights(r::DIDResult)

src/operations.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ findcell(names, data, esample=Colon()) =
7171
A utility function for processing the object `refrows` returned by [`findcell`](@ref).
7272
Unique row values from `cols` corresponding to
7373
the keys in `refrows` are sorted lexicographically
74-
and stored as rows in a `Tables.MatrixTable`.
74+
and stored as rows in a new `VecColumnTable`.
7575
Groups of row indices from the values of `refrows` are permuted to
7676
match the order of row values and collected in a `Vector`.
7777
7878
# Returns
79-
- `cells::MatrixTable`: unique row values from columns in `cols`.
79+
- `cells::VecColumnTable`: unique row values from columns in `cols`.
8080
- `rows::Vector{Vector{Int}}`: row indices for each combination.
8181
"""
8282
function cellrows(cols::VecColumnTable, refrows::IdDict)
@@ -85,21 +85,26 @@ function cellrows(cols::VecColumnTable, refrows::IdDict)
8585
ncol = length(cols)
8686
ncell = length(refrows)
8787
rows = Vector{Vector{Int}}(undef, ncell)
88-
cache = Matrix{Any}(undef, ncell, ncol+1)
88+
columns = AbstractVector[Vector{eltype(c)}(undef, ncell) for c in cols]
89+
refs = Vector{keytype(refrows)}(undef, ncell)
8990
r = 0
9091
@inbounds for (k, v) in refrows
9192
r += 1
92-
cache[r, end] = k
9393
row1 = v[1]
94+
refs[r] = k
9495
for c in 1:ncol
95-
cache[r, c] = cols[c][row1]
96+
columns[c][r] = cols[c][row1]
9697
end
9798
end
98-
sorted = sortslices(cache, dims=1)
99-
cells = table(sorted[:,1:ncol], header=columnnames(cols))
99+
cells = VecColumnTable(columns, _names(cols), _lookup(cols))
100+
p = sortperm(cells)
101+
# Replace each column of cells with a new one in the sorted order
102+
@inbounds for i in 1:ncol
103+
columns[i] = cells[i][p]
104+
end
100105
# Collect rows in the same order as cells
101106
@inbounds for i in 1:ncell
102-
rows[i] = refrows[sorted[i,end]]
107+
rows[i] = refrows[refs[p[i]]]
103108
end
104109
return cells, rows
105110
end

src/tables.jl

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,33 @@ function VecColumnTable(columns::Vector{AbstractVector}, names::Vector{Symbol})
2929
return VecColumnTable(columns, names, lookup)
3030
end
3131

32-
function VecColumnTable(data)
32+
function VecColumnTable(data, esample::Union{BitVector, Nothing}=nothing)
3333
Tables.istable(data) || throw(ArgumentError("input data is not Tables.jl-compatible"))
3434
names = collect(Tables.columnnames(data))
3535
ncol = length(names)
3636
columns = Vector{AbstractVector}(undef, ncol)
3737
cols = Tables.columns(data)
38-
@inbounds for i in keys(names)
39-
columns[i] = Tables.getcolumn(cols, i)
38+
if esample === nothing
39+
@inbounds for i in keys(names)
40+
columns[i] = Tables.getcolumn(cols, i)
41+
end
42+
else
43+
@inbounds for i in keys(names)
44+
columns[i] = view(Tables.getcolumn(cols, i), esample)
45+
end
4046
end
4147
return VecColumnTable(columns, names)
4248
end
4349

50+
function VecColumnTable(cols::VecColumnTable, esample::Union{BitVector, Nothing}=nothing)
51+
esample === nothing && return cols
52+
columns = similar(_columns(cols))
53+
@inbounds for i in keys(columns)
54+
columns[i] = view(cols[i], esample)
55+
end
56+
return VecColumnTable(columns, _names(cols), _lookup(cols))
57+
end
58+
4459
_columns(cols::VecColumnTable) = getfield(cols, :columns)
4560
_names(cols::VecColumnTable) = getfield(cols, :names)
4661
_lookup(cols::VecColumnTable) = getfield(cols, :lookup)
@@ -150,3 +165,58 @@ end
150165

151166
subcolumns(data, names::Symbol, rows=Colon(); nomissing=true) =
152167
subcolumns(data, [names], rows, nomissing=nomissing)
168+
169+
const VecColsRow = Tables.ColumnsRow{VecColumnTable}
170+
171+
ncol(r::VecColsRow) = length(r)
172+
173+
_rowhash(cols::Tuple{AbstractVector}, r::Int, h::UInt=zero(UInt))::UInt =
174+
hash(cols[1][r], h)
175+
176+
function _rowhash(cols::Tuple{Vararg{AbstractVector}}, r::Int, h::UInt=zero(UInt))::UInt
177+
h = hash(cols[1][r], h)
178+
_rowhash(Base.tail(cols), r, h)
179+
end
180+
181+
# hash is implemented following DataFrames.DataFrameRow for getting unique row values
182+
# Column names are not taken into account
183+
Base.hash(r::VecColsRow, h::UInt=zero(UInt)) =
184+
_rowhash(ntuple(c->Tables.getcolumns(r)[c], ncol(r)), Tables.getrow(r), h)
185+
186+
# Column names are not taken into account
187+
function Base.isequal(r1::VecColsRow, r2::VecColsRow)
188+
length(r1) == length(r2) || return false
189+
return all(((a, b),) -> isequal(a, b), zip(r1, r2))
190+
end
191+
192+
# Column names are not taken into account
193+
function Base.isless(r1::VecColsRow, r2::VecColsRow)
194+
length(r1) == length(r2) ||
195+
throw(ArgumentError("compared VecColsRow do not have the same length"))
196+
for (a, b) in zip(r1, r2)
197+
isequal(a, b) || return isless(a, b)
198+
end
199+
return false
200+
end
201+
202+
Base.sortperm(cols::VecColumnTable; @nospecialize(kwargs...)) =
203+
sortperm(collect(Tables.rows(cols)); kwargs...)
204+
205+
# names and lookup are not copied
206+
function Base.sort(cols::VecColumnTable; @nospecialize(kwargs...))
207+
p = sortperm(cols; kwargs...)
208+
columns = similar(_columns(cols))
209+
i = 0
210+
@inbounds for col in cols
211+
i += 1
212+
columns[i] = col[p]
213+
end
214+
return VecColumnTable(columns, _names(cols), _lookup(cols))
215+
end
216+
217+
function Base.sort!(cols::VecColumnTable; @nospecialize(kwargs...))
218+
p = sortperm(cols; kwargs...)
219+
@inbounds for col in cols
220+
col .= col[p]
221+
end
222+
end

0 commit comments

Comments
 (0)