From 23f79a8592f16a0275c2c6a959d17f7edae91090 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:24:50 -0400 Subject: [PATCH 01/17] first draft of core functionality --- src/choice_map2/array_interface.jl | 92 ++++++++++ src/choice_map2/choice_map.jl | 247 ++++++++++++++++++++++++++ src/choice_map2/dynamic_choice_map.jl | 153 ++++++++++++++++ src/choice_map2/nested_view.jl | 81 +++++++++ src/choice_map2/static_choice_map.jl | 131 ++++++++++++++ 5 files changed, 704 insertions(+) create mode 100644 src/choice_map2/array_interface.jl create mode 100644 src/choice_map2/choice_map.jl create mode 100644 src/choice_map2/dynamic_choice_map.jl create mode 100644 src/choice_map2/nested_view.jl create mode 100644 src/choice_map2/static_choice_map.jl diff --git a/src/choice_map2/array_interface.jl b/src/choice_map2/array_interface.jl new file mode 100644 index 000000000..f88c5b116 --- /dev/null +++ b/src/choice_map2/array_interface.jl @@ -0,0 +1,92 @@ +### interface for to_array and fill_array ### + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: ChoiceMap` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) <: start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) + ValueChoiceMap(arr[start_idx]) +end \ No newline at end of file diff --git a/src/choice_map2/choice_map.jl b/src/choice_map2/choice_map.jl new file mode 100644 index 000000000..d7e7101fe --- /dev/null +++ b/src/choice_map2/choice_map.jl @@ -0,0 +1,247 @@ +######################### +# choice map interface # +######################### + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `ValueChoiceMap`s.) +""" +function get_submaps_shallow end + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +function get_submap end + +# provide _get_submap so when users overwrite get_submap(choices::CustomChoiceMap, addr::Pair) +# they can just call _get_submap for convenience if they want +@inline function _get_submap(choices::ChoiceMap, addr::Pair) + (first, rest) = addr + submap = get_submap(choices, first) + get_submap(submap, rest) +end +@inline get_submap(choices::ChoiceMap, addr::Pair) = _get_submap(choices, addr) + +""" + has_value(choices::ChoiceMap) + +Returns true if `choices` is a `ValueChoiceMap`. + + has_value(choices::ChoiceMap, addr) + +Returns true if `choices` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(::ChoiceMap) = false +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; +throws a `KeyError` if `choices` is not a `ValueChoiceMap`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `KeyError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +get_value(::ChoiceMap) = throw(KeyError(nothing)) +get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +# get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `ValueChoiceMap`. +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + filter(! ∘ has_value, get_submaps_shallow(choices)) +end + +# a choicemap is empty if it has no submaps and no value +Base.isempty(c::ChoiceMap) = isempty(get_submaps_shallow(c)) && !has_value(c) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + +""" + EmptyChoiceMap + +A choicemap with no submaps or values. +""" +struct EmptyChoiceMap <: ChoiceMap end + +@inline has_value(::EmptyChoiceMap, addr...) = false +@inline get_value(::EmptyChoiceMap) = throw(KeyError(nothing)) +@inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() +@inline Base.isempty(::EmptyChoiceMap) = true +@inline get_submaps_shallow(::EmptyChoiceMap) = () + +""" + ValueChoiceMap + +A leaf-node choicemap. Stores a single value. +""" +struct ValueChoiceMap{T} <: ChoiceMap + val::T +end + +@inline has_value(choices::ValueChoiceMap) = true +@inline get_value(choices::ValueChoiceMap) = choices.val +@inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() +@inline get_submaps_shallow(choices::ValueChoiceMap) = () +Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) + +""" + choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + +Merge two choice maps. + +It is an error if the choice maps both have values at the same address, or if +one choice map has a value at an address that is the prefix of the address of a +value in the other choice map. +""" +function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + choices = DynamicChoiceMap() + for (key, submap) in get_submaps_shallow(choices1) + set_submap!(choices, key, merge(submap, get_submap(choices2, key))) + end + choices +end +Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c +Base.merge(c::ValueChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ValueChoiceMap) = c +Base.merge(::ValueChoiceMap, ::ChoiceMap) = error("ValueChoiceMaps cannot be merged") +Base.merge(::ChoiceMap, ::ValueChoiceMap) = error("ValueChoiceMaps cannot be merged") + +""" +Variadic merge of choice maps. +""" +function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) + reduce(Base.merge, choices_rest; init=choices1) +end + +function Base.:(==)(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if get_submap(b, addr) != submap + return false + end + end + return true +end + +function Base.isapprox(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if !isapprox(get_submap(b, addr), submap) + return false + end + end + return true +end + +""" + selected_choices = get_selected(choices::ChoiceMap, selection::Selection) + +Filter the choice map to include only choices in the given selection. + +Returns a new choice map. +""" +function get_selected( + choices::ChoiceMap, selection::Selection) + # TODO: return a `FilteringChoiceMap` which does this filtering lazily! + output = choicemap() + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) && addr in selection + output[addr] = get_value(submap) + else + subselection = selection[addr] + set_submap!(output, addr, get_selected(submap, subselection)) + end + end + output +end + +function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_values = collect(get_values_shallow(choices)) + key_and_submaps = collect(get_nonvalue_submaps_shallow(choices)) + n = length(key_and_values) + length(key_and_submaps) + cur = 1 + for (key, value) in key_and_values + # For strings, `print` is what we want; `Base.show` includes quote marks. + # https://docs.julialang.org/en/v1/base/io-network/#Base.print + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + for (key, submap) in key_and_submaps + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) + _show_pretty(io, choices, 0, ()) +end + +export ChoiceMap, ValueChoiceMap, EmptyChoiceMap +export get_submap, get_submaps_shallow +export get_value, has_value +export get_values_shallow, get_nonvalue_submaps_shallow + +include("array_interface.jl") +include("dynamic_choice_map.jl") +include("static_choice_map.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map2/dynamic_choice_map.jl new file mode 100644 index 000000000..a93a49021 --- /dev/null +++ b/src/choice_map2/dynamic_choice_map.jl @@ -0,0 +1,153 @@ +####################### +# dynamic assignment # +####################### + +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, <:ChoiceMap} +end + +""" + struct DynamicChoiceMap <: ChoiceMap .. end + +A mutable map from arbitrary hierarchical addresses to values. + + choices = DynamicChoiceMap() + +Construct an empty map. + + choices = DynamicChoiceMap(tuples...) + +Construct a map containing each of the given (addr, value) tuples. +""" +function DynamicChoiceMap() + DynamicChoiceMap(Dict()) +end + +function DynamicChoiceMap(tuples...) + choices = DynamicChoiceMap() + for (addr, value) in tuples + choices[addr] = value + end + choices +end + +""" + choices = DynamicChoiceMap(other::ChoiceMap) + +Copy a choice map, returning a mutable choice map. +""" +function DynamicChoiceMap(other::ChoiceMap) + choices = DynamicChoiceMap() + for (addr, submap) in get_submaps_shallow(other) + if choices isa ValueChoiceMap + set_submap!(choices, addr, submap) + else + set_submap!(choices, addr, DynamicChoiceMap(submap)) + end + end +end + +DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given address, value tuples. +""" +function choicemap(tuples...) + DynamicChoiceMap(tuples...) +end + +get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +function get_submap(choices::DynamicChoiceMap, addr) + if haskey(choices.submaps, addr) + choices.submaps[addr] + else + EmptyChoiceMap() + end +end +get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) + +# mutation (not part of the assignment interface) + +""" + set_value!(choices::DynamicChoiceMap, addr, value) + +Set the given value for the given address. + +Will cause any previous value or sub-assignment at this address to be deleted. +It is an error if there is already a value present at some prefix of the given address. + +The following syntactic sugar is provided: + + choices[addr] = value +""" +function set_value!(choices::DynamicChoiceMap, addr, value) + delete!(choices.submaps, addr) + choices.submaps[addr] = ValueChoiceMap(value) +end + +function set_value!(choices::DynamicChoiceMap, addr::Pair, value) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_value!(choices.submaps[first], rest, value) +end + +""" + set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) + +Replace the sub-assignment rooted at the given address with the given sub-assignment. +Set the given value for the given address. + +Will cause any previous value or sub-assignment at the given address to be deleted. +It is an error if there is already a value present at some prefix of address. +""" +function set_submap!(choices::DynamicChoiceMap, addr, new_node) + delete!(choices.submaps, addr) + if !isempty(new_node) + choices.submaps[addr] = new_node + end +end + +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_submap!(choices.submaps[first], rest, new_node) +end + +Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) + +function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + choices = DynamicChoiceMap() + keys_sorted = sort(collect(keys(choices.submaps))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) + idx += n_read + choices.submaps[key] = submap + end + (idx - start_idx, choices) +end + +export DynamicChoiceMap +export choicemap +export set_value! +export set_submap! \ No newline at end of file diff --git a/src/choice_map2/nested_view.jl b/src/choice_map2/nested_view.jl new file mode 100644 index 000000000..6693234fb --- /dev/null +++ b/src/choice_map2/nested_view.jl @@ -0,0 +1,81 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) +ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) + a.choice_map = b.choice_map +end +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map2/static_choice_map.jl new file mode 100644 index 000000000..e5e2d89e2 --- /dev/null +++ b/src/choice_map2/static_choice_map.jl @@ -0,0 +1,131 @@ +###################### +# static assignment # +###################### + +struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap + submaps::NamedTuple{Addrs, SubmapTypes} +end + +@inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) +@inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) + +# TODO: profiling! +@generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} + if A in Addrs + quote choices.submaps[A] end + else + quote EmptyChoiceMap() end + end +end + +static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +function StaticChoiceMap(other::ChoiceMap) + keys_and_nodes = get_submaps_shallow(other) + (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + StaticChoiceMap(NamedTuple{addrs}(submaps)) +end +StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") + +# TODO: deep conversion to static choicemap + +""" + choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + +Return an assignment that contains `choices1` as a sub-assignment under `key1` +and `choices2` as a sub-assignment under `key2`. +""" +function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + StaticChoiceMap(NamedTuple{(key1, key2)}((choices1, choices2))) +end + +""" + (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + +Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any submaps at keys other than `key1` and `key2`. +""" +function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + if length(collect(get_submaps_shallow(choices))) != 2 + error("Not a pair") + end + (get_submap(choices, key1), get_submap(choices, key2)) +end + +@generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, + choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, ::Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, ::Type{<:ChoiceMap}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyChoiceMap) + type2 = get(addr_to_type2, addr, EmptyChoiceMap) + if ((type1 <: ValueChoiceMap && type2 != EmptyChoiceMap) + || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) + error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") + end + if type1 <: ValueChoiceMap + push!(submap_exprs, + quote choices1.submaps[$addr] end + ) + elseif type2 <: ValueChoiceMap + push!(submap_exprs, + quote choices2.submaps[$addr] end + ) + else + push!(submap_exprs, + quote merge(choices1.submaps[$addr], choices2.submaps[$addr]) end + ) + end + end + + quote + StaticChoiceMap{$merged_addrs}(submap_exprs...) + end +end + +@generated function _from_array!(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} + + perm = sortperm(Addrs) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, submap_var_name = _from_array(proto_choices.submaps[$addr], arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...) )) + choices = StaticChoiceMap{Addrs, SubmapTypes}(submaps) + (idx - start_idx, choices) + end +end + +export StaticChoiceMap +export pair, unpair +export static_get_submap, static_get_value \ No newline at end of file From c9b1d4982e5f8f4903254adc982d4d5a216c5580 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:33:00 -0400 Subject: [PATCH 02/17] add support for address schemas --- src/choice_map2/choice_map.jl | 9 +++++++++ src/choice_map2/dynamic_choice_map.jl | 2 ++ src/choice_map2/static_choice_map.jl | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/src/choice_map2/choice_map.jl b/src/choice_map2/choice_map.jl index d7e7101fe..0ebb19f09 100644 --- a/src/choice_map2/choice_map.jl +++ b/src/choice_map2/choice_map.jl @@ -60,6 +60,13 @@ get_value(::ChoiceMap) = throw(KeyError(nothing)) get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) +""" +schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} + +Return the (top-level) address schema for the given choice map. +""" +function get_address_schema end + # get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow """ get_values_shallow(choices::ChoiceMap) @@ -108,6 +115,7 @@ struct EmptyChoiceMap <: ChoiceMap end @inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() @inline Base.isempty(::EmptyChoiceMap) = true @inline get_submaps_shallow(::EmptyChoiceMap) = () +@inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() """ ValueChoiceMap @@ -124,6 +132,7 @@ end @inline get_submaps_shallow(choices::ValueChoiceMap) = () Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() """ choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map2/dynamic_choice_map.jl index a93a49021..5dfca0b55 100644 --- a/src/choice_map2/dynamic_choice_map.jl +++ b/src/choice_map2/dynamic_choice_map.jl @@ -147,6 +147,8 @@ function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx: (idx - start_idx, choices) end +get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() + export DynamicChoiceMap export choicemap export set_value! diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map2/static_choice_map.jl index e5e2d89e2..3508762d7 100644 --- a/src/choice_map2/static_choice_map.jl +++ b/src/choice_map2/static_choice_map.jl @@ -126,6 +126,10 @@ end end end +function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} + StaticAddressSchema(set(Addrs)) +end + export StaticChoiceMap export pair, unpair export static_get_submap, static_get_value \ No newline at end of file From 1e0a58997d4717eb687aba6306b8b556108475bb Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sun, 17 May 2020 12:42:52 -0400 Subject: [PATCH 03/17] update choicemap docs --- docs/src/ref/choice_maps.md | 43 +++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index c065b1b32..8d3f4200e 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -8,13 +8,20 @@ ChoiceMap Choice maps are constructed by users to express observations and/or constraints on the traces of generative functions. Choice maps are also returned by certain Gen inference methods, and are used internally by various Gen inference methods. +A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses +for sub-choicemaps. Leaf nodes have type: +```@docs +ValueChoiceMap +``` + Choice maps provide the following methods: ```@docs +get_submap +get_submaps_shallow has_value get_value -get_submap get_values_shallow -get_submaps_shallow +get_nonvalue_submaps_shallow to_array from_array get_selected @@ -50,3 +57,35 @@ choicemap set_value! set_submap! ``` + +## Implementing custom choicemap types + +To implement a custom choicemap, one must implement +`get_submap` and `get_submaps_shallow`. +To avoid method ambiguity with the default +`get_submap(::ChoiceMap, ::Pair)`, one must implement both +```julia +get_submap(::CustomChoiceMap, addr) +``` +and +```julia +get_submap(::CustomChoiceMap, addr::Pair) +``` +To use the default implementation of `get_submap(_, ::Pair)`, +one may define +```julia +get_submap(c::CustomChoiceMap, addr::Pair) = _get_choicemap(c, addr) +``` + +Once `get_submap` and `get_submaps_shallow` are defined, default +implementations are provided for: +- `has_value` +- `get_value` +- `get_values_shallow` +- `get_nonvalue_submaps_shallow` +- `to_array` +- `get_selected` + +If one wishes to support `from_array`, they must implement +`_from_array`, as described in the documentation for +[`from_array`](@ref). \ No newline at end of file From 623bc8fcba7fc81eecb039a13d861baf06102d57 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 16:57:46 -0400 Subject: [PATCH 04/17] refactoring and tests --- docs/src/ref/choice_maps.md | 2 +- src/Gen.jl | 2 +- src/choice_map.jl | 1009 ----------------- .../array_interface.jl | 20 +- src/{choice_map2 => choice_map}/choice_map.jl | 57 +- .../dynamic_choice_map.jl | 20 +- .../nested_view.jl | 7 +- .../static_choice_map.jl | 52 +- src/dynamic/dynamic.jl | 31 +- src/dynamic/generate.jl | 2 +- src/dynamic/trace.jl | 36 +- src/dynamic/update.jl | 26 +- src/inference/kernel_dsl.jl | 11 +- src/modeling_library/call_at/call_at.jl | 5 +- src/modeling_library/choice_at/choice_at.jl | 4 +- src/modeling_library/recurse/recurse.jl | 18 +- src/modeling_library/vector.jl | 4 - src/static_ir/backprop.jl | 21 +- src/static_ir/trace.jl | 75 +- src/static_ir/update.jl | 7 +- test/assignment.jl | 224 ++-- test/benchmark.md | 21 + test/dynamic_dsl.jl | 14 +- test/modeling_library/call_at.jl | 26 +- test/modeling_library/choice_at.jl | 26 +- test/modeling_library/recurse.jl | 4 +- test/modeling_library/unfold.jl | 6 +- test/optional_args.jl | 2 +- test/runtests.jl | 2 +- test/static_ir/static_ir.jl | 10 +- test/tilde_sugar.jl | 2 +- 31 files changed, 347 insertions(+), 1399 deletions(-) delete mode 100644 src/choice_map.jl rename src/{choice_map2 => choice_map}/array_interface.jl (83%) rename src/{choice_map2 => choice_map}/choice_map.jl (82%) rename src/{choice_map2 => choice_map}/dynamic_choice_map.jl (93%) rename src/{choice_map2 => choice_map}/nested_view.jl (93%) rename src/{choice_map2 => choice_map}/static_choice_map.jl (68%) create mode 100644 test/benchmark.md diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 8d3f4200e..6c445df6f 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -30,7 +30,7 @@ Note that none of these methods mutate the choice map. Choice maps also implement: -- `Base.isempty`, which tests of there are no random choices in the choice map +- `Base.isempty`, which returns `false` if the choicemap contains no value or submaps, and `true` otherwise. - `Base.merge`, which takes two choice maps, and returns a new choice map containing all random choices in either choice map. It is an error if the choice maps both have values at the same address, or if one choice map has a value at an address that is the prefix of the address of a value in the other choice map. diff --git a/src/Gen.jl b/src/Gen.jl index 9f3da9e3a..fa2393596 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -37,7 +37,7 @@ include("backprop.jl") include("address.jl") # abstract and built-in concrete choice map data types -include("choice_map.jl") +include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") diff --git a/src/choice_map.jl b/src/choice_map.jl deleted file mode 100644 index b7891b40a..000000000 --- a/src/choice_map.jl +++ /dev/null @@ -1,1009 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -""" - submap = get_submap(choices::ChoiceMap, addr) - -Return the sub-assignment containing all choices whose address is prefixed by addr. - -It is an error if the assignment contains a value at the given address. If -there are no choices whose address is prefixed by addr then return an -`EmptyChoiceMap`. -""" -function get_submap end - -""" - value = get_value(choices::ChoiceMap, addr) - -Return the value at the given address in the assignment, or throw a KeyError if -no value exists. A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end - -""" - key_submap_iterable = get_submaps_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, submap::ChoiceMap)` for each top-level key -that has a non-empty sub-assignment. -""" -function get_submaps_shallow end - -""" - has_value(choices::ChoiceMap, addr) - -Return true if there is a value at the given address. -""" -function has_value end - -""" - key_submap_iterable = get_values_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, value)` for each -top-level key associated with a value. -""" -function get_values_shallow end - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - Base.isempty(choices::ChoiceMap) - -Return true if there are no values in the assignment. -""" -function Base.isempty(::ChoiceMap) - true -end - -@inline get_submap(choices::ChoiceMap, addr) = EmptyChoiceMap() -@inline has_value(choices::ChoiceMap, addr) = false -@inline get_value(choices::ChoiceMap, addr) = throw(KeyError(addr)) -@inline Base.getindex(choices::ChoiceMap, addr) = get_value(choices, addr) - -@inline function _has_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - has_value(submap, rest) -end - -@inline function _get_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_value(submap, rest) -end - -@inline function _get_submap(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -# assignments that have static address schemas should also support faster -# accessors, which make the address explicit in the type (Val(:foo) instaed of -# :foo) -function static_get_value end -function static_get_submap end - -function _fill_array! end -function _from_array end - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -# Implementation - -To support `to_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(value::T, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = value - 1 -end - -function _fill_array!(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -The order in which addresses are populated is determined by the prototype -assignment. It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -# Implementation - -To support `from_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`, and the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::T, arr::Vector{T}, start_idx::Int) where {T} - (1, arr[start_idx]) -end - -function _from_array(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(value) - (n_read, arr[start_idx:start_idx+n_read-1]) -end - - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, value) in get_values_shallow(choices1) - choices.leaf_nodes[key] = value - end - for (key, node1) in get_submaps_shallow(choices1) - node2 = get_submap(choices2, key) - node = merge(node1, node2) - choices.internal_nodes[key] = node - end - for (key, value) in get_values_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has leaf node at $key") - end - if haskey(choices.internal_nodes, key) - error("choices1 has internal node at $key and choices2 has leaf node at $key") - end - choices.leaf_nodes[key] = value - end - for (key, node) in get_submaps_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has internal node at $key") - end - if !haskey(choices.internal_nodes, key) - # otherwise it should already be included - choices.internal_nodes[key] = node - end - end - return choices -end - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || (get_value(b, addr) != value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || (get_value(a, addr) != value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if submap != get_submap(b, addr) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if submap != get_submap(a, addr) - return false - end - end - return true -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || !isapprox(get_value(b, addr), value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || !isapprox(get_value(a, addr), value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(submap, get_submap(b, addr)) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if !isapprox(submap, get_submap(a, addr)) - return false - end - end - return true -end - - -export ChoiceMap -export get_address_schema -export get_submap -export get_value -export has_value -export get_submaps_shallow -export get_values_shallow -export static_get_value -export static_get_submap -export to_array, from_array - - -###################### -# static assignment # -###################### - -struct StaticChoiceMap{R,S,T,U} <: ChoiceMap - leaf_nodes::NamedTuple{R,S} - internal_nodes::NamedTuple{T,U} - isempty::Bool -end - -function StaticChoiceMap{R,S,T,U}(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - -function StaticChoiceMap(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - - -# invariant: all internal_nodes are nonempty - -function get_address_schema(::Type{StaticChoiceMap{R,S,T,U}}) where {R,S,T,U} - keys = Set{Symbol}() - for (key, _) in zip(R, S.parameters) - push!(keys, key) - end - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -function Base.isempty(choices::StaticChoiceMap) - choices.isempty -end - -get_values_shallow(choices::StaticChoiceMap) = pairs(choices.leaf_nodes) -get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.internal_nodes) -has_value(choices::StaticChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::StaticChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# NOTE: there is no static_has_value because this is known from the static -# address schema - -## has_value ## - -function has_value(choices::StaticChoiceMap, key::Symbol) - haskey(choices.leaf_nodes, key) -end - -## get_submap ## - -function get_submap(choices::StaticChoiceMap, key::Symbol) - if haskey(choices.internal_nodes, key) - choices.internal_nodes[key] - elseif haskey(choices.leaf_nodes, key) - throw(KeyError(key)) - else - EmptyChoiceMap() - end -end - -function static_get_submap(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.internal_nodes[A] -end - -## get_value ## - -function get_value(choices::StaticChoiceMap, key::Symbol) - choices.leaf_nodes[key] -end - -function static_get_value(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.leaf_nodes[A] -end - -# convert from any other schema that has only Val{:foo} addresses -function StaticChoiceMap(other::ChoiceMap) - leaf_keys_and_nodes = collect(get_values_shallow(other)) - internal_keys_and_nodes = collect(get_submaps_shallow(other)) - if length(leaf_keys_and_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(leaf_keys_and_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) - end - if length(internal_keys_and_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(internal_keys_and_nodes...)) - else - (internal_keys, internal_nodes) = ((), ()) - end - StaticChoiceMap( - NamedTuple{leaf_keys}(leaf_nodes), - NamedTuple{internal_keys}(internal_nodes), - isempty(other)) -end - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple(), NamedTuple{(key1,key2)}((choices1, choices2)), - isempty(choices1) && isempty(choices2)) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any top-level values, or any non-empty top-level -sub-assignments at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if !isempty(get_values_shallow(choices)) || length(collect(get_submaps_shallow(choices))) > 2 - error("Not a pair") - end - a = get_submap(choices, key1) - b = get_submap(choices, key2) - (a, b) -end - -# TODO make a generated function? -function _fill_array!(choices::StaticChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for value in choices.leaf_nodes - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for node in choices.internal_nodes - n_written = _fill_array!(node, arr, idx) - idx += n_written - end - idx - start_idx -end - -@generated function _from_array( - proto_choices::StaticChoiceMap{R,S,T,U}, arr::Vector{V}, start_idx::Int) where {R,S,T,U,V} - leaf_node_keys = proto_choices.parameters[1] - leaf_node_types = proto_choices.parameters[2].parameters - internal_node_keys = proto_choices.parameters[3] - internal_node_types = proto_choices.parameters[4].parameters - - exprs = [quote idx = start_idx end] - leaf_node_names = [] - internal_node_names = [] - - # leaf nodes - for key in leaf_node_keys - value = gensym() - push!(leaf_node_names, value) - push!(exprs, quote - (n_read, $value) = _from_array(proto_choices.leaf_nodes.$key, arr, idx) - idx += n_read - end) - end - - # internal nodes - for key in internal_node_keys - node = gensym() - push!(internal_node_names, node) - push!(exprs, quote - (n_read, $node) = _from_array(proto_choices.internal_nodes.$key, arr, idx) - idx += n_read - end) - end - - quote - $(exprs...) - leaf_nodes_field = NamedTuple{R,S}(($(leaf_node_names...),)) - internal_nodes_field = NamedTuple{T,U}(($(internal_node_names...),)) - choices = StaticChoiceMap{R,S,T,U}(leaf_nodes_field, internal_nodes_field) - (idx - start_idx, choices) - end -end - -@generated function Base.merge(choices1::StaticChoiceMap{R,S,T,U}, - choices2::StaticChoiceMap{W,X,Y,Z}) where {R,S,T,U,W,X,Y,Z} - - # unpack first assignment type parameters - leaf_node_keys1 = choices1.parameters[1] - leaf_node_types1 = choices1.parameters[2].parameters - internal_node_keys1 = choices1.parameters[3] - internal_node_types1 = choices1.parameters[4].parameters - keys1 = (leaf_node_keys1..., internal_node_keys1...,) - - # unpack second assignment type parameters - leaf_node_keys2 = choices2.parameters[1] - leaf_node_types2 = choices2.parameters[2].parameters - internal_node_keys2 = choices2.parameters[3] - internal_node_types2 = choices2.parameters[4].parameters - keys2 = (leaf_node_keys2..., internal_node_keys2...,) - - # leaf vs leaf collision is an error - colliding_leaf_leaf_keys = intersect(leaf_node_keys1, leaf_node_keys2) - if !isempty(colliding_leaf_leaf_keys) - error("choices1 and choices2 both have leaf nodes at key(s): $colliding_leaf_leaf_keys") - end - - # leaf vs internal collision is an error - colliding_leaf_internal_keys = intersect(leaf_node_keys1, internal_node_keys2) - if !isempty(colliding_leaf_internal_keys) - error("choices1 has leaf node and choices2 has internal node at key(s): $colliding_leaf_internal_keys") - end - - # internal vs leaf collision is an error - colliding_internal_leaf_keys = intersect(internal_node_keys1, leaf_node_keys2) - if !isempty(colliding_internal_leaf_keys) - error("choices1 has internal node and choices2 has leaf node at key(s): $colliding_internal_leaf_keys") - end - - # internal vs internal collision is not an error, recursively call merge - colliding_internal_internal_keys = (intersect(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys1_exclusive = (setdiff(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys2_exclusive = (setdiff(internal_node_keys2, internal_node_keys1)...,) - - # leaf nodes named tuple - leaf_node_keys = (leaf_node_keys1..., leaf_node_keys2...,) - leaf_node_types = map(QuoteNode, (leaf_node_types1..., leaf_node_types2...,)) - leaf_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys1]..., - [Expr(:(.), :(choices2.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys2]...) - leaf_nodes = Expr(:call, - Expr(:curly, :NamedTuple, - QuoteNode(leaf_node_keys), - Expr(:curly, :Tuple, leaf_node_types...)), - leaf_node_values) - - # internal nodes named tuple - internal_node_keys = (internal_node_keys1_exclusive..., - internal_node_keys2_exclusive..., - colliding_internal_internal_keys...) - internal_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)) - for key in internal_node_keys1_exclusive]..., - [Expr(:(.), :(choices2.internal_nodes), QuoteNode(key)) - for key in internal_node_keys2_exclusive]..., - [Expr(:call, :merge, - Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)), - Expr(:(.), :(choices2.internal_nodes), QuoteNode(key))) - for key in colliding_internal_internal_keys]...) - internal_nodes = Expr(:call, - Expr(:curly, :NamedTuple, QuoteNode(internal_node_keys)), - internal_node_values) - - # construct assignment from named tuples - Expr(:call, :StaticChoiceMap, leaf_nodes, internal_nodes) -end - -export StaticChoiceMap -export pair, unpair - -####################### -# dynamic assignment # -####################### - -struct DynamicChoiceMap <: ChoiceMap - leaf_nodes::Dict{Any,Any} - internal_nodes::Dict{Any,Any} - function DynamicChoiceMap(leaf_nodes::Dict{Any,Any}, internal_nodes::Dict{Any,Any}) - new(leaf_nodes, internal_nodes) - end -end - -# invariant: all internal nodes are nonempty - -""" - struct DynamicChoiceMap <: ChoiceMap .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -function DynamicChoiceMap() - DynamicChoiceMap(Dict(), Dict()) -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for (addr, value) in tuples - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, val) in get_values_shallow(other) - choices[addr] = val - end - for (addr, submap) in get_submaps_shallow(other) - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - choices -end - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -get_values_shallow(choices::DynamicChoiceMap) = choices.leaf_nodes - -get_submaps_shallow(choices::DynamicChoiceMap) = choices.internal_nodes - -has_value(choices::DynamicChoiceMap, addr::Pair) = _has_value(choices, addr) - -get_value(choices::DynamicChoiceMap, addr::Pair) = _get_value(choices, addr) - -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.internal_nodes, addr) - choices.internal_nodes[addr] - elseif haskey(choices.leaf_nodes, addr) - throw(KeyError(addr)) - else - EmptyChoiceMap() - end -end - -has_value(choices::DynamicChoiceMap, addr) = haskey(choices.leaf_nodes, addr) - -get_value(choices::DynamicChoiceMap, addr) = choices.leaf_nodes[addr] - -function Base.isempty(choices::DynamicChoiceMap) - isempty(choices.leaf_nodes) && isempty(choices.internal_nodes) -end - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.internal_nodes, addr) - choices.leaf_nodes[addr] = value -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - node = choices.internal_nodes[first] - set_value!(node, rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node) - delete!(choices.leaf_nodes, addr) - delete!(choices.internal_nodes, addr) - if !isempty(new_node) - choices.internal_nodes[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - set_submap!(node, rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _fill_array!(choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - leaf_keys_sorted = sort(collect(keys(choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - value = choices.leaf_nodes[key] - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for key in internal_node_keys_sorted - n_written = _fill_array!(get_submap(choices, key), arr, idx) - idx += n_written - end - idx - start_idx -end - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - @assert length(arr) >= start_idx - choices = DynamicChoiceMap() - leaf_keys_sorted = sort(collect(keys(proto_choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(proto_choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - (n_read, value) = _from_array(proto_choices.leaf_nodes[key], arr, idx) - idx += n_read - choices.leaf_nodes[key] = value - end - for key in internal_node_keys_sorted - (n_read, node) = _from_array(get_submap(proto_choices, key), arr, idx) - idx += n_read - choices.internal_nodes[key] = node - end - (idx - start_idx, choices) -end - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! - - -####################################### -## vector combinator for assignments # -####################################### - -# TODO implement LeafVectorChoiceMap, which stores a vector of leaf nodes - -struct InternalVectorChoiceMap{T} <: ChoiceMap - internal_nodes::Vector{T} - is_empty::Bool -end - -function vectorize_internal(nodes::Vector{T}) where {T} - is_empty = all(map(isempty, nodes)) - InternalVectorChoiceMap(nodes, is_empty) -end - -# note some internal nodes may be empty - -get_address_schema(::Type{InternalVectorChoiceMap}) = VectorAddressSchema() - -Base.isempty(choices::InternalVectorChoiceMap) = choices.is_empty -has_value(choices::InternalVectorChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::InternalVectorChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::InternalVectorChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::InternalVectorChoiceMap, addr::Int) - if addr > 0 && addr <= length(choices.internal_nodes) - choices.internal_nodes[addr] - else - EmptyChoiceMap() - end -end - -function get_submaps_shallow(choices::InternalVectorChoiceMap) - ((i, choices.internal_nodes[i]) - for i=1:length(choices.internal_nodes) - if !isempty(choices.internal_nodes[i])) -end - -get_values_shallow(::InternalVectorChoiceMap) = () - -function _fill_array!(choices::InternalVectorChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for key=1:length(choices.internal_nodes) - n = _fill_array!(choices.internal_nodes[key], arr, idx) - idx += n - end - idx - start_idx -end - -function _from_array(proto_choices::InternalVectorChoiceMap{U}, arr::Vector{T}, start_idx::Int) where {T,U} - @assert length(arr) >= start_idx - nodes = Vector{U}(undef, length(proto_choices.internal_nodes)) - idx = start_idx - for key=1:length(proto_choices.internal_nodes) - (n_read, nodes[key]) = _from_array(proto_choices.internal_nodes[key], arr, idx) - idx += n_read - end - choices = InternalVectorChoiceMap(nodes, proto_choices.is_empty) - (idx - start_idx, choices) -end - -export InternalVectorChoiceMap -export vectorize_internal - - -#################### -# empty assignment # -#################### - -struct EmptyChoiceMap <: ChoiceMap end - -Base.isempty(::EmptyChoiceMap) = true -get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -get_submaps_shallow(::EmptyChoiceMap) = () -get_values_shallow(::EmptyChoiceMap) = () - -_fill_array!(::EmptyChoiceMap, arr::Vector, start_idx::Int) = 0 -_from_array(::EmptyChoiceMap, arr::Vector, start_idx::Int) = (0, EmptyChoiceMap()) - -export EmptyChoiceMap - -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -function Base.getindex(choices::ChoiceMapNestedView, addr) - if has_value(choices.choice_map, addr) - return get_value(choices.choice_map, addr) - end - submap = get_submap(choices.choice_map, addr) - if isempty(submap) - throw(KeyError(addr)) - end - ChoiceMapNestedView(submap) -end - -function Base.iterate(c::ChoiceMapNestedView) - inner_iterator = Base.Iterators.flatten(( - get_values_shallow(c.choice_map), - ((k, ChoiceMapNestedView(v)) - for (k, v) in get_submaps_shallow(c.choice_map)))) - r = Base.iterate(inner_iterator) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (inner_iterator, inner_state) = state - r = Base.iterate(inner_iterator, inner_state) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::Gen.ChoiceMapNestedView) = (k for (k, v) in cv) - -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map == b.choice_map -end - -# Length of a `ChoiceMapNestedView` is number of leaf values + number of -# submaps. Motivation: This matches what `length` would return for the -# equivalent nested dict. -function Base.length(cv::ChoiceMapNestedView) - +(get_values_shallow(cv.choice_map) |> collect |> length, - get_submaps_shallow(cv.choice_map) |> collect |> length) -end - -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - output = choicemap() - for (key, value) in get_values_shallow(choices) - if (key in selection) - output[key] = value - end - end - for (key, submap) in get_submaps_shallow(choices) - subselection = selection[key] - set_submap!(output, key, get_selected(submap, subselection)) - end - output -end - -export get_selected diff --git a/src/choice_map2/array_interface.jl b/src/choice_map/array_interface.jl similarity index 83% rename from src/choice_map2/array_interface.jl rename to src/choice_map/array_interface.jl index f88c5b116..cf9d0bd03 100644 --- a/src/choice_map2/array_interface.jl +++ b/src/choice_map/array_interface.jl @@ -34,12 +34,20 @@ function to_array(choices::ChoiceMap, ::Type{T}) where {T} end function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) <: start_idx + if length(arr) < start_idx resize!(arr, 2 * start_idx) end arr[start_idx] = get_value(c) 1 end +function _fill_array!(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end # default _fill_array! implementation function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} @@ -88,5 +96,11 @@ function from_array(proto_choices::ChoiceMap, arr::Vector) end function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) - ValueChoiceMap(arr[start_idx]) -end \ No newline at end of file + (1, ValueChoiceMap(arr[start_idx])) +end +function _from_array(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, ValueChoiceMap(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/choice_map2/choice_map.jl b/src/choice_map/choice_map.jl similarity index 82% rename from src/choice_map2/choice_map.jl rename to src/choice_map/choice_map.jl index 0ebb19f09..402cefa37 100644 --- a/src/choice_map2/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -2,6 +2,22 @@ # choice map interface # ######################### +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + """ get_submaps_shallow(choices::ChoiceMap) @@ -26,7 +42,6 @@ function get_submap end submap = get_submap(choices, first) get_submap(submap, rest) end -@inline get_submap(choices::ChoiceMap, addr::Pair) = _get_submap(choices, addr) """ has_value(choices::ChoiceMap) @@ -45,18 +60,18 @@ function has_value end get_value(choices::ChoiceMap) Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; -throws a `KeyError` if `choices` is not a `ValueChoiceMap`. +throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. get_value(choices::ChoiceMap, addr) Returns the value stored in the submap with address `addr` or throws -a `KeyError` if no value exists at this address. +a `ChoiceMapGetValueError` if no value exists at this address. A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -get_value(::ChoiceMap) = throw(KeyError(nothing)) +get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) @@ -73,6 +88,8 @@ function get_address_schema end Returns an iterable collection of tuples `(address, value)` for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) """ function get_values_shallow(choices::ChoiceMap) ( @@ -88,20 +105,15 @@ end Returns an iterable collection of tuples `(address, submap)` for every top-level submap stored in `choices` which is not a `ValueChoiceMap`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) """ function get_nonvalue_submaps_shallow(choices::ChoiceMap) - filter(! ∘ has_value, get_submaps_shallow(choices)) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) end # a choicemap is empty if it has no submaps and no value -Base.isempty(c::ChoiceMap) = isempty(get_submaps_shallow(c)) && !has_value(c) - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end +Base.isempty(c::ChoiceMap) = all(((addr, submap),) -> isempty(submap), get_submaps_shallow(c)) && !has_value(c) """ EmptyChoiceMap @@ -111,11 +123,14 @@ A choicemap with no submaps or values. struct EmptyChoiceMap <: ChoiceMap end @inline has_value(::EmptyChoiceMap, addr...) = false -@inline get_value(::EmptyChoiceMap) = throw(KeyError(nothing)) +@inline get_value(::EmptyChoiceMap) = throw(ChoiceMapGetValueError()) @inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() @inline Base.isempty(::EmptyChoiceMap) = true @inline get_submaps_shallow(::EmptyChoiceMap) = () @inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() +@inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true +@inline Base.:(==)(::ChoiceMap, ::EmptyChoiceMap) = false +@inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false """ ValueChoiceMap @@ -148,6 +163,11 @@ function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) for (key, submap) in get_submaps_shallow(choices1) set_submap!(choices, key, merge(submap, get_submap(choices2, key))) end + for (key, submap) in get_submaps_shallow(choices2) + if isempty(get_submap(choices1, key)) + set_submap!(choices, key, submap) + end + end choices end Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c @@ -170,6 +190,11 @@ function Base.:(==)(a::ChoiceMap, b::ChoiceMap) return false end end + for (addr, submap) in get_submaps_shallow(b) + if get_submap(a, addr) != submap + return false + end + end return true end @@ -246,9 +271,11 @@ function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) end export ChoiceMap, ValueChoiceMap, EmptyChoiceMap -export get_submap, get_submaps_shallow +export _get_submap, get_submap, get_submaps_shallow export get_value, has_value export get_values_shallow, get_nonvalue_submaps_shallow +export get_address_schema, get_selected +export ChoiceMapGetValueError include("array_interface.jl") include("dynamic_choice_map.jl") diff --git a/src/choice_map2/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl similarity index 93% rename from src/choice_map2/dynamic_choice_map.jl rename to src/choice_map/dynamic_choice_map.jl index 5dfca0b55..a3403307c 100644 --- a/src/choice_map2/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -2,10 +2,6 @@ # dynamic assignment # ####################### -struct DynamicChoiceMap <: ChoiceMap - submaps::Dict{Any, <:ChoiceMap} -end - """ struct DynamicChoiceMap <: ChoiceMap .. end @@ -19,8 +15,11 @@ Construct an empty map. Construct a map containing each of the given (addr, value) tuples. """ -function DynamicChoiceMap() - DynamicChoiceMap(Dict()) +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, ChoiceMap} + function DynamicChoiceMap() + new(Dict()) + end end function DynamicChoiceMap(tuples...) @@ -39,12 +38,13 @@ Copy a choice map, returning a mutable choice map. function DynamicChoiceMap(other::ChoiceMap) choices = DynamicChoiceMap() for (addr, submap) in get_submaps_shallow(other) - if choices isa ValueChoiceMap + if submap isa ValueChoiceMap set_submap!(choices, addr, submap) else set_submap!(choices, addr, DynamicChoiceMap(submap)) end end + choices end DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") @@ -116,14 +116,14 @@ Set the given value for the given address. Will cause any previous value or sub-assignment at the given address to be deleted. It is an error if there is already a value present at some prefix of address. """ -function set_submap!(choices::DynamicChoiceMap, addr, new_node) +function set_submap!(choices::DynamicChoiceMap, addr, new_node::ChoiceMap) delete!(choices.submaps, addr) if !isempty(new_node) choices.submaps[addr] = new_node end end -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node::ChoiceMap) (first, rest) = addr if !haskey(choices.submaps, first) choices.submaps[first] = DynamicChoiceMap() @@ -137,7 +137,7 @@ Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, add function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} choices = DynamicChoiceMap() - keys_sorted = sort(collect(keys(choices.submaps))) + keys_sorted = sort(collect(keys(proto_choices.submaps))) idx = start_idx for key in keys_sorted (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) diff --git a/src/choice_map2/nested_view.jl b/src/choice_map/nested_view.jl similarity index 93% rename from src/choice_map2/nested_view.jl rename to src/choice_map/nested_view.jl index 6693234fb..68add0a05 100644 --- a/src/choice_map2/nested_view.jl +++ b/src/choice_map/nested_view.jl @@ -33,7 +33,7 @@ ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") function Base.getindex(choices::ChoiceMapNestedView, addr) - ChoiceMapNestedView(get_submap(choices, addr)) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) end function Base.iterate(c::ChoiceMapNestedView) @@ -62,9 +62,8 @@ end # exists. Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map = b.choice_map -end +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + function Base.length(cv::ChoiceMapNestedView) length(collect(get_submaps_shallow(cv.choice_map))) end diff --git a/src/choice_map2/static_choice_map.jl b/src/choice_map/static_choice_map.jl similarity index 68% rename from src/choice_map2/static_choice_map.jl rename to src/choice_map/static_choice_map.jl index 3508762d7..1f75b3bca 100644 --- a/src/choice_map2/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -4,13 +4,21 @@ struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap submaps::NamedTuple{Addrs, SubmapTypes} + function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} + new{Addrs, SubmapTypes}(submaps) + end +end + +function StaticChoiceMap(;addrs_to_vals_and_maps...) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) + maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : ValueChoiceMap(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) + StaticChoiceMap(NamedTuple{addrs}(maps)) end @inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) @inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) @inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) -# TODO: profiling! @generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} if A in Addrs quote choices.submaps[A] end @@ -18,17 +26,25 @@ end quote EmptyChoiceMap() end end end +static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = get_submaps_shallow(other) - (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + keys_and_nodes = collect(get_submaps_shallow(other)) + if length(keys_and_nodes) > 0 + (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + else + addrs = () + submaps = () + end StaticChoiceMap(NamedTuple{addrs}(submaps)) end StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") +StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() # TODO: deep conversion to static choicemap @@ -58,9 +74,9 @@ end @generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} - - addr_to_type1 = Dict{Symbol, ::Type{<:ChoiceMap}}() - addr_to_type2 = Dict{Symbol, ::Type{<:ChoiceMap}}() + + addr_to_type1 = Dict{Symbol, Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, Type{<:ChoiceMap}}() for (i, addr) in enumerate(Addrs1) addr_to_type1[addr] = SubmapTypes1.parameters[i] end @@ -78,30 +94,30 @@ end || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") end - if type1 <: ValueChoiceMap + if type1 <: EmptyChoiceMap push!(submap_exprs, - quote choices1.submaps[$addr] end + quote choices2.submaps.$addr end ) - elseif type2 <: ValueChoiceMap + elseif type2 <: EmptyChoiceMap push!(submap_exprs, - quote choices2.submaps[$addr] end + quote choices1.submaps.$addr end ) else push!(submap_exprs, - quote merge(choices1.submaps[$addr], choices2.submaps[$addr]) end + quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end ) end end quote - StaticChoiceMap{$merged_addrs}(submap_exprs...) + StaticChoiceMap(NamedTuple{$merged_addrs}(($(submap_exprs...),))) end end -@generated function _from_array!(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, +@generated function _from_array(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} - perm = sortperm(Addrs) + perm = sortperm(collect(Addrs)) sorted_addrs = Addrs[perm] submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) @@ -112,7 +128,7 @@ end submap_var_names[idx] = submap_var_name push!(exprs, quote - (n_read, submap_var_name = _from_array(proto_choices.submaps[$addr], arr, idx) + (n_read, $submap_var_name) = _from_array(proto_choices.submaps.$addr, arr, idx) idx += n_read end ) @@ -120,14 +136,14 @@ end quote $(exprs...) - submaps = NamedTuple{Addrs}(( $(submap_var_names...) )) - choices = StaticChoiceMap{Addrs, SubmapTypes}(submaps) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticChoiceMap(submaps) (idx - start_idx, choices) end end function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} - StaticAddressSchema(set(Addrs)) + StaticAddressSchema(Set(Addrs)) end export StaticChoiceMap diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index c6f09374c..73f22159a 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -124,42 +124,33 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end +all_visited(::Selection, ::ValueChoiceMap) = false +all_visited(::AllSelection, ::ValueChoiceMap) = true function all_visited(visited::Selection, choices::ChoiceMap) - allvisited = true - for (key, _) in get_values_shallow(choices) - allvisited = allvisited && (key in visited) - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - allvisited = allvisited && all_visited(subvisited, submap) + if !all_visited(visited[key], submap) + return false end end - allvisited + return true end +get_unvisited(::Selection, v::ValueChoiceMap) = v +get_unvisited(::AllSelection, v::ValueChoiceMap) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() - for (key, _) in get_values_shallow(choices) - if !(key in visited) - set_value!(unvisited, key, get_value(choices, key)) - end - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - sub_unvisited = get_unvisited(subvisited, submap) - set_submap!(unvisited, key, sub_unvisited) - end + sub_unvisited = get_unvisited(visited[key], submap) + set_submap!(unvisited, key, sub_unvisited) end unvisited end get_visited(visitor) = visitor.visited -function check_no_submap(constraints::ChoiceMap, addr) +function check_is_empty(constraints::ChoiceMap, addr) if !isempty(get_submap(constraints, addr)) - error("Expected a value at address $addr but found a sub-assignment") + error("Expected a value or EmptyChoiceMap at address $addr but found a sub-assignment") end end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index df6a5f465..970dac42d 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -20,7 +20,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # get return value if constrained diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..882297e43 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -119,9 +119,6 @@ struct DynamicDSLChoiceMap <: ChoiceMap end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -Base.isempty(::DynamicDSLChoiceMap) = false # TODO not necessarily true -has_value(choices::DynamicDSLChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::DynamicDSLChoiceMap, addr::Pair) = _get_value(choices, addr) get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_submap(choices::DynamicDSLChoiceMap, addr) @@ -130,9 +127,10 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) # leaf node, must be a call call = trie[addr] if call.is_choice - throw(KeyError(addr)) + ValueChoiceMap(call.subtrace_or_retval) + else + get_choices(call.subtrace_or_retval) end - get_choices(call.subtrace_or_retval) elseif has_internal_node(trie, addr) # internal node subtrie = get_internal_node(trie, addr) @@ -142,32 +140,12 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) end end -function has_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - has_leaf_node(trie, addr) && trie[addr].is_choice -end - -function get_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - choice = trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - choice.subtrace_or_retval -end - -function get_values_shallow(choices::DynamicDSLChoiceMap) - ((key, choice.subtrace_or_retval) - for (key, choice) in get_leaf_nodes(choices.trie) - if choice.is_choice) -end - function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ((key, get_choices(call.subtrace_or_retval)) + calls_iter = ( + (key, call.is_choice ? ValueChoiceMap(call.subtrace_or_retval) : get_choices(call.subtrace_or_retval)) for (key, call) in get_leaf_nodes(choices.trie) - if !call.is_choice) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) - for (key, trie) in get_internal_nodes(choices.trie)) + ) + internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) Iterators.flatten((calls_iter, internal_nodes_iter)) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f24..7acc16302 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -35,7 +35,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # record the previous value as discarded if it is replaced if constrained && has_previous @@ -149,32 +149,22 @@ end function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) - for (key, value) in get_values_shallow(prev_choices) + for (key, submap) in get_submaps_shallow(prev_choices) + # if key IS in visited, + # the recursive call to update already handled the discard + # for this entire submap; else we need to handle it if !(key in visited) - @assert !has_value(discard, key) @assert isempty(get_submap(discard, key)) - set_value!(discard, key, value) - end - end - for (key, submap) in get_submaps_shallow(prev_choices) - @assert !has_value(discard, key) - if key in visited - # the recursive call to update already handled the discard - # for this entire submap - continue - else subvisited = visited[key] if isempty(subvisited) # none of this submap was visited, so we discard the whole thing - @assert isempty(get_submap(discard, key)) set_submap!(discard, key, submap) else subdiscard = get_submap(discard, key) - add_unvisited_to_discard!( - isempty(subdiscard) ? choicemap() : subdiscard, - subvisited, submap) + subdiscard = isempty(subdiscard) ? choicemap() : subdiscard + add_unvisited_to_discard!(subdiscard, subvisited, submap) set_submap!(discard, key, subdiscard) - end + end end end end diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index a231f03a7..d662dbb75 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -1,12 +1,13 @@ import MacroTools function check_observations(choices::ChoiceMap, observations::ChoiceMap) - for (key, value) in get_values_shallow(observations) - !has_value(choices, key) && error("Check failed: observed choice at $key not found") - choices[key] != value && error("Check failed: value of observed choice at $key changed") - end for (key, submap) in get_submaps_shallow(observations) - check_observations(get_submap(choices, key), submap) + if has_value(submap) + !has_value(choices, key) && error("Check failed: observed choice at $key not found") + choices[key] != value && error("Check failed: value of observed choice at $key changed") + else + check_observations(get_submap(choices, key), submap) + end end end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 234116976..f17d061f8 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -14,10 +14,7 @@ function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} end get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_value(choices::CallAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::CallAtChoiceMap, addr::Pair) = _has_value(choices, addr) get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) -get_values_shallow(::CallAtChoiceMap) = () # TODO optimize CallAtTrace using type parameters @@ -69,7 +66,7 @@ unpack_call_at_args(args) = (args[end], args[1:end-1]) function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) - if length(get_submaps_shallow(choices)) > 1 || length(get_values_shallow(choices)) > 0 + if length(get_submaps_shallow(choices)) > 1 error("Not all constraints were consumed") end submap = get_submap(choices, key) diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 69bb4851a..f38758956 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -25,10 +25,12 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} end get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) +get_submap(choices::ChoiceAtChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} choices.key == addr ? choices.value : throw(KeyError(choices, addr)) end -get_submaps_shallow(choices::ChoiceAtChoiceMap) = () +get_submap(choices::ChoiceAtChoiceMap, addr) = addr == choices.key ? ValueChoiceMap(choices.value) : EmptyChoiceMap() +get_submaps_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, ValueChoiceMap(choices.value)),) get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 715800737..1f1017251 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -84,17 +84,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, addr::Pair) - _get_submap(choices, addr) -end - -function has_value(choices::RecurseTraceChoiceMap, addr::Pair) - _has_value(choices, addr) -end - -function get_value(choices::RecurseTraceChoiceMap, addr::Pair) - _get_value(choices, addr) -end +get_submap(choices::RecurseTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () @@ -333,6 +323,9 @@ function recurse_unpack_constraints(constraints::ChoiceMap) production_constraints = Dict{Int, Any}() aggregation_constraints = Dict{Int, Any}() for (addr, node) in get_submaps_shallow(constraints) + if has_value(node) + error("Unknown address: $(addr)") + end idx::Int = addr[1] if addr[2] == Val(:production) production_constraints[idx] = node @@ -342,9 +335,6 @@ function recurse_unpack_constraints(constraints::ChoiceMap) error("Unknown address: $addr") end end - if length(get_values_shallow(constraints)) > 0 - error("Unknown address: $(first(get_values_shallow(constraints))[1])") - end return (production_constraints, aggregation_constraints) end diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 9b0eb763a..3af416ef8 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -92,10 +92,6 @@ end end @inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_value(choices::VectorTraceChoiceMap, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::VectorTraceChoiceMap, addr::Pair) = _has_value(choices, addr) -@inline get_values_shallow(::VectorTraceChoiceMap) = () - ############################################ # code shared by vector-shaped combinators # diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 7a0fe384e..b352d3ca2 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -330,21 +330,22 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) - leaf_gradients = map((node) -> gradient_var(node), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(trace.$(get_value_fieldname(node)))), selected_choices_vec) + leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) - internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), + internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) + internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) + + quoted_all_keys = Iterators.flatten((quoted_leaf_keys, quoted_internal_keys)) + all_value_choicemaps = Iterators.flatten((leaf_value_choicemaps, internal_value_choicemaps)) + all_gradient_choicemaps = Iterators.flatten((leaf_gradient_choicemaps, internal_gradient_choicemaps)) + quote - $value_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_gradient_choicemaps...),))) end end diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 713c0863a..5ac3ced16 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -9,25 +9,8 @@ end function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) - @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) - -@inline static_has_value(choices::StaticIRTraceAssmt, key) = false - -@inline function get_value(choices::StaticIRTraceAssmt, key::Symbol) - static_get_value(choices, Val(key)) -end - -@inline function has_value(choices::StaticIRTraceAssmt, key::Symbol) - static_has_value(choices, Val(key)) -end - -@inline function get_submap(choices::StaticIRTraceAssmt, key::Symbol) - static_get_submap(choices, Val(key)) -end - -@inline get_value(choices::StaticIRTraceAssmt, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::StaticIRTraceAssmt, addr::Pair) = _has_value(choices, addr) +@inline get_submap(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) @inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) ######################### @@ -36,16 +19,13 @@ end abstract type StaticIRTrace <: Trace end -@inline function static_get_subtrace(trace::StaticIRTrace, addr) - error("Not implemented") -end +@inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline function Base.getindex(trace::StaticIRTrace, addr) - Gen.static_getindex(trace, Val(addr)) -end +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] @@ -161,21 +141,13 @@ function generate_get_choices(trace_struct_name::Symbol) :($(QuoteNode(EmptyChoiceMap))()))) end -function generate_get_values_shallow(ir::StaticIR, trace_struct_name::Symbol) +function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes addr = node.addr value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), $value))) + push!(elements, :(($(QuoteNode(addr)), ValueChoiceMap($value)))) end - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_values_shallow)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), - Expr(:block, Expr(:tuple, elements...))) -end - -function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) - elements = [] for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) @@ -224,30 +196,6 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] end -function generate_static_get_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(choices.trace.$(get_value_fieldname(node)))))) - end - methods -end - -function generate_static_has_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_has_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(true)))) - end - methods -end - function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes @@ -259,13 +207,13 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) end - # throw a KeyError if get_submap is run on an address containing a value + # return a ValueChoiceMap if get_submap is run on an address containing a value for node in ir.choice_nodes push!(methods, Expr(:function, Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(throw(KeyError($(QuoteNode(node.addr)))))))) + Expr(:block, :(ValueChoiceMap(choices.trace.$(get_value_fieldname(node))))))) end methods end @@ -290,18 +238,13 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_retval_expr = generate_get_retval(ir, trace_struct_name) get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) - static_get_value_exprs = generate_static_get_value(ir, trace_struct_name) - static_has_value_exprs = generate_static_has_value(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_values_shallow_expr, - get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + get_choices_expr, get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index dc4fddf31..c806bba3a 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -454,9 +454,10 @@ function generate_discard!(stmts::Vector{Expr}, end leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - expr = :($(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(leaf_keys...),)}(($(leaf_nodes...),)), - $(QuoteNode(NamedTuple)){($(internal_keys...),)}(($(internal_nodes...),)))) + all_keys = (leaf_keys..., internal_keys...) + all_nodes = ([:($(QuoteNode(ValueChoiceMap))($node)) for node in leaf_nodes]..., internal_nodes...) + expr = quote $(QuoteNode(StaticChoiceMap))( + $(QuoteNode(NamedTuple)){($(all_keys...),)}(($(all_nodes...),))) end push!(stmts, :($discard = $expr)) end diff --git a/test/assignment.jl b/test/assignment.jl index 1bba754af..1d7e48a80 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,6 +1,46 @@ +@testset "ValueChoiceMap" begin + vcm1 = ValueChoiceMap(2) + vcm2 = ValueChoiceMap(2.) + vcm3 = ValueChoiceMap([1,2]) + @test vcm1 isa ValueChoiceMap{Int} + @test vcm2 isa ValueChoiceMap{Float64} + @test vcm3 isa ValueChoiceMap{Vector{Int}} + + @test !isempty(vcm1) + @test has_value(vcm1) + @test get_value(vcm1) == 2 + @test vcm1 == vcm2 + @test isempty(get_submaps_shallow(vcm1)) + @test isempty(get_values_shallow(vcm1)) + @test isempty(get_nonvalue_submaps_shallow(vcm1)) + @test to_array(vcm1, Int) == [2] + @test from_array(vcm1, [4]) == ValueChoiceMap(4) + @test from_array(vcm3, [4, 5]) == ValueChoiceMap([4, 5]) + @test_throws Exception merge(vcm1, vcm2) + @test_throws Exception merge(vcm1, choicemap(:a, 5)) + @test merge(vcm1, EmptyChoiceMap()) == vcm1 + @test merge(EmptyChoiceMap(), vcm1) == vcm1 + @test get_submap(vcm1, :addr) == EmptyChoiceMap() + @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) + @test !has_value(vcm1, :addr) + @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) + @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) + @test get_address_schema(typeof(vcm1)) == EmptyAddressSchema() + @test get_address_schema(ValueChoiceMap) == EmptyAddressSchema() + @test nested_view(vcm1) == 2 +end + +@testset "static choicemap constructor" begin + @test StaticChoiceMap((a=ValueChoiceMap(5), b=ValueChoiceMap(6))) == StaticChoiceMap(a=5, b=6) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + @test submap == StaticChoiceMap((a=ValueChoiceMap(1.), b=ValueChoiceMap([2., 2.5]))) + outer = StaticChoiceMap(c=3, d=submap, e=submap) + @test outer == StaticChoiceMap((c=ValueChoiceMap(3), d=submap, e=submap)) +end + @testset "static assignment to/from array" begin - submap = StaticChoiceMap((a=1., b=[2., 2.5]),NamedTuple()) - outer = StaticChoiceMap((c=3.,), (d=submap, e=submap)) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + outer = StaticChoiceMap(c=3., d=submap, e=submap) arr = to_array(outer, Float64) @test to_array(outer, Float64) == Float64[3.0, 1.0, 2.0, 2.5, 1.0, 2.0, 2.5] @@ -11,14 +51,16 @@ @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment to/from array" begin @@ -39,14 +81,18 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test get_submap(choices, :c) == ValueChoiceMap(1.0) + @test get_submap(choices, :d => :b) == ValueChoiceMap([3.0, 4.0]) + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment copy constructor" begin @@ -64,25 +110,6 @@ end @test choices[:u => :w] == 4 end -@testset "internal vector assignment to/from array" begin - inner = choicemap() - set_value!(inner, :a, 1.) - set_value!(inner, :b, 2.) - outer = vectorize_internal([inner, inner, inner]) - - arr = to_array(outer, Float64) - @test to_array(outer, Float64) == Float64[1, 2, 1, 2, 1, 2] - - choices = from_array(outer, Float64[1, 2, 3, 4, 5, 6]) - @test choices[1 => :a] == 1.0 - @test choices[1 => :b] == 2.0 - @test choices[2 => :a] == 3.0 - @test choices[2 => :b] == 4.0 - @test choices[3 => :a] == 5.0 - @test choices[3 => :b] == 6.0 - @test length(collect(get_submaps_shallow(choices))) == 3 -end - @testset "dynamic assignment merge" begin submap = choicemap() set_value!(submap, :x, 1) @@ -107,7 +134,7 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @@ -125,8 +152,8 @@ end set_value!(submap, :x, 1) submap2 = choicemap() set_value!(submap2, :y, 4.) - choices1 = StaticChoiceMap((a=1., b=2.), (c=submap, shared=submap)) - choices2 = StaticChoiceMap((d=3.,), (e=submap, f=submap, shared=submap2)) + choices1 = StaticChoiceMap(a=1., b=2., c=submap, shared=submap) + choices2 = StaticChoiceMap(d=3., e=submap, f=submap, shared=submap2) choices = merge(choices1, choices2) @test choices[:a] == 1. @test choices[:b] == 2. @@ -136,124 +163,91 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @testset "static assignment variadic merge" begin - choices1 = StaticChoiceMap((a=1,), NamedTuple()) - choices2 = StaticChoiceMap((b=2,), NamedTuple()) - choices3 = StaticChoiceMap((c=3,), NamedTuple()) - choices_all = StaticChoiceMap((a=1, b=2, c=3), NamedTuple()) + choices1 = StaticChoiceMap(a=1) + choices2 = StaticChoiceMap(b=2) + choices3 = StaticChoiceMap(c=3) + choices_all = StaticChoiceMap(a=1, b=2, c=3) @test merge(choices1) == choices1 @test merge(choices1, choices2, choices3) == choices_all end +# TODO: in changing a lot of these to reflect the new behavior of choicemap, +# they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; +# should we relabel this testset? @testset "static assignment errors" begin + # get_choices on an address that returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x) == ValueChoiceMap(1) + + # static_get_submap on an address that contains a value returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test static_get_submap(choices, Val(:x)) == ValueChoiceMap(1) - # get_choices on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw - - # static_get_submap on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # get_submap on an address whose prefix contains a value returns EmptyChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) + choices = StaticChoiceMap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # static_get_choices on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_choices on an address that contains nothing returns an EmptyChoiceMap + choices = StaticChoiceMap() + @test static_get_submap(choices, Val(:x)) == EmptyChoiceMap() - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # static_get_value on an address that contains a submap throws a KeyError + # static_get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) + + # get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) + + # static_get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) end @testset "dynamic assignment errors" begin - - # get_choices on an address that contains a value throws a KeyError + # get_choices on an address that contains a value returns a ValueChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(1) - # get_choices on an address whose prefix contains a value throws a KeyError + # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment choices = choicemap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError choices = choicemap() choices[:x => :y] = 1 - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # get_value on an address that contains nothing throws a KeyError + # get_value on an address that contains nothing throws a ChoiceMapGetValueError choices = choicemap() - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) end @testset "dynamic assignment overwrite" begin @@ -276,9 +270,7 @@ end choices = choicemap() choices[:x => :y] = 1 choices[:x] = 2 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment @@ -293,17 +285,13 @@ end # illegal set value under existing value choices = choicemap() choices[:x] = 1 - threw = false - try set_value!(choices, :x => :y, 2) catch KeyError threw = true end - @test threw + @test_throws Exception set_value!(choices, :x => :y, 2) # illegal set submap under existing value choices = choicemap() choices[:x] = 1 submap = choicemap(); choices[:z] = 2 - threw = false - try set_submap!(choices, :x => :y, submap) catch KeyError threw = true end - @test threw + @test_throws Exception set_submap!(choices, :x => :y, submap) end @testset "dynamic assignment constructor" begin diff --git a/test/benchmark.md b/test/benchmark.md new file mode 100644 index 000000000..adabb8a58 --- /dev/null +++ b/test/benchmark.md @@ -0,0 +1,21 @@ +NEW version: +static choicemap nonnested lookup: + 0.728112 seconds (149.59 k allocations: 4.259 MiB) + 0.785652 seconds (100.00 k allocations: 1.526 MiB) + 0.693433 seconds (100.00 k allocations: 1.526 MiB) + 0.660211 seconds (100.00 k allocations: 1.526 MiB) +static choicemap nested lookup: + 0.680497 seconds (49.59 k allocations: 2.732 MiB) + 0.665768 seconds (1 allocation: 32 bytes) + 0.666708 seconds (1 allocation: 32 bytes) + 0.671009 seconds (1 allocation: 32 bytes) +static gen function choicemap nonnested lookup: + 0.701754 seconds (62.76 k allocations: 3.415 MiB) + 0.662916 seconds + 0.659019 seconds + 0.663398 seconds +static gen function choicemap nested lookup: + 1.338034 seconds (172.13 k allocations: 5.352 MiB) + 1.311123 seconds (100.00 k allocations: 1.526 MiB) + 1.311800 seconds (100.00 k allocations: 1.526 MiB) + 1.310289 seconds (100.00 k allocations: 1.526 MiB) \ No newline at end of file diff --git a/test/dynamic_dsl.jl b/test/dynamic_dsl.jl index 35f81703d..5561ae549 100644 --- a/test/dynamic_dsl.jl +++ b/test/dynamic_dsl.jl @@ -119,7 +119,7 @@ end @test get_value(discard, :x) == x @test get_value(discard, :u => :a) == a @test length(collect(get_values_shallow(discard))) == 2 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 # test new trace new_assignment = get_choices(new_trace) @@ -127,7 +127,7 @@ end @test get_value(new_assignment, :y) == y @test get_value(new_assignment, :v => :b) == b @test length(collect(get_values_shallow(new_assignment))) == 2 - @test length(collect(get_submaps_shallow(new_assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(new_assignment))) == 1 # test score and weight prev_score = ( @@ -242,7 +242,7 @@ end @test !isempty(get_submap(assignment, :v)) end @test length(collect(get_values_shallow(assignment))) == 2 - @test length(collect(get_submaps_shallow(assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(assignment))) == 1 # test weight if assignment[:branch] == prev_assignment[:branch] @@ -332,11 +332,11 @@ end @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @test !has_value(choices, :b) # was not selected - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 # check gradient trie - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected @test isapprox(get_value(gradients, :bar => :z), @@ -431,14 +431,14 @@ end @test choices[:x => 2] == 2 @test choices[:x => 3 => :z] == 3 @test length(collect(get_values_shallow(choices))) == 1 # :y - @test length(collect(get_submaps_shallow(choices))) == 1 # :x + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # :x submap = get_submap(choices, :x) @test submap[1] == 1 @test submap[2] == 2 @test submap[3 => :z] == 3 @test length(collect(get_values_shallow(submap))) == 2 # 1, 2 - @test length(collect(get_submaps_shallow(submap))) == 1 # 3 + @test length(collect(get_nonvalue_submaps_shallow(submap))) == 1 # 3 bar_submap = get_submap(submap, 3) @test bar_submap[:z] == 3 diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130d..607eb61fd 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -20,7 +20,7 @@ y = choices[3 => :y] @test isapprox(weight, logpdf(normal, y, 0.4, 1)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end @testset "generate" begin @@ -32,7 +32,7 @@ y = choices[3 => :y] @test get_retval(trace) == 0.4 + y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # with constraints y = 1.234 @@ -44,7 +44,7 @@ @test get_retval(trace) == 0.4 + y @test isapprox(weight, logpdf(normal, y, 0.4, 1.)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end function get_trace() @@ -71,7 +71,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isempty(discard) @@ -86,12 +86,12 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) # change kernel_args, different key, with constraint @@ -103,12 +103,12 @@ choices = get_choices(new_trace) @test choices[4 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) end @@ -121,7 +121,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isapprox(get_score(new_trace), logpdf(normal, y, 0.2, 1)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) y_new = choices[3 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -144,7 +144,7 @@ choices = get_choices(new_trace) y_new = choices[4 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -171,9 +171,9 @@ @test choices[3 => :y] == y @test isapprox(gradients[3 => :y], logpdf_grad(normal, y, 0.4, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 0 - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(input_grads) == 2 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.4, 1.0)[2] + retval_grad) @test input_grads[2] == nothing # the key has no gradient diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b461..4f5241381 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -15,7 +15,7 @@ @test isapprox(weight, value ? log(0.4) : log(0.6)) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end @testset "generate" begin @@ -27,7 +27,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 # with constraints constraints = choicemap() @@ -39,7 +39,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end function get_trace() @@ -65,7 +65,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isempty(discard) @@ -78,12 +78,12 @@ choices = get_choices(new_trace) @test choices[3] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 # change kernel_args, different key, with constraint constraints = choicemap() @@ -93,12 +93,12 @@ choices = get_choices(new_trace) @test choices[4] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 end @testset "regenerate" begin @@ -110,7 +110,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isapprox(get_score(new_trace), log(0.2)) @@ -122,7 +122,7 @@ choices = get_choices(new_trace) value = choices[3] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) value = choices[4] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -163,9 +163,9 @@ @test choices[3] == y @test isapprox(gradients[3], logpdf_grad(normal, y, 0.0, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 1 - @test length(collect(get_submaps_shallow(gradients))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 0 @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test length(input_grads) == 3 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.0, 1.0)[2]) @test isapprox(input_grads[2], logpdf_grad(normal, y, 0.0, 1.0)[3]) diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 46954e3be..b440a44fa 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -197,9 +197,9 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false @test discard[(3, Val(:aggregation)) => :prefix] == true - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 @test length(collect(get_values_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 1 @test retdiff == UnknownChange() diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index ba748453b..0f3a56180 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -28,7 +28,7 @@ x3 = trace[3 => :x] choices = get_choices(trace) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 expected_score = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x2, x1 * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -55,7 +55,7 @@ @test choices[1 => :x] == x1 @test choices[3 => :x] == x3 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x2 = choices[2 => :x] expected_weight = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -77,7 +77,7 @@ beta = 0.3 (choices, weight, retval) = propose(foo, (3, x_init, alpha, beta)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x1 = choices[1 => :x] x2 = choices[2 => :x] x3 = choices[3 => :x] diff --git a/test/optional_args.jl b/test/optional_args.jl index fd6c4ea71..b0fb821bd 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -1,4 +1,4 @@ -using Gen +#using Gen @testset "optional positional args (calling + GFI)" begin diff --git a/test/runtests.jl b/test/runtests.jl index a67a5f782..749236037 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,4 +74,4 @@ include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 91c6c3202..9e2cecf3b 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -1,4 +1,6 @@ using Gen: generate_generative_function +using Test +using Gen @testset "static IR" begin @@ -362,12 +364,12 @@ end @test get_value(value_trie, :out) == out @test get_value(value_trie, :bar => :z) == z @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(value_trie))) == 1 + @test length(collect(get_values_shallow(value_trie))) == 2 # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(gradient_trie))) == 1 + @test length(collect(get_values_shallow(gradient_trie))) == 2 @test !has_value(gradient_trie, :b) # was not selected @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index fbd528b76..8396fe517 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,4 +1,4 @@ -using Gen +using .Gen import MacroTools normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) From 83349c7d4a320e028c9b24e26da4c3b44066fce9 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:13:03 -0400 Subject: [PATCH 05/17] performance improvements and benchmarking --- src/choice_map/choice_map.jl | 8 ++--- src/choice_map/static_choice_map.jl | 6 ++-- src/static_ir/trace.jl | 15 +++------ test/static_choicemap_benchmark.jl | 50 +++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 18 deletions(-) create mode 100644 test/static_choicemap_benchmark.jl diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl index 402cefa37..213bc5f80 100644 --- a/src/choice_map/choice_map.jl +++ b/src/choice_map/choice_map.jl @@ -71,8 +71,8 @@ A syntactic sugar is `Base.getindex`: value = choices[addr] """ function get_value end -get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) -get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) @inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) """ @@ -145,8 +145,8 @@ end @inline get_value(choices::ValueChoiceMap) = choices.val @inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() @inline get_submaps_shallow(choices::ValueChoiceMap) = () -Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val -Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +@inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) @inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() """ diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 1f75b3bca..58ef57d37 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -26,10 +26,10 @@ end quote EmptyChoiceMap() end end end -static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() +@inline static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() -static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) -static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) +@inline static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +@inline static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 5ac3ced16..168ccf50e 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -23,7 +23,7 @@ abstract type StaticIRTrace <: Trace end @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false - Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) +@inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) @inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) @@ -31,6 +31,8 @@ abstract type StaticIRTrace <: Trace end return Gen.static_get_subtrace(trace, Val(first))[rest] end +@inline get_choices(trace::T) where {T <: StaticIRTrace} = StaticIRTraceAssmt{T}(trace) + const arg_prefix = gensym("arg") const choice_value_prefix = gensym("choice_value") const choice_score_prefix = gensym("choice_score") @@ -133,14 +135,6 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_choices(trace_struct_name::Symbol) - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_choices)), :(trace::$trace_struct_name)), - Expr(:if, :(!isempty(trace)), - :($(QuoteNode(StaticIRTraceAssmt))(trace)), - :($(QuoteNode(EmptyChoiceMap))()))) -end - function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes @@ -236,7 +230,6 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) - get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) @@ -244,7 +237,7 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) + get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/test/static_choicemap_benchmark.jl b/test/static_choicemap_benchmark.jl new file mode 100644 index 000000000..1e62b9a8e --- /dev/null +++ b/test/static_choicemap_benchmark.jl @@ -0,0 +1,50 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +scm = StaticChoiceMap(a=1, b=StaticChoiceMap(c=2)) + +println("static choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(scm) +end + +println("static choicemap nested lookup:") +for _=1:4 + @time many_nested(scm) +end + +@gen (static) function inner() + c ~ normal(0, 1) +end +@gen (static) function outer() + a ~ normal(0, 1) + b ~ inner() +end + +load_generated_functions() + +tr, _ = generate(outer, ()) +choices = get_choices(tr) + +println("static gen function choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(choices) +end + +println("static gen function choicemap nested lookup:") +for _=1:4 + @time many_nested(choices) +end From b9b5312e990fc49b08611b7077b7c6f3aa5d99ee Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:21:45 -0400 Subject: [PATCH 06/17] benchmark for dynamic choicemap lookups --- test/dynamic_choicemap_benchmark.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 test/dynamic_choicemap_benchmark.jl diff --git a/test/dynamic_choicemap_benchmark.jl b/test/dynamic_choicemap_benchmark.jl new file mode 100644 index 000000000..3724e44de --- /dev/null +++ b/test/dynamic_choicemap_benchmark.jl @@ -0,0 +1,27 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +cm = choicemap((:a, 1), (:b => :c, 2)) + +println("dynamic choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(cm) +end + +println("dynamic choicemap nested lookup:") +for _=1:4 + @time many_nested(cm) +end \ No newline at end of file From bce5e7724db64175bf2fd0f15fe25a4dc68af13e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Mon, 18 May 2020 23:30:40 -0400 Subject: [PATCH 07/17] inline dynamicchoicemap methods --- src/choice_map/dynamic_choice_map.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/choice_map/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl index a3403307c..0f27c89d7 100644 --- a/src/choice_map/dynamic_choice_map.jl +++ b/src/choice_map/dynamic_choice_map.jl @@ -67,16 +67,10 @@ function choicemap(tuples...) DynamicChoiceMap(tuples...) end -get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.submaps, addr) - choices.submaps[addr] - else - EmptyChoiceMap() - end -end -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) -Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) +@inline get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +@inline get_submap(choices::DynamicChoiceMap, addr) = get(choices.submaps, addr, EmptyChoiceMap()) +@inline get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) # mutation (not part of the assignment interface) From a985f9bd3dc3f8806e2da1e7c81fbe891334bac9 Mon Sep 17 00:00:00 2001 From: georgematheos Date: Tue, 19 May 2020 09:13:32 -0400 Subject: [PATCH 08/17] remove old version benchmark file --- test/benchmark.md | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 test/benchmark.md diff --git a/test/benchmark.md b/test/benchmark.md deleted file mode 100644 index adabb8a58..000000000 --- a/test/benchmark.md +++ /dev/null @@ -1,21 +0,0 @@ -NEW version: -static choicemap nonnested lookup: - 0.728112 seconds (149.59 k allocations: 4.259 MiB) - 0.785652 seconds (100.00 k allocations: 1.526 MiB) - 0.693433 seconds (100.00 k allocations: 1.526 MiB) - 0.660211 seconds (100.00 k allocations: 1.526 MiB) -static choicemap nested lookup: - 0.680497 seconds (49.59 k allocations: 2.732 MiB) - 0.665768 seconds (1 allocation: 32 bytes) - 0.666708 seconds (1 allocation: 32 bytes) - 0.671009 seconds (1 allocation: 32 bytes) -static gen function choicemap nonnested lookup: - 0.701754 seconds (62.76 k allocations: 3.415 MiB) - 0.662916 seconds - 0.659019 seconds - 0.663398 seconds -static gen function choicemap nested lookup: - 1.338034 seconds (172.13 k allocations: 5.352 MiB) - 1.311123 seconds (100.00 k allocations: 1.526 MiB) - 1.311800 seconds (100.00 k allocations: 1.526 MiB) - 1.310289 seconds (100.00 k allocations: 1.526 MiB) \ No newline at end of file From 1f5029cfc1637d4d3ac257cd46835312131c6ee2 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 09:15:55 -0400 Subject: [PATCH 09/17] minor testing cleanup --- test/optional_args.jl | 2 +- test/static_inference_benchmark.jl | 23 +++++++++++++++++++++++ test/static_ir/static_ir.jl | 2 -- test/tilde_sugar.jl | 2 +- 4 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 test/static_inference_benchmark.jl diff --git a/test/optional_args.jl b/test/optional_args.jl index b0fb821bd..fd6c4ea71 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -1,4 +1,4 @@ -#using Gen +using Gen @testset "optional positional args (calling + GFI)" begin diff --git a/test/static_inference_benchmark.jl b/test/static_inference_benchmark.jl new file mode 100644 index 000000000..b70d08be2 --- /dev/null +++ b/test/static_inference_benchmark.jl @@ -0,0 +1,23 @@ +using Gen + +@gen (static, diffs) function foo() + a ~ normal(0, 1) + b ~ normal(a, 1) + c ~ normal(b, 1) +end + +@load_generated_functions + +observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) +tr, _ = generate(foo, (), observations) + +function run_inference(trace) + tr = trace + for _=1:10^3 + tr, acc = mh(tr, select(:a)) + end +end + +for _=1:4 + @time run_inference(tr) +end \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 9e2cecf3b..1b594d39d 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -1,6 +1,4 @@ using Gen: generate_generative_function -using Test -using Gen @testset "static IR" begin diff --git a/test/tilde_sugar.jl b/test/tilde_sugar.jl index 8396fe517..fbd528b76 100644 --- a/test/tilde_sugar.jl +++ b/test/tilde_sugar.jl @@ -1,4 +1,4 @@ -using .Gen +using Gen import MacroTools normalize(ex) = MacroTools.prewalk(MacroTools.rmlines, ex) From eb6adf7a76c5975fa20d7567a560175588aafed4 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:33:16 -0400 Subject: [PATCH 10/17] ensure valuechoicemap[] syntax works --- test/assignment.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/assignment.jl b/test/assignment.jl index 1d7e48a80..69890297f 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -5,6 +5,8 @@ @test vcm1 isa ValueChoiceMap{Int} @test vcm2 isa ValueChoiceMap{Float64} @test vcm3 isa ValueChoiceMap{Vector{Int}} + @test vcm1[] == 2 + @test vcm1[] == get_value(vcm1) @test !isempty(vcm1) @test has_value(vcm1) From eef941776857c50d8ad93ead2ee0d164d60f737e Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:43:49 -0400 Subject: [PATCH 11/17] provide some examples in the documentation --- docs/src/ref/choice_maps.md | 38 ++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 6c445df6f..bf742f904 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -14,6 +14,42 @@ for sub-choicemaps. Leaf nodes have type: ValueChoiceMap ``` +### Example Usage Overview + +Choicemaps store values nested in a tree where each node posesses an address for each subtree. +A leaf-node choicemap simply contains a value, and has it's value looked up via: +```julia +value = choicemap[] +``` +If a choicemap has a value choicemap at address `:a`, it is looked up via: +```julia +value = choicemap[:a] +``` +And a choicemap may also have a non-value choicemap stored at a value. For instance, +if a choicemap has another choicemap stored at address `:a`, and this internal choicemap +has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: +```julia +value1 = choicemap[:a => :b] +value2 = choicemap[:a => :c] +``` +Nesting can be arbitrarily deep, and the keys can be arbitrary values; for instance +choicemaps can be constructed with values at the following nested addresses: +```julia +value = choicemap[:a => :b => :c => 4 => 1.63 => :e] +value = choicemap[:a => :b => :a => 2 => "alphabet" => :e] +``` +To get a sub-choicemap, use `get_submap`: +```julia +value1 = choicemap[:a => :b] +submap = get_submap(choicemap, :a) +value1 == submap[:b] # is true + +value_submap = get_submap(choicemap, :a => :b) +value_submap[] == value1 # is true +``` + +### Interface + Choice maps provide the following methods: ```@docs get_submap @@ -58,7 +94,7 @@ set_value! set_submap! ``` -## Implementing custom choicemap types +### Implementing custom choicemap types To implement a custom choicemap, one must implement `get_submap` and `get_submaps_shallow`. From a83adfbc2d02bed4e9c0a78163151c742cc660f8 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:50:25 -0400 Subject: [PATCH 12/17] fix some typos --- docs/src/ref/choice_maps.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index bf742f904..2963d304a 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -17,15 +17,15 @@ ValueChoiceMap ### Example Usage Overview Choicemaps store values nested in a tree where each node posesses an address for each subtree. -A leaf-node choicemap simply contains a value, and has it's value looked up via: +A leaf-node choicemap simply contains a value, and has its value looked up via: ```julia value = choicemap[] ``` -If a choicemap has a value choicemap at address `:a`, it is looked up via: +If a choicemap has a value choicemap at address `:a`, the value it stores is looked up via: ```julia value = choicemap[:a] ``` -And a choicemap may also have a non-value choicemap stored at a value. For instance, +A choicemap may also have a non-value choicemap stored at an address. For instance, if a choicemap has another choicemap stored at address `:a`, and this internal choicemap has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: ```julia From 1bd705f101bb7c783aedad30fe442f864bcec625 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Tue, 19 May 2020 12:54:25 -0400 Subject: [PATCH 13/17] add phrase 'nesting level zero' to docs --- docs/src/ref/choice_maps.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index 2963d304a..4a23b7cfa 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -47,6 +47,8 @@ value1 == submap[:b] # is true value_submap = get_submap(choicemap, :a => :b) value_submap[] == value1 # is true ``` +One can think of `ValueChoiceMap`s at storing being a choicemap which has a value at "nesting level zero", +while other choicemaps have values at "nesting level" one or higher. ### Interface From 972d4555907813ec7fd77a2b202fdfdacf4d5f79 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 09:29:06 -0400 Subject: [PATCH 14/17] default static_get_submap = EmptyChoiceMap --- src/static_ir/trace.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 168ccf50e..a79ed539b 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -20,6 +20,7 @@ function get_schema end abstract type StaticIRTrace <: Trace end @inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() @inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false From fd1991ff3df029224ddc464642abb0ec15c5ead3 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 10:53:28 -0400 Subject: [PATCH 15/17] minor performance improvement --- src/choice_map/static_choice_map.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 58ef57d37..1aa40d4f3 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -36,7 +36,7 @@ end function StaticChoiceMap(other::ChoiceMap) keys_and_nodes = collect(get_submaps_shallow(other)) if length(keys_and_nodes) > 0 - (addrs::NTuple{n, Symbol} where {n}, submaps) = collect(zip(keys_and_nodes...)) + (addrs::NTuple{n, Symbol} where {n}, submaps) = zip(keys_and_nodes...) else addrs = () submaps = () From c3d5db029e57d7bcb381a113dca1fa3659983296 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Thu, 18 Jun 2020 12:35:10 -0400 Subject: [PATCH 16/17] performance improvement related to zip bug --- src/choice_map/static_choice_map.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl index 1aa40d4f3..ff8c01a7e 100644 --- a/src/choice_map/static_choice_map.jl +++ b/src/choice_map/static_choice_map.jl @@ -34,9 +34,10 @@ end # convert a nonvalue choicemap all of whose top-level-addresses # are symbols into a staticchoicemap at the top level function StaticChoiceMap(other::ChoiceMap) - keys_and_nodes = collect(get_submaps_shallow(other)) + keys_and_nodes = get_submaps_shallow(other) if length(keys_and_nodes) > 0 - (addrs::NTuple{n, Symbol} where {n}, submaps) = zip(keys_and_nodes...) + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) else addrs = () submaps = () From c34c60b0246cb1fc51ee5fabf5b3d1444debd25a Mon Sep 17 00:00:00 2001 From: George Matheos Date: Sat, 20 Jun 2020 11:41:08 -0400 Subject: [PATCH 17/17] bug fix --- src/static_ir/trace.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index a79ed539b..3c7016a1f 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -21,7 +21,7 @@ abstract type StaticIRTrace <: Trace end @inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") @inline static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() -@inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_submap(trace, v)) +@inline static_get_value(assmt::StaticIRTraceAssmt, v::Val) = get_value(static_get_submap(assmt, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false @inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key))