Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit dcd1124

Browse files
authored
Add GroupTreatintterms, GroupXterms, GroupSample and post! (#32)
1 parent 8c572ae commit dcd1124

File tree

7 files changed

+216
-36
lines changed

7 files changed

+216
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Reexport = "0.2, 1"
3535
StatsBase = "0.33"
3636
StatsFuns = "0.9"
3737
StatsModels = "0.6.18"
38-
StructArrays = "0.5"
38+
StructArrays = "0.5, 0.6"
3939
Tables = "1.2"
4040
julia = "1.3"
4141

src/DiffinDiffsBase.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ export cb,
109109
@specset,
110110

111111
CheckData,
112-
GroupTerms,
112+
GroupTreatintterms,
113+
GroupXterms,
113114
CheckVars,
115+
GroupSample,
114116
MakeWeights,
115117

116118
DiffinDiffsEstimator,
@@ -135,7 +137,12 @@ export cb,
135137
TransformedDIDResult,
136138
TransSubDIDResult,
137139
lincom,
138-
rescale
140+
rescale,
141+
ExportFormat,
142+
StataPostHDF,
143+
getexportformat,
144+
setexportformat!,
145+
post!
139146

140147
include("tables.jl")
141148
include("utils.jl")

src/did.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,3 +780,67 @@ function rescale(r::AbstractDIDResult, by::Pair, subset)
780780
tinds = treatindex(ntreatcoef(r), inds)
781781
return rescale(r, apply(view(treatcells(r), tinds), by), inds)
782782
end
783+
784+
"""
785+
ExportFormat
786+
787+
Supertype for all types representing the format for exporting an [`AbstractDIDResult`](@ref).
788+
"""
789+
abstract type ExportFormat end
790+
791+
"""
792+
StataPostHDF <: ExportFormat
793+
794+
Export an [`AbstractDIDResult`](@ref) for Stata module
795+
[`posthdf`](https://github.com/junyuan-chen/posthdf).
796+
"""
797+
struct StataPostHDF <: ExportFormat end
798+
799+
const DefaultExportFormat = ExportFormat[StataPostHDF()]
800+
801+
"""
802+
getexportformat()
803+
804+
Return the default [`ExportFormat`](@ref) for [`post!`](@ref).
805+
"""
806+
getexportformat() = DefaultExportFormat[1]
807+
808+
"""
809+
setexportformat!(format::ExportFormat)
810+
811+
Set the default [`ExportFormat`](@ref) for [`post!`](@ref).
812+
"""
813+
setexportformat!(format::ExportFormat) = (DefaultExportFormat[1] = format)
814+
815+
"""
816+
post!(f, r::AbstractDIDResult; kwargs...)
817+
818+
Export result `r` in a default [`ExportFormat`](@ref).
819+
820+
The default format can be retrieved via [`getexportformat`](@ref)
821+
and modified via [`setexportformat!`](@ref).
822+
"""
823+
post!(f, r::AbstractDIDResult; kwargs...) =
824+
post!(f, getexportformat(), r; kwargs...)
825+
826+
"""
827+
post!(f, ::StataPostHDF, r::AbstractDIDResult; model="DiffinDiffsBase.AbstractDIDResult")
828+
829+
Export result `r` for Stata module
830+
[`posthdf`](https://github.com/junyuan-chen/posthdf).
831+
A subset of field values from `r` are placed in `f` by setting key-value pairs,
832+
where `f` can be either an `HDF5.Group` or any object that can be indexed by strings.
833+
"""
834+
function post!(f, ::StataPostHDF, r::AbstractDIDResult;
835+
model::String="DiffinDiffsBase.AbstractDIDResult")
836+
f["model"] = model
837+
f["b"] = coef(r)
838+
f["V"] = vcov(r)
839+
f["vce"] = repr(vce(r))
840+
f["N"] = nobs(r)
841+
f["depvar"] = string(outcomename(r))
842+
f["coefnames"] = convert(AbstractVector{String}, coefnames(r))
843+
f["weights"] = (w = weights(r); w === nothing ? "" : string(w))
844+
f["ntreatcoef"] = ntreatcoef(r)
845+
return f
846+
end

src/procedures.jl

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,48 @@ required(::CheckData) = (:data,)
4545
default(::CheckData) = (subset=nothing, weightname=nothing)
4646

4747
"""
48-
groupterms(args...)
48+
grouptreatintterms(treatintterms)
4949
50-
Return the arguments for allowing later comparisons based on object-id.
51-
See also [`GroupTerms`](@ref).
50+
Return the argument without change for allowing later comparisons based on object-id.
51+
See also [`GroupTreatintterms`](@ref).
5252
"""
53-
groupterms(treatintterms::TermSet, xterms::TermSet) =
54-
(treatintterms = treatintterms, xterms = xterms)
53+
grouptreatintterms(treatintterms::TermSet) = (treatintterms=treatintterms,)
5554

5655
"""
57-
GroupTerms <: StatsStep
56+
GroupTreatintterms <: StatsStep
5857
59-
Call [`DiffinDiffsBase.groupterms`](@ref)
60-
to obtain one of the instances of `treatintterms` and `xterms`
61-
that have been grouped by `==`
58+
Call [`DiffinDiffsBase.grouptreatintterms`](@ref)
59+
to obtain one of the instances of `treatintterms`
60+
that have been grouped by equality (`hash`)
6261
for allowing later comparisons based on object-id.
6362
6463
This step is only useful when working with [`@specset`](@ref) and [`proceed`](@ref).
6564
"""
66-
const GroupTerms = StatsStep{:GroupTerms, typeof(groupterms), false}
65+
const GroupTreatintterms = StatsStep{:GroupTreatintterms, typeof(grouptreatintterms), false}
6766

68-
required(::GroupTerms) = (:treatintterms, :xterms)
67+
default(::GroupTreatintterms) = (treatintterms=TermSet(),)
68+
69+
"""
70+
groupxterms(xterms)
71+
72+
Return the argument without change for allowing later comparisons based on object-id.
73+
See also [`GroupXterms`](@ref).
74+
"""
75+
groupxterms(xterms::TermSet) = (xterms=xterms,)
76+
77+
"""
78+
GroupXterms <: StatsStep
79+
80+
Call [`DiffinDiffsBase.groupxterms`](@ref)
81+
to obtain one of the instances of `xterms`
82+
that have been grouped by equality (`hash`)
83+
for allowing later comparisons based on object-id.
84+
85+
This step is only useful when working with [`@specset`](@ref) and [`proceed`](@ref).
86+
"""
87+
const GroupXterms = StatsStep{:GroupXterms, typeof(groupxterms), false}
88+
89+
default(::GroupXterms) = (xterms=TermSet(),)
6990

7091
function _checkscales(col1::AbstractArray, col2::AbstractArray, treatvars::Vector{Symbol})
7192
if col1 isa ScaledArrOrSub || col2 isa ScaledArrOrSub
@@ -159,11 +180,12 @@ Exclude rows with missing data or violate the overlap condition
159180
and find rows with data from treated units.
160181
See also [`CheckVars`](@ref).
161182
"""
162-
function checkvars!(data, tr::AbstractTreatment, pr::AbstractParallel,
183+
function checkvars!(data, pr::AbstractParallel,
163184
yterm::AbstractTerm, treatname::Symbol, esample::BitVector, aux::BitVector,
164-
treatintterms::TermSet, xterms::TermSet)
185+
treatintterms::TermSet, xterms::TermSet, ::Type, @nospecialize(trvars::Tuple),
186+
tr::AbstractTreatment)
165187
# Do not check eltype of treatintterms
166-
treatvars = union([treatname], termvars(tr), termvars(pr))
188+
treatvars = union([treatname], trvars, termvars(pr))
167189
checktreatvars(tr, pr, treatvars, data)
168190

169191
allvars = union(treatvars, termvars(yterm), termvars(xterms))
@@ -202,9 +224,40 @@ Call [`DiffinDiffsBase.checkvars!`](@ref) to exclude invalid rows for relevant v
202224
"""
203225
const CheckVars = StatsStep{:CheckVars, typeof(checkvars!), true}
204226

205-
required(::CheckVars) = (:data, :tr, :pr, :yterm, :treatname, :esample, :aux)
206-
default(::CheckVars) = (treatintterms=TermSet(), xterms=TermSet())
207-
copyargs(::CheckVars) = (6,)
227+
required(::CheckVars) = (:data, :pr, :yterm, :treatname, :esample, :aux,
228+
:treatintterms, :xterms)
229+
transformed(::CheckVars, @nospecialize(nt::NamedTuple)) =
230+
(typeof(nt.tr), (termvars(nt.tr)...,))
231+
232+
combinedargs(step::CheckVars, allntargs) =
233+
combinedargs(step, allntargs, typeof(allntargs[1].tr))
234+
235+
combinedargs(::CheckVars, allntargs, ::Type{DynamicTreatment{SharpDesign}}) =
236+
(allntargs[1].tr,)
237+
238+
copyargs(::CheckVars) = (5,)
239+
240+
"""
241+
groupsample(esample)
242+
243+
Return the argument without change for allowing later comparisons based on object-id.
244+
See also [`GroupSample`](@ref).
245+
"""
246+
groupsample(esample::BitVector) = (esample=esample,)
247+
248+
"""
249+
GroupSample <: StatsStep
250+
251+
Call [`DiffinDiffsBase.groupsample`](@ref)
252+
to obtain one of the instances of `esample`
253+
that have been grouped by equality (`hash`)
254+
for allowing later comparisons based on object-id.
255+
256+
This step is only useful when working with [`@specset`](@ref) and [`proceed`](@ref).
257+
"""
258+
const GroupSample = StatsStep{:GroupSample, typeof(groupsample), false}
259+
260+
required(::GroupSample) = (:esample,)
208261

209262
"""
210263
makeweights(args...)

test/did.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,22 @@ end
545545
m = Diagonal(r.treatcells.rel[1:3])
546546
@test vcov(tr) == m * r.vcov[1:3,1:3] * m'
547547
end
548+
549+
@testset "post!" begin
550+
@test getexportformat() == DefaultExportFormat[1]
551+
setexportformat!(StataPostHDF())
552+
@test DefaultExportFormat[1] == StataPostHDF()
553+
554+
f = Dict{String,Any}()
555+
r = TestResult(2, 2)
556+
post!(f, r)
557+
@test f["model"] == "DiffinDiffsBase.AbstractDIDResult"
558+
@test f["b"] == coef(r)
559+
@test f["V"] == vcov(r)
560+
@test f["vce"] == "nothing"
561+
@test f["N"] == nobs(r)
562+
@test f["depvar"] == "y"
563+
@test f["coefnames"][1] == "rel: 1 & c: 1"
564+
@test f["weights"] == "w"
565+
@test f["ntreatcoef"] == 4
566+
end

test/procedures.jl

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,35 @@
4242
end
4343
end
4444

45-
@testset "GroupTerms" begin
46-
@testset "groupterms" begin
47-
nt = (treatintterms=TermSet(), xterms=TermSet(term(:x)))
48-
@test groupterms(nt...) == nt
45+
@testset "GroupTreatintterms" begin
46+
@testset "grouptreatintterms" begin
47+
nt = (treatintterms=TermSet(),)
48+
@test grouptreatintterms(nt...) == nt
4949
end
5050

5151
@testset "StatsStep" begin
52-
@test sprint(show, GroupTerms()) == "GroupTerms"
53-
@test sprint(show, MIME("text/plain"), GroupTerms()) ==
54-
"GroupTerms (StatsStep that calls DiffinDiffsBase.groupterms)"
55-
@test _byid(GroupTerms()) == false
56-
nt = (treatintterms=TermSet(), xterms=TermSet(term(:x)))
57-
@test GroupTerms()(nt) == nt
52+
@test sprint(show, GroupTreatintterms()) == "GroupTreatintterms"
53+
@test sprint(show, MIME("text/plain"), GroupTreatintterms()) ==
54+
"GroupTreatintterms (StatsStep that calls DiffinDiffsBase.grouptreatintterms)"
55+
@test _byid(GroupTreatintterms()) == false
56+
nt = (treatintterms=TermSet(),)
57+
@test GroupTreatintterms()() == nt
58+
end
59+
end
60+
61+
@testset "GroupXterms" begin
62+
@testset "groupxterms" begin
63+
nt = (xterms=TermSet(term(:x)),)
64+
@test groupxterms(nt...) == nt
65+
end
66+
67+
@testset "StatsStep" begin
68+
@test sprint(show, GroupXterms()) == "GroupXterms"
69+
@test sprint(show, MIME("text/plain"), GroupXterms()) ==
70+
"GroupXterms (StatsStep that calls DiffinDiffsBase.groupxterms)"
71+
@test _byid(GroupXterms()) == false
72+
nt = (xterms=TermSet(),)
73+
@test GroupXterms()() == nt
5874
end
5975
end
6076

@@ -63,9 +79,11 @@ end
6379
hrs = exampledata("hrs")
6480
N = size(hrs,1)
6581
us = unspecifiedpr()
66-
nt = (data=hrs, tr=dynamic(:wave, -1), pr=us, yterm=term(:oop_spend),
82+
tr = dynamic(:wave, -1)
83+
nt = (data=hrs, pr=us, yterm=term(:oop_spend),
6784
treatname=:wave_hosp, esample=trues(N), aux=BitVector(undef, N),
68-
treatintterms=TermSet(), xterms=TermSet())
85+
treatintterms=TermSet(), xterms=TermSet(),
86+
tytr=typeof(tr), trvars=(termvars(tr)...,), tr=tr)
6987
@test checkvars!(nt...) == (esample=trues(N), tr_rows=trues(N))
7088

7189
nt = merge(nt, (pr=nevertreated(11),))
@@ -140,7 +158,8 @@ end
140158
df.wave_hosp = rotatingtime(rot, df.wave_hosp)
141159
df.wave = rotatingtime(rot, df.wave)
142160
e = rotatingtime((1,2), 11)
143-
nt = merge(nt, (data=df, tr=dynamic(:wave, -1), pr=nevertreated(e), treatintterms=TermSet(), xterms=TermSet(), esample=trues(N)))
161+
nt = merge(nt, (data=df, pr=nevertreated(e), treatintterms=TermSet(),
162+
xterms=TermSet(), esample=trues(N)))
144163
# Check RotatingTimeArray
145164
@test_throws ArgumentError checkvars!(nt...)
146165
df.wave_hosp = settime(hrs.wave_hosp, rotation=rot)
@@ -198,13 +217,30 @@ end
198217
@test CheckVars()(nt) ==
199218
merge(nt, (esample=trues(N), tr_rows=hrs.wave_hosp.!=11))
200219
nt = (data=hrs, tr=dynamic(:wave, -1), pr=nevertreated(11), yterm=term(:oop_spend),
201-
treatname=:wave_hosp, esample=trues(N), aux=BitVector(undef, N))
220+
treatname=:wave_hosp, esample=trues(N), aux=BitVector(undef, N),
221+
treatintterms=TermSet(), xterms=TermSet())
202222
@test CheckVars()(nt) ==
203223
merge(nt, (esample=trues(N), tr_rows=hrs.wave_hosp.!=11))
204224
@test_throws ErrorException CheckVars()()
205225
end
206226
end
207227

228+
@testset "GroupSample" begin
229+
@testset "groupsample" begin
230+
nt = (esample=trues(3),)
231+
@test groupsample(nt...) == nt
232+
end
233+
234+
@testset "StatsStep" begin
235+
@test sprint(show, GroupSample()) == "GroupSample"
236+
@test sprint(show, MIME("text/plain"), GroupSample()) ==
237+
"GroupSample (StatsStep that calls DiffinDiffsBase.groupsample)"
238+
@test _byid(GroupSample()) == false
239+
nt = (esample=trues(3),)
240+
@test GroupSample()(nt) == nt
241+
end
242+
end
243+
208244
@testset "MakeWeights" begin
209245
@testset "makeweights" begin
210246
hrs = exampledata("hrs")

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ using Dates: Date, Year
77
using DiffinDiffsBase: @fieldequal, unpack, @unpack, checktable, hastreat, parse_treat,
88
isintercept, isomitsintercept, parse_intercept!,
99
ncol, nrow, _mult!,
10-
_f, _byid, groupargs, copyargs, pool, checkdata!, groupterms, checkvars!, makeweights,
10+
_f, _byid, groupargs, copyargs, pool,
11+
checkdata!, grouptreatintterms, groupxterms, checkvars!, groupsample, makeweights,
1112
_totermset!, parse_didargs!, _treatnames, _parse_bycells!, _parse_subset, _nselected,
12-
treatindex, checktreatindex
13+
treatindex, checktreatindex, DefaultExportFormat
1314
using LinearAlgebra: Diagonal
1415
using Missings: allowmissing, disallowmissing
1516
using PooledArrays: PooledArray

0 commit comments

Comments
 (0)