Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit b3c99b2

Browse files
authored
Improve StatsSpec and fix an issue with macro did (#10)
1 parent e653da5 commit b3c99b2

File tree

3 files changed

+65
-40
lines changed

3 files changed

+65
-40
lines changed

src/StatsProcedures.jl

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ See also [`required`](@ref) and [`default`](@ref).
5454
"""
5555
transformed(::StatsStep, ::NamedTuple) = ()
5656

57-
_get(a::NTuple{N,Symbol}, @nospecialize(nt::NamedTuple)) where N =
58-
map(s->getfield(nt, s), a)
59-
_get(a::NamedTuple{S1}, nt::NamedTuple{S2}) where {S1,S2} =
60-
map(s->getfield(ifelse(sym_in(s, S2), nt, a), s), S1)
57+
_get(nt::NamedTuple, args::Tuple{Vararg{Symbol}}) =
58+
map(s->getfield(nt, s), args)
59+
60+
# This method is useful for obtaining default values
61+
_get(nt::NamedTuple, args::NamedTuple) =
62+
map(s->getfield(ifelse(sym_in(s, keys(nt)), nt, args), s), keys(args))
6163

6264
"""
6365
groupargs(s::StatsStep, ntargs::NamedTuple)
@@ -75,7 +77,7 @@ Instead, one should define methods for
7577
See also [`combinedargs`](@ref).
7678
"""
7779
groupargs(s::StatsStep, @nospecialize(ntargs::NamedTuple)) =
78-
(_get(required(s), ntargs)..., _get(default(s), ntargs)...,
80+
(_get(ntargs, required(s))..., _get(ntargs, default(s))...,
7981
transformed(s, ntargs)...)
8082

8183
"""
@@ -162,6 +164,15 @@ function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where
162164
end
163165
end
164166

167+
"""
168+
_get_default(p::AbstractStatsProcedure, ntargs::NamedTuple)
169+
170+
Obtain default values for arguments of each [`StatsStep`](@ref) in procedure `p`
171+
and then merge any default value for arguments not found in `ntargs` into `ntargs`.
172+
"""
173+
_get_default(p::AbstractStatsProcedure, @nospecialize(ntargs::NamedTuple)) =
174+
merge((default(s) for s in p)..., ntargs)
175+
165176
"""
166177
SharedStatsStep
167178
@@ -336,19 +347,20 @@ function show(io::IO, ::MIME"text/plain", p::PooledStatsProcedure)
336347
end
337348

338349
"""
339-
StatsSpec{Alias, T<:AbstractStatsProcedure}
350+
StatsSpec{T<:AbstractStatsProcedure}
340351
341352
Record the specification for a statistical procedure of type `T`.
342353
343354
An instance of `StatsSpec` is callable and
344355
its fields provide all information necessary for conducting the procedure.
345-
An optional name for the specification can be attached as parameter `Alias`.
356+
An optional name for the specification can be specified.
346357
347358
# Fields
348-
- `args::NamedTuple`: arguments for the [`StatsStep`](@ref)s in `T`.
359+
- `name::String`: a name for the specification (takes `""` if not specified).
360+
- `args::NamedTuple`: arguments for the [`StatsStep`](@ref)s in `T` (default values are merged into `args` if not found in `args`).
349361
350362
# Methods
351-
(sp::StatsSpec{A,T})(; verbose::Bool=false, keep=nothing, keepall::Bool=false)
363+
(sp::StatsSpec{T})(; verbose::Bool=false, keep=nothing, keepall::Bool=false)
352364
353365
Execute the procedure of type `T` with the arguments specified in `args`.
354366
By default, a dedicated result object for `T` is returned if it is available.
@@ -359,37 +371,38 @@ Otherwise, the last value returned by the last [`StatsStep`](@ref) is returned.
359371
- `keep=nothing`: names (of type `Symbol`) of additional objects to be returned.
360372
- `keepall::Bool=false`: return all objects returned by each step.
361373
"""
362-
struct StatsSpec{Alias, T<:AbstractStatsProcedure}
374+
struct StatsSpec{T<:AbstractStatsProcedure}
375+
name::String
363376
args::NamedTuple
364377
StatsSpec(name::Union{Symbol,String},
365378
T::Type{<:AbstractStatsProcedure}, @nospecialize(args::NamedTuple)) =
366-
new{Symbol(name),T}(args)
379+
new{T}(string(name), args)
367380
end
368381

369382
"""
370-
==(x::StatsSpec{A1,T}, y::StatsSpec{A2,T})
383+
==(x::StatsSpec{T}, y::StatsSpec{T})
371384
372385
Test whether two instances of [`StatsSpec`](@ref)
373386
with the same parameter `T` also have the same field `args`.
374387
375388
See also [`≊`](@ref).
376389
"""
377-
==(x::StatsSpec{A1,T}, y::StatsSpec{A2,T}) where {A1,A2,T} = x.args == y.args
390+
==(x::StatsSpec{T}, y::StatsSpec{T}) where T = x.args == y.args
378391

379392
"""
380-
≊(x::StatsSpec{A1,T}, y::StatsSpec{A2,T})
393+
≊(x::StatsSpec{T}, y::StatsSpec{T})
381394
382395
Test whether two instances of [`StatsSpec`](@ref)
383396
with the same parameter `T` also have the field `args`
384397
containing the same sets of key-value pairs
385398
while ignoring the orders.
386399
"""
387-
(x::StatsSpec{A1,T}, y::StatsSpec{A2,T}) where {A1,A2,T} = x.args y.args
400+
(x::StatsSpec{T}, y::StatsSpec{T}) where T = x.args y.args
388401

389-
_procedure(::StatsSpec{A,T}) where {A,T} = T
402+
_procedure(::StatsSpec{T}) where T = T
390403

391-
function (sp::StatsSpec{A,T})(;
392-
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
404+
function (sp::StatsSpec{T})(;
405+
verbose::Bool=false, keep=nothing, keepall::Bool=false) where T
393406
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
394407
ntall = result(T, foldl(|>, T(), init=args))
395408
if keepall
@@ -416,12 +429,12 @@ function (sp::StatsSpec{A,T})(;
416429
end
417430
end
418431

419-
show(io::IO, ::StatsSpec{A}) where {A} = print(io, A==Symbol("") ? "unnamed" : A)
432+
show(io::IO, sp::StatsSpec) = print(io, sp.name=="" ? "unnamed" : sp.name)
420433

421434
_show_args(::IO, ::StatsSpec) = nothing
422435

423-
function show(io::IO, ::MIME"text/plain", sp::StatsSpec{A,T}) where {A,T}
424-
print(io, A==Symbol("") ? "unnamed" : A, " (", typeof(sp).name.name,
436+
function show(io::IO, ::MIME"text/plain", sp::StatsSpec{T}) where T
437+
print(io, sp.name=="" ? "unnamed" : sp.name, " (", typeof(sp).name.name,
425438
" for ", T.parameters[1], ")")
426439
_show_args(io, sp)
427440
end
@@ -517,19 +530,18 @@ function proceed(sps::AbstractVector{<:StatsSpec};
517530
end
518531

519532
function _parse!(options::Expr, args)
520-
noproceed = false
521533
for arg in args
522534
# Assume a symbol means the kwarg takes value true
523535
if isa(arg, Symbol)
524536
if arg == :noproceed
525-
noproceed = true
537+
return true
526538
else
527539
key = Expr(:quote, arg)
528540
push!(options.args, Expr(:call, :(=>), key, true))
529541
end
530542
elseif isexpr(arg, :(=))
531543
if arg.args[1] == :noproceed
532-
noproceed = arg.args[2]
544+
return arg.args[2]
533545
else
534546
key = Expr(:quote, arg.args[1])
535547
push!(options.args, Expr(:call, :(=>), key, arg.args[2]))
@@ -538,18 +550,24 @@ function _parse!(options::Expr, args)
538550
throw(ArgumentError("unexpected option $arg"))
539551
end
540552
end
541-
return noproceed
553+
return false
542554
end
543555

544556
function _spec_walker1(x, parsers, formatters, ntargs_set)
545-
@capture(x, StatsSpec(formatter_(parser_(rawargs__))...)(;o__)) || return x
557+
@capture(x, spec_(formatter_(parser_(rawargs__))...)(;o__)) || return x
558+
# spec may be a GlobalRef
559+
name = spec isa Symbol ? spec : spec.name
560+
name == :StatsSpec || return x
546561
push!(parsers, parser)
547562
push!(formatters, formatter)
548563
return :(push!($ntargs_set, $parser($(rawargs...))))
549564
end
550565

551566
function _spec_walker2(x, parsers, formatters, ntargs_set)
552-
@capture(x, StatsSpec(formatter_(parser_(rawargs__))...)) || return x
567+
@capture(x, spec_(formatter_(parser_(rawargs__))...)) || return x
568+
# spec may be a GlobalRef
569+
name = spec isa Symbol ? spec : spec.name
570+
name == :StatsSpec || return x
553571
push!(parsers, parser)
554572
push!(formatters, formatter)
555573
return :(push!($ntargs_set, $parser($(rawargs...))))
@@ -613,7 +631,8 @@ macro specset(args...)
613631
isexpr(specs, :block, :for) ||
614632
throw(ArgumentError("last argument to @specset must be begin/end block or for loop"))
615633

616-
parsers, formatters, ntargs_set = Symbol[], Symbol[], NamedTuple[]
634+
# parser and formatter may be GlobalRef
635+
parsers, formatters, ntargs_set = [], [], NamedTuple[]
617636
walked = postwalk(x->_spec_walker1(x, parsers, formatters, ntargs_set), specs)
618637
walked = postwalk(x->_spec_walker2(x, parsers, formatters, ntargs_set), walked)
619638
nparser = length(unique!(parsers))

src/did.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ with the specified arguments.
8787
"""
8888
didspec(args...; kwargs...) = StatsSpec(_didargs(args...; kwargs...)...)
8989

90-
function _show_args(io::IO, sp::StatsSpec{A,<:DiffinDiffsEstimator}) where A
90+
function _show_args(io::IO, sp::StatsSpec{<:DiffinDiffsEstimator})
9191
if haskey(sp.args, :tr) || haskey(sp.args, :pr)
9292
print(io, ":")
9393
haskey(sp.args, :tr) && print(io, "\n ", sp.args[:tr])
@@ -139,7 +139,7 @@ macro did(args...)
139139
nargs = length(args)
140140
options = :(Dict{Symbol, Any}())
141141
noproceed = false
142-
didargs = []
142+
didargs = ()
143143
if nargs > 0
144144
if isexpr(args[1], :vect, :hcat, :vcat)
145145
noproceed = _parse!(options, args[1].args)
@@ -150,9 +150,9 @@ macro did(args...)
150150
end
151151
dargs, dkwargs = _args_kwargs(didargs)
152152
if noproceed
153-
return esc(:(StatsSpec(valid_didargs(parse_didargs($(dargs...); $(dkwargs...)))...)))
153+
return :(StatsSpec(valid_didargs(parse_didargs($(esc.(dargs)...); $(esc.(dkwargs)...)))...))
154154
else
155-
return esc(:(StatsSpec(valid_didargs(parse_didargs($(dargs...); $(dkwargs...)))...)(; $options...)))
155+
return :(StatsSpec(valid_didargs(parse_didargs($(esc.(dargs)...); $(esc.(dkwargs)...)))...)(; $(esc(options))...))
156156
end
157157
end
158158

test/StatsProcedures.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using DiffinDiffsBase: _f, _get, groupargs,
1+
using DiffinDiffsBase: _f, _get, groupargs, _get_default,
22
_sharedby, _show_args, _args_kwargs, _parse!, pool, proceed
33
import DiffinDiffsBase: required, default, transformed, combinedargs
44

@@ -26,20 +26,20 @@ combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs]
2626

2727
testinvalidstep(a::String, b::String) = b, false
2828
const TestInvalidStep = StatsStep{:TestInvalidStep, typeof(testinvalidstep)}
29-
default(::TestInvalidStep) = (a="a",b="b")
29+
default(::TestInvalidStep) = (a="a", b="b")
3030

3131
const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testinvalidstep)}
3232

3333
@testset "StatsStep" begin
3434
@testset "_get" begin
35-
@test _get((), NamedTuple()) == ()
36-
@test _get((:a,), (a=1, b=2)) == (1,)
37-
@test_throws ErrorException _get((:a,), (b=2,))
35+
@test _get(NamedTuple(), ()) == ()
36+
@test _get((a=1, b=2), (:a,)) == (1,)
37+
@test_throws ErrorException _get((b=2,), (:a,))
3838

3939
@test _get(NamedTuple(), NamedTuple()) == ()
40-
@test _get((a=1,), (b=2,)) == (1,)
41-
@test _get((a=1,), (a=1, b=2)) == (1,)
42-
@test _get((a=1, b=2), (a=2,)) == (2, 2)
40+
@test _get((a=1,), (b=2,)) == (2,)
41+
@test _get((a=1,), (a=2, b=2)) == (1, 2)
42+
@test _get((a=1, b=2), (a=2,)) == (1,)
4343
end
4444

4545
@testset "args" begin
@@ -126,6 +126,12 @@ const ep = EP()
126126
NullProcedure (TestProcedure with 0 step)"""
127127
end
128128

129+
@testset "_get_default" begin
130+
@test _get_default(rp, NamedTuple()) == (a="a", b="b")
131+
@test _get_default(rp, (a="a1",)) == (a="a1", b="b")
132+
@test _get_default(rp, (c="c",)) == (a="a", b="b", c="c")
133+
end
134+
129135
@testset "SharedStatsStep" begin
130136
s1 = SharedStatsStep(TestRegStep(), 1)
131137
s2 = SharedStatsStep(TestRegStep(), [3,2])

0 commit comments

Comments
 (0)