From 39807e874a619aa648d2991f8b8ccf05bd0372c8 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 31 Aug 2023 11:04:02 -0600 Subject: [PATCH] feat: extract trace, splice --- examples/introduction.clj | 2 +- src/gen/dynamic.cljc | 102 +++++++++++----------- src/gen/dynamic/trace.cljc | 143 +++++++++++++++++-------------- test/gen/dynamic/trace_test.cljc | 10 +-- 4 files changed, 130 insertions(+), 127 deletions(-) diff --git a/examples/introduction.clj b/examples/introduction.clj index ae3e164..5374942 100644 --- a/examples/introduction.clj +++ b/examples/introduction.clj @@ -140,7 +140,7 @@ ;; then these values are sufficient to answer any question using executions of ;; the function, because all states in the execution of the function are ;; deterministic given the random choices. We will call the record of all the -;; random choies a **trace**. In order to store all the random choices in the +;; random choices a **trace**. In order to store all the random choices in the ;; trace, we need to come up with a unique name or **address** for each random ;; choice. diff --git a/src/gen/dynamic.cljc b/src/gen/dynamic.cljc index c0735b2..42c5f38 100644 --- a/src/gen/dynamic.cljc +++ b/src/gen/dynamic.cljc @@ -9,55 +9,48 @@ #?(:cljs (:require-macros [gen.dynamic]))) +(defrecord GenerateMap [constraints trace weight] + dynamic.trace/ITrace + (-splice [state gf args] + (let [{subtrace :trace + weight :weight} + (gf/generate gf args constraints)] + [(-> state + (update :trace dynamic.trace/merge-subtraces subtrace) + (update :weight + weight)) + (trace/retval subtrace)])) + + (-trace [state k gf args] + (dynamic.trace/validate-empty! trace k) + (let [{subtrace :trace :as ret} + (if-let [k-constraints (get (choice-map/submaps constraints) k)] + (gf/generate gf args k-constraints) + (gf/generate gf args))] + [(dynamic.trace/combine state k ret) + (trace/retval subtrace)]))) + (defrecord DynamicDSLFunction [clojure-fn] gf/Simulate (simulate [gf args] - (let [trace (atom (dynamic.trace/trace gf args))] - (binding [dynamic.trace/*splice* - (fn [gf args] - (let [subtrace (gf/simulate gf args)] - (swap! trace dynamic.trace/merge-subtraces subtrace) - (trace/retval subtrace))) - - dynamic.trace/*trace* - (fn [k gf args] - (dynamic.trace/validate-empty! @trace k) - (let [subtrace (gf/simulate gf args)] - (swap! trace dynamic.trace/assoc-subtrace k subtrace) - (trace/retval subtrace)))] - (let [retval (apply clojure-fn args)] - (swap! trace dynamic.trace/with-retval retval) - @trace)))) + (let [!trace (atom (dynamic.trace/trace gf args)) + retval (binding [dynamic.trace/*active* !trace] + (apply clojure-fn args)) + trace @!trace] + (dynamic.trace/with-retval trace retval))) gf/Generate (generate [gf args] (let [trace (gf/simulate gf args)] {:trace trace :weight (math/log 1)})) (generate [gf args constraints] - (let [state (atom {:trace (dynamic.trace/trace gf args) - :weight 0})] - (binding [dynamic.trace/*splice* - (fn [gf args] - (let [{subtrace :trace - weight :weight} - (gf/generate gf args constraints)] - (swap! state update :trace dynamic.trace/merge-subtraces subtrace) - (swap! state update :weight + weight) - (trace/retval subtrace))) - - dynamic.trace/*trace* - (fn [k gf args] - (dynamic.trace/validate-empty! (:trace @state) k) - (let [{subtrace :trace :as ret} - (if-let [k-constraints (get (choice-map/submaps constraints) k)] - (gf/generate gf args k-constraints) - (gf/generate gf args))] - (swap! state dynamic.trace/combine k ret) - (trace/retval subtrace)))] - (let [retval (apply clojure-fn args) - trace (:trace @state)] - {:trace (dynamic.trace/with-retval trace retval) - :weight (:weight @state)})))) + (let [!state (atom (->GenerateMap + constraints + (dynamic.trace/trace gf args) + 0)) + retval (binding [dynamic.trace/*active* !state] + (apply clojure-fn args)) + state @!state] + (update state :trace dynamic.trace/with-retval retval))) #?@(:clj [clojure.lang.IFn @@ -154,19 +147,20 @@ `(->DynamicDSLFunction (fn ~@(when name [name]) ~params - ~@(walk/postwalk (fn [form] - (cond (trace-form? form) - (if-not (valid-trace-form? form) - (throw (ex-info "Malformed trace expression." {:form form})) - (let [[addr [gf & args]] (rest form)] - `((dynamic.trace/active-trace) ~addr ~gf ~(vec args)))) + ~@(walk/postwalk + (fn [form] + (cond (trace-form? form) + (if-not (valid-trace-form? form) + (throw (ex-info "Malformed trace expression." {:form form})) + (let [[addr [gf & args]] (rest form)] + `(dynamic.trace/trace! ~addr ~gf ~(vec args)))) - (splice-form? form) - (if-not (valid-splice-form? form) - (throw (ex-info "Malformed splice expression." {:form form})) - (let [[[gf & args]] (rest form)] - `((dynamic.trace/active-splice) ~gf ~(vec args)))) + (splice-form? form) + (if-not (valid-splice-form? form) + (throw (ex-info "Malformed splice expression." {:form form})) + (let [[[gf & args]] (rest form)] + `(dynamic.trace/splice! ~gf ~(vec args)))) - :else - form)) - body))))) + :else + form)) + body))))) diff --git a/src/gen/dynamic/trace.cljc b/src/gen/dynamic/trace.cljc index 35de1cb..47ff064 100644 --- a/src/gen/dynamic/trace.cljc +++ b/src/gen/dynamic/trace.cljc @@ -12,46 +12,39 @@ (:import (clojure.lang Associative IFn IObj IMapIterable Seqable)))) -(defn no-op - ([gf args] - (apply gf args)) - ([_k gf args] - (apply gf args))) - -(def ^:dynamic *trace* - "Applies the generative function gf to args. Dynamically rebound by functions - like `gf/simulate`, `gf/generate`, `trace/update`, etc." - no-op) - -(def ^:dynamic *splice* - "Applies the generative function gf to args. Dynamically rebound by functions - like `gf/simulate`, `gf/generate`, `trace/update`, etc." - no-op) - -(defn active-trace - "Returns the currently-active tracing function, bound to [[*trace*]]. - - NOTE: Prefer `([[active-trace]])` to `[[*trace*]]`, as direct access to - `[[*trace*]]` won't reflect new bindings when accessed inside of an SCI - environment." - [] *trace*) - -(defn active-splice - "Returns the currently-active tracing function, bound to [[*splice*]]. - - NOTE: Prefer `([[active-splice]])` to `[[*splice*]]`, as direct access to - `[[*splice*]]` won't reflect new bindings when accessed inside of an SCI - environment." - [] - *splice*) +(defprotocol ITrace + (-splice [this gf args]) + (-trace [this addr gf args])) + +(defrecord NoOp [] + ITrace + (-splice [this gf args] + [this (apply gf args)]) + (-trace [this _k gf args] + [this (apply gf args)])) + +(def no-op (NoOp.)) + +(def ^:dynamic *active* (atom no-op)) + +(defn active [] *active*) + +(defn splice! [gf args] + (let [[new-state ret] (-splice @*active* gf args)] + (swap! *active* (fn [_] new-state)) + ret)) + +(defn trace! [k gf args] + (let [[new-state ret] (-trace @*active* k gf args)] + (swap! *active* (fn [_] new-state)) + ret)) (defmacro without-tracing [& body] - `(binding [*trace* no-op - *splice* no-op] + `(binding [*active* (atom no-op)] ~@body)) -(declare assoc-subtrace update-trace trace =) +(declare assoc-subtrace merge-subtraces update-trace validate-empty! trace =) (deftype Trace [gf args subtraces retval] trace/Args @@ -79,6 +72,18 @@ (update [this constraints] (update-trace this constraints)) + ITrace + (-splice [this gf args] + (let [subtrace (gf/simulate gf args)] + [(merge-subtraces this subtrace) + (trace/retval subtrace)])) + + (-trace [this k gf args] + (validate-empty! this k) + (let [subtrace (gf/simulate gf args)] + [(assoc-subtrace this k subtrace) + (trace/retval subtrace)])) + #?@(:cljs [Object (equiv [this other] (-equiv this other)) @@ -193,9 +198,9 @@ [^Trace t addr subt] (validate-empty! t addr) (->Trace (.-gf t) - (.-args t) - (assoc (.-subtraces t) addr subt) - (.-retval t))) + (.-args t) + (assoc (.-subtraces t) addr subt) + (.-retval t))) (defn merge-subtraces [^Trace t1 ^Trace t2] @@ -211,34 +216,42 @@ (update :weight + weight) (cond-> discard (update :discard assoc k discard)))) +;; TODO: this does NOT feel like the right data structure. In fact I think +;; updates should be able to shuffle over the unused stuff from update to +;; update, instead of having to do that final update at the very end. +;; +;; Then each update step could shuffling from the constraints over to the end. +(defrecord UpdateMap [this constraints trace weight discard] + ITrace + (-splice [_ _ _] + (throw (ex-info "Not yet implemented." {}))) + + (-trace [state k gf args] + (validate-empty! trace k) + (let [k-constraints (get (choice-map/submaps constraints) k) + {subtrace :trace :as ret} + (if-let [prev-subtrace (get (.-subtraces ^Trace this) k)] + (trace/update prev-subtrace k-constraints) + (gf/generate gf args k-constraints))] + [(combine state k ret) + (trace/retval subtrace)]))) + (defn update-trace [this constraints] - (let [gf (trace/gf this) - state (atom {:trace (trace gf (trace/args this)) - :weight 0 - :discard (cm/choice-map)})] - (binding [*splice* - (fn [& _] - (throw (ex-info "Not yet implemented." {}))) - - *trace* - (fn [k gf args] - (validate-empty! (:trace @state) k) - (let [k-constraints (get (choice-map/submaps constraints) k) - {subtrace :trace :as ret} - (if-let [prev-subtrace (get (.-subtraces this) k)] - (trace/update prev-subtrace k-constraints) - (gf/generate gf args k-constraints))] - (swap! state combine k ret) - (trace/retval subtrace)))] - (let [retval (apply (:clojure-fn gf) (trace/args this)) - {:keys [trace weight discard]} @state - unvisited (apply dissoc - (trace/choices this) - (keys (trace/choices trace)))] - - {:trace (with-retval trace retval) - :weight weight - :discard (merge discard unvisited)})))) + (let [gf (trace/gf this) + !state (atom (->UpdateMap + this constraints + (trace gf (trace/args this)) + 0 + (cm/choice-map))) + retval (binding [*active* !state] + (apply (:clojure-fn gf) (trace/args this))) + {:keys [trace weight discard]} @!state + unvisited (apply dissoc + (trace/choices this) + (keys (trace/choices trace)))] + {:trace (with-retval trace retval) + :weight weight + :discard (merge discard unvisited)})) ;; ## Primitive Trace ;; diff --git a/test/gen/dynamic/trace_test.cljc b/test/gen/dynamic/trace_test.cljc index 3e129ee..d559449 100644 --- a/test/gen/dynamic/trace_test.cljc +++ b/test/gen/dynamic/trace_test.cljc @@ -10,13 +10,9 @@ (deftest binding-tests (letfn [(f [_] "hi!")] - (binding [dynamic.trace/*trace* f - dynamic.trace/*splice* f] - (is (= f (dynamic.trace/active-trace)) - "active-trace reflects dynamic bindings") - - (is (= f (dynamic.trace/active-splice)) - "active-splice reflects dynamic bindings")))) + (binding [dynamic.trace/*active* f] + (is (= f (dynamic.trace/active)) + "active reflects dynamic bindings")))) (defn choice-trace [x]