Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 418d540

Browse files
authored
Improve treatment and parallel types and add pause option to proceed (#12)
1 parent 120963f commit 418d540

File tree

9 files changed

+90
-63
lines changed

9 files changed

+90
-63
lines changed

src/DiffinDiffsBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
1414
import StatsBase: coef, vcov, responsename, coefnames, weights, nobs, dof_residual
1515
import StatsModels: termvars, hasintercept, omitsintercept
1616

17+
const TimeType = Int
18+
1719
# Reexport objects from StatsBase
1820
export coef, vcov, responsename, coefnames, weights, nobs, dof_residual
1921

src/StatsProcedures.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ See also [`@specset`](@ref).
458458
- `verbose::Bool=false`: print the name of each step when it is called.
459459
- `keep=nothing`: names (of type `Symbol`) of additional objects to be returned.
460460
- `keepall::Bool=false`: return all objects generated by procedures along with arguments from the [`StatsSpec`](@ref)s.
461+
- `pause::Int=0`: break the iteration over [`StatsStep`](@ref)s after finishing the specified number of steps (for debugging).
461462
462463
# Returns
463464
- `Vector`: results for each specification in the same order of `sps`.
@@ -469,7 +470,7 @@ When either `keep` or `keepall` is specified,
469470
a `NamedTuple` with additional objects is formed for each [`StatsSpec`](@ref).
470471
"""
471472
function proceed(sps::AbstractVector{<:StatsSpec};
472-
verbose::Bool=false, keep=nothing, keepall::Bool=false)
473+
verbose::Bool=false, keep=nothing, keepall::Bool=false, pause::Int=0)
473474
nsps = length(sps)
474475
nsps == 0 && throw(ArgumentError("expect a nonempty vector"))
475476

@@ -486,6 +487,8 @@ function proceed(sps::AbstractVector{<:StatsSpec};
486487
steps = pool((p for p in keys(gids))...)
487488
tasks = IdDict{Tuple, Vector{Int}}()
488489
ntask_total = 0
490+
step_count = 0
491+
paused = false
489492
@inbounds for step in steps
490493
ntask = 0
491494
verbose && print("Running ", step, "...")
@@ -521,14 +524,18 @@ function proceed(sps::AbstractVector{<:StatsSpec};
521524
nprocs = length(_sharedby(step))
522525
verbose && print("Finished ", ntask, ntask > 1 ? " tasks" : " task", " for ",
523526
nprocs, nprocs > 1 ? " procedures\n" : " procedure\n")
527+
step_count += 1
528+
step_count === pause && (paused = true) && break
524529
end
525530

526531
nprocs = length(steps.procs)
527532
verbose && printstyled("All steps finished (", ntask_total,
528533
ntask_total > 1 ? " tasks" : " task", " for ", nprocs,
529534
nprocs > 1 ? " procedures)\n" : " procedure)\n", bold=true, color=:green)
530-
@inbounds for i in 1:nsps
531-
traces[i] = result(_procedure(sps[i]), traces[i])
535+
if !paused
536+
@inbounds for i in 1:nsps
537+
traces[i] = result(_procedure(sps[i]), traces[i])
538+
end
532539
end
533540

534541
if keepall
@@ -635,6 +642,7 @@ The following options are available for altering the behavior of `@specset`:
635642
- `verbose::Bool=false`: print the name of each step when it is called.
636643
- `keep=nothing`: names (of type `Symbol`) of additional objects to be returned.
637644
- `keepall::Bool=false`: return all objects generated by procedures along with arguments from the [`StatsSpec`](@ref)s.
645+
- `pause::Int=0`: break the iteration over [`StatsStep`](@ref)s after finishing the specified number of steps (for debugging).
638646
"""
639647
macro specset(args...)
640648
nargs = length(args)

src/did.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ The following options are available for altering the behavior of `@did`:
134134
- `verbose::Bool=false`: print the name of each step when it is called.
135135
- `keep=nothing`: names (of type `Symbol`) of additional objects to be returned.
136136
- `keepall::Bool=false`: return all objects generated by procedures along with arguments from the [`StatsSpec`](@ref)s.
137+
- `pause::Int=0`: break the iteration over [`StatsStep`](@ref)s after finishing the specified number of steps (for debugging).
137138
"""
138139
macro did(args...)
139140
nargs = length(args)

src/parallels.jl

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,16 @@ and a group that did not receive any treatment in any sample period.
9393
See also [`nevertreated`](@ref).
9494
9595
# Fields
96-
- `e::Vector{Int}`: group indices for units that did not receive any treatment.
96+
- `e::Tuple{Vararg{TimeType}}`: group indices for units that did not receive any treatment.
9797
- `c::C`: an instance of [`ParallelCondition`](@ref).
9898
- `s::S`: an instance of [`ParallelStrength`](@ref).
9999
"""
100100
struct NeverTreatedParallel{C,S} <: TrendParallel{C,S}
101-
e::Vector{Int}
101+
e::Tuple{Vararg{TimeType}}
102102
c::C
103103
s::S
104104
function NeverTreatedParallel(e, c::ParallelCondition, s::ParallelStrength)
105-
e = unique!(sort!([e...]))
105+
e = (unique!(sort!([e...]))...,)
106106
isempty(e) && error("field `e` cannot be empty")
107107
return new{typeof(c),typeof(s)}(e, c, s)
108108
end
@@ -111,11 +111,12 @@ end
111111
istreated(pr::NeverTreatedParallel, x) = !(x in pr.e)
112112

113113
show(io::IO, pr::NeverTreatedParallel) =
114-
print(IOContext(io, :compact=>true), "NeverTreated{", pr.c, ",", pr.s, "}(", pr.e, ")")
114+
print(IOContext(io, :compact=>true), "NeverTreated{", pr.c, ",", pr.s, "}",
115+
length(pr.e)==1 ? string("(", pr.e[1], ")") : pr.e)
115116

116117
function show(io::IO, ::MIME"text/plain", pr::NeverTreatedParallel)
117118
println(io, pr.s, " trends with any never-treated group:")
118-
print(io, " Never-treated groups: ", pr.e)
119+
print(io, " Never-treated groups: ", join(string.(pr.e), ", "))
119120
pr.c isa Unconditional || print(io, "\n ", pr.c)
120121
end
121122

@@ -133,14 +134,14 @@ a wrapper method of `nevertreated` calls this method.
133134
```jldoctest; setup = :(using DiffinDiffsBase)
134135
julia> nevertreated(-1)
135136
Parallel trends with any never-treated group:
136-
Never-treated groups: [-1]
137+
Never-treated groups: -1
137138
138139
julia> typeof(nevertreated(-1))
139-
NeverTreatedParallel{Unconditional,Exact,Tuple{Int64}}
140+
NeverTreatedParallel{Unconditional,Exact}
140141
141142
julia> nevertreated([-1, 0])
142143
Parallel trends with any never-treated group:
143-
Never-treated groups: [-1, 0]
144+
Never-treated groups: -1, 0
144145
145146
julia> nevertreated([-1, 0]) == nevertreated(-1:0) == nevertreated(Set([-1, 0]))
146147
true
@@ -167,8 +168,8 @@ and any group that received the treatment relatively late (or never receved).
167168
See also [`notyettreated`](@ref).
168169
169170
# Fields
170-
- `e::Vector{Int}`: group indices for units that received the treatment relatively late.
171-
- `ecut::Vector{Int}`: user-specified period(s) when units in a group in `e` started to receive treatment.
171+
- `e::Tuple{Vararg{TimeType}}`: group indices for units that received the treatment relatively late.
172+
- `ecut::Tuple{Vararg{TimeType}}`: user-specified period(s) when units in a group in `e` started to receive treatment.
172173
- `c::C`: an instance of [`ParallelCondition`](@ref).
173174
- `s::S`: an instance of [`ParallelStrength`](@ref).
174175
@@ -178,14 +179,14 @@ See also [`notyettreated`](@ref).
178179
- the sample has a rotating panel structure with periods overlapping with some others.
179180
"""
180181
struct NotYetTreatedParallel{C,S} <: TrendParallel{C,S}
181-
e::Vector{Int}
182-
ecut::Vector{Int}
182+
e::Tuple{Vararg{TimeType}}
183+
ecut::Tuple{Vararg{TimeType}}
183184
c::C
184185
s::S
185186
function NotYetTreatedParallel(e, ecut, c::ParallelCondition, s::ParallelStrength)
186-
e = unique!(sort!([e...]))
187+
e = (unique!(sort!([e...]))...,)
187188
isempty(e) && error("field `e` cannot be empty")
188-
ecut = unique!(sort!([ecut...]))
189+
ecut = (unique!(sort!([ecut...]))...,)
189190
isempty(ecut) && error("field `ecut` cannot be empty")
190191
return new{typeof(c),typeof(s)}(e, ecut, c, s)
191192
end
@@ -194,12 +195,13 @@ end
194195
istreated(pr::NotYetTreatedParallel, x) = !(x in pr.e)
195196

196197
show(io::IO, pr::NotYetTreatedParallel) =
197-
print(IOContext(io, :compact=>true), "NotYetTreated{", pr.c, ",", pr.s, "}(", pr.e, ")")
198+
print(IOContext(io, :compact=>true), "NotYetTreated{", pr.c, ",", pr.s, "}",
199+
length(pr.e)==1 ? string("(", pr.e[1], ")") : pr.e)
198200

199201
function show(io::IO, ::MIME"text/plain", pr::NotYetTreatedParallel)
200202
println(io, pr.s, " trends with any not-yet-treated group:")
201-
println(io, " Not-yet-treated groups: ", pr.e)
202-
print(io, " Treated since: ", pr.ecut)
203+
println(io, " Not-yet-treated groups: ", join(string.(pr.e), ", "))
204+
print(io, " Treated since: ", join(string.(pr.ecut), ", "))
203205
pr.c isa Unconditional || print(io, "\n ", pr.c)
204206
end
205207

@@ -218,21 +220,21 @@ a wrapper method of `notyettreated` calls this method.
218220
```jldoctest; setup = :(using DiffinDiffsBase)
219221
julia> notyettreated(5)
220222
Parallel trends with any not-yet-treated group:
221-
Not-yet-treated groups: [5]
222-
Treated since: [5]
223+
Not-yet-treated groups: 5
224+
Treated since: 5
223225
224226
julia> typeof(notyettreated(5))
225-
NotYetTreatedParallel{Unconditional,Exact,Tuple{Int64},Tuple{Int64}}
227+
NotYetTreatedParallel{Unconditional,Exact}
226228
227229
julia> notyettreated([-1, 5, 6], 5)
228230
Parallel trends with any not-yet-treated group:
229-
Not-yet-treated groups: [-1, 5, 6]
230-
Treated since: [5]
231+
Not-yet-treated groups: -1, 5, 6
232+
Treated since: 5
231233
232234
julia> notyettreated([4, 5, 6], [4, 5, 6])
233235
Parallel trends with any not-yet-treated group:
234-
Not-yet-treated groups: [4, 5, 6]
235-
Treated since: [4, 5, 6]
236+
Not-yet-treated groups: 4, 5, 6
237+
Treated since: 4, 5, 6
236238
```
237239
"""
238240
notyettreated(e, ecut, c::ParallelCondition, s::ParallelStrength) =

src/treatments.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,28 @@ See also [`dynamic`](@ref).
4141
4242
# Fields
4343
- `time::Symbol`: column name of data representing calendar time.
44-
- `exc::Vector{Int}`: excluded relative time.
44+
- `exc::Tuple{Vararg{Int}}`: excluded relative time.
4545
- `s::S`: an instance of [`TreatmentSharpness`](@ref).
4646
"""
4747
struct DynamicTreatment{S<:TreatmentSharpness} <: AbstractTreatment
4848
time::Symbol
49-
exc::Vector{Int}
49+
exc::Tuple{Vararg{Int}}
5050
s::S
5151
function DynamicTreatment(time::Symbol, exc, s::TreatmentSharpness)
52-
exc = exc !== nothing ? unique!(sort!([exc...])) : Int[]
52+
exc = exc !== nothing ? (unique!(sort!([exc...]))...,) : ()
5353
return new{typeof(s)}(time, exc, s)
5454
end
5555
end
5656

5757
show(io::IO, tr::DynamicTreatment) =
58-
print(IOContext(io, :compact=>true), "Dynamic{", tr.s, "}(",
59-
isempty(tr.exc) ? "none" : tr.exc, ")")
58+
print(IOContext(io, :compact=>true), "Dynamic{", tr.s, "}",
59+
length(tr.exc)==1 ? string("(", tr.exc[1], ")") : tr.exc)
6060

6161
function show(io::IO, ::MIME"text/plain", tr::DynamicTreatment)
6262
println(io, tr.s, " dynamic treatment:")
6363
println(io, " column name of time variable: ", tr.time)
64-
print(io, " excluded relative time: ", isempty(tr.exc) ? "none" : tr.exc)
64+
print(io, " excluded relative time: ",
65+
isempty(tr.exc) ? "none" : join(string.(tr.exc), ", "))
6566
end
6667

6768
"""
@@ -77,20 +78,20 @@ a wrapper method of `dynamic` calls this method.
7778
julia> dynamic(:month, -1)
7879
Sharp dynamic treatment:
7980
column name of time variable: month
80-
excluded relative time: [-1]
81+
excluded relative time: -1
8182
8283
julia> typeof(dynamic(:month, -1))
83-
DynamicTreatment{SharpDesign,Tuple{Int64}}
84+
DynamicTreatment{SharpDesign}
8485
8586
julia> dynamic(:month, -3:-1)
8687
Sharp dynamic treatment:
8788
column name of time variable: month
88-
excluded relative time: [-3, -2, -1]
89+
excluded relative time: -3, -2, -1
8990
9091
julia> dynamic(:month, [-2,-1], sharp())
9192
Sharp dynamic treatment:
9293
column name of time variable: month
93-
excluded relative time: [-2, -1]
94+
excluded relative time: -2, -1
9495
```
9596
"""
9697
dynamic(time::Symbol, exc, s::TreatmentSharpness=sharp()) =

test/StatsProcedures.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ end
190190
@test p6 == PooledStatsProcedure(ps, shared)
191191
@test length(p6) == 5
192192

193+
ps = (rp, cp)
194+
shared = [SharedStatsStep(rp[1], 1), SharedStatsStep(rp[2], 1),
195+
SharedStatsStep(rp[3], 1), SharedStatsStep(cp[1], 2), SharedStatsStep(cp[2], 2)]
196+
p7 = pool(rp, cp)
197+
@test p7 == PooledStatsProcedure(ps, shared)
198+
@test length(p7) == 5
199+
193200
@test sprint(show, p1) == "PooledStatsProcedure"
194201
@test sprint(show, MIME("text/plain"), p1) == """
195202
PooledStatsProcedure with 3 steps from 1 procedure:
@@ -305,6 +312,8 @@ testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a
305312
@test proceed([s10], keepall=true) == NamedTuple[NamedTuple()]
306313
@test proceed([s10], keep=:result) == NamedTuple[NamedTuple()]
307314

315+
@test proceed([s1], pause=1) == ["b"]
316+
308317
@test_throws ArgumentError proceed(StatsSpec[])
309318
end
310319

@@ -354,6 +363,10 @@ end
354363
r = @specset [verbose keep=[:a]] a=a begin
355364
StatsSpec(testformatter(testparser(RP; b="b"))...) end
356365
@test r == [(a="a0", result="a0a0b")]
366+
367+
r = @specset [verbose pause=1] a=a begin
368+
StatsSpec(testformatter(testparser(RP; b="b"))...) end
369+
@test r == ["b"]
357370

358371
s0 = @specset [noproceed] for i in 1:3
359372
a = "a"*string(i)

test/did.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,20 @@ end
106106
@test sprint(show, sp) == "unnamed"
107107
@test sprint(show, MIME("text/plain"), sp) == """
108108
unnamed (StatsSpec for TestDID):
109-
Dynamic{S}([-1])
110-
NeverTreated{U,P}([-1])"""
109+
Dynamic{S}(-1)
110+
NeverTreated{U,P}(-1)"""
111111

112112
sp = StatsSpec("", TestDID, (tr=dynamic(:time,-1),))
113113
@test sprint(show, sp) == "unnamed"
114114
@test sprint(show, MIME("text/plain"), sp) == """
115115
unnamed (StatsSpec for TestDID):
116-
Dynamic{S}([-1])"""
116+
Dynamic{S}(-1)"""
117117

118-
sp = StatsSpec("name", TestDID, (pr=nevertreated(-1),))
118+
sp = StatsSpec("name", TestDID, (pr=notyettreated(-1),))
119119
@test sprint(show, sp) == "name"
120120
@test sprint(show, MIME("text/plain"), sp) == """
121121
name (StatsSpec for TestDID):
122-
NeverTreated{U,P}([-1])"""
122+
NotYetTreated{U,P}(-1)"""
123123
end
124124
end
125125

test/parallels.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ end
7272
end
7373

7474
@testset "show" begin
75-
@test sprint(show, nt0) == "NeverTreated{U,P}([0])"
75+
@test sprint(show, nt0) == "NeverTreated{U,P}(0)"
7676
@test sprint(show, MIME("text/plain"), nt0) == """
7777
Parallel trends with any never-treated group:
78-
Never-treated groups: [0]"""
78+
Never-treated groups: 0"""
7979

80-
@test sprint(show, nt1) == "NeverTreated{U,P}([0, 1])"
80+
@test sprint(show, nt1) == "NeverTreated{U,P}(0, 1)"
8181
@test sprint(show, MIME("text/plain"), nt1) == """
8282
Parallel trends with any never-treated group:
83-
Never-treated groups: [0, 1]"""
83+
Never-treated groups: 0, 1"""
8484
end
8585
end
8686

@@ -158,29 +158,29 @@ end
158158
end
159159

160160
@testset "show" begin
161-
@test sprint(show, ny0) == "NotYetTreated{U,P}([0])"
161+
@test sprint(show, ny0) == "NotYetTreated{U,P}(0)"
162162
@test sprint(show, MIME("text/plain"), ny0) == """
163163
Parallel trends with any not-yet-treated group:
164-
Not-yet-treated groups: [0]
165-
Treated since: [0]"""
164+
Not-yet-treated groups: 0
165+
Treated since: 0"""
166166

167-
@test sprint(show, ny1) == "NotYetTreated{U,P}([0, 1])"
167+
@test sprint(show, ny1) == "NotYetTreated{U,P}(0, 1)"
168168
@test sprint(show, MIME("text/plain"), ny1) == """
169169
Parallel trends with any not-yet-treated group:
170-
Not-yet-treated groups: [0, 1]
171-
Treated since: [0]"""
170+
Not-yet-treated groups: 0, 1
171+
Treated since: 0"""
172172

173-
@test sprint(show, ny2) == "NotYetTreated{U,P}([0, 1])"
173+
@test sprint(show, ny2) == "NotYetTreated{U,P}(0, 1)"
174174
@test sprint(show, MIME("text/plain"), ny2) == """
175175
Parallel trends with any not-yet-treated group:
176-
Not-yet-treated groups: [0, 1]
177-
Treated since: [0]"""
176+
Not-yet-treated groups: 0, 1
177+
Treated since: 0"""
178178

179-
@test sprint(show, ny3) == "NotYetTreated{U,P}([0, 1])"
179+
@test sprint(show, ny3) == "NotYetTreated{U,P}(0, 1)"
180180
@test sprint(show, MIME("text/plain"), ny3) == """
181181
Parallel trends with any not-yet-treated group:
182-
Not-yet-treated groups: [0, 1]
183-
Treated since: [0, 1]"""
182+
Not-yet-treated groups: 0, 1
183+
Treated since: 0, 1"""
184184
end
185185
end
186186

0 commit comments

Comments
 (0)