From c34ba2df6625a48e378ec52e8496a72a4919ec6b Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Tue, 22 Oct 2024 18:21:54 -0400 Subject: [PATCH 01/86] Reorg head of file New plan is as follows. * Load map (&clutters) exclusively, and visualize it. * Visualize robot poses, and ideal sensors. * Introduce robot programs, and ideal (physical) motion. * Carry on with the rest of the tutorial, minus the above items. The code is now ready for independent extension in both single-pose and goal-inference branches. --- probcomp-localization-tutorial.jl | 354 ++++++++++++++------------ robot_program.json | 87 +++++++ example_20_program.json => world.json | 89 +------ 3 files changed, 279 insertions(+), 251 deletions(-) create mode 100644 robot_program.json rename example_20_program.json => world.json (94%) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 60d60fa..d040347 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -38,30 +38,17 @@ using Gen mkpath("imgs"); # %% [markdown] -# ## The "real world" +# ## Ojbects of reasoning # -# We assume given -# * a map of a space, together with -# * some clutters that sometimes unexpectedly exist in that space. +# ### The map # -# We also assume given a description of a robot's behavior via -# * an estimated initial pose (= position + heading), and -# * a program of controls (= advance distance, followed by rotate heading). -# -# *In addition to the uncertainty in the initial pose, we are uncertain about the true execution of the motion of the robot.* +# The tutorial will revolve around modeling the activity of a robot within some space. A large simplifying assumption, which could be lifted with more effort, is that we have been given a *map* of the space, to which the robot will have access. # -# Below, we will also introduce sensors. - -# %% [markdown] -# ### Load map and robot data -# -# Generally speaking, we keep general code and specific examples in separate cells, as signposted here. +# The code below loads such a map, along with other data for later use. Generally speaking, we keep general code and specific examples in separate cells, as signposted here. # %% # General code here -norm(v) = sqrt(sum(v.^2)) - struct Segment p1 :: Vector{Float64} p2 :: Vector{Float64} @@ -70,6 +57,84 @@ struct Segment end Base.show(io :: IO, s :: Segment) = Base.show(io, "Segment($(s.p1), $(s.p2))") +function create_segments(verts; loop_around=false) + segs = [Segment(p1, p2) for (p1, p2) in zip(verts[1:end-1], verts[2:end])] + if loop_around; push!(segs, Segment(verts[end], verts[1])) end + return segs +end + +function make_world(walls_vec, clutters_vec, start, controls; args...) + walls = create_segments(walls_vec; args...) + clutters = [create_segments(clutter; args...) for clutter in clutters_vec] + all_points = [walls_vec ; clutters_vec...] + x_min, x_max = extrema(first, all_points) + y_min, y_max = extrema(last, all_points) + bounding_box = (x_min, x_max, y_min, y_max) + box_size = max(x_max - x_min, y_max - y_min) + center_point = [(x_min + x_max) / 2.0, (y_min + y_max) / 2.0] + return (; walls, clutters, bounding_box, box_size, center_point) +end + +function load_world(file_name; args...) + data = parsefile(file_name) + walls_vec = Vector{Vector{Float64}}(data["wall_verts"]) + clutters_vec = Vector{Vector{Vector{Float64}}}(data["clutter_vert_groups"]) + return make_world(walls_vec, clutters_vec; args...) +end; + +# %% +# Specific example code here + +world = load_world("world.json"); + +# %% [markdown] +# ### Plotting +# +# It is crucial to picture what we are doing at all times, so we develop plotting code early and often. + +# %% +function plot_list!(list; label=nothing, args...) + if !isempty(list) + plt = plot!(list[1]; label=label, args...) + for item in list[2:end]; plot!(item; label=nothing, args...) end + return plt + end +end + +Plots.plot!(seg :: Segment; args...) = plot!([seg.p1[1], seg.p2[1]], [seg.p1[2], seg.p2[2]]; args...) +Plots.plot!(segs :: Vector{Segment}; args...) = plot_list!(segs; args...) +Plots.plot!(seg_groups :: Vector{Vector{Segment}}; args...) = plot_list!(seg_groups; args...) + +function plot_world(world, title; show=()) + border = world.box_size * (3.)/19. + the_plot = plot( + size = (500, 500), + aspect_ratio = :equal, + grid = false, + xlim = (world.bounding_box[1]-border, world.bounding_box[2]+border), + ylim = (world.bounding_box[3]-border, world.bounding_box[4]+border), + title = title, + legend = :bottomleft) + (walls_label, clutter_label) = :label in show ? ("walls", "clutters") : (nothing, nothing) + plot!(world.walls; c=:black, label=walls_label) + if :clutters in show; plot!(world.clutters; c=:magenta, label=clutter_label) end + return the_plot +end; + +# %% [markdown] +# Following this initial display of the given data, we *suppress the clutters* until much later in the notebook. + +# %% +plot_world(world, "Given data", show=(:label, :clutters=true)) + +# %% [markdown] +# ### Robot poses +# +# We will model the robot's physical state as a *pose* (or mathematically speaking a ray), defined to be a *position* (2D point relative to the map) plus a *heading* (angle from -$\pi$ to $\pi$). +# +# These will be visualized using arrows whose tip is at the position, and whose direction indicates the heading. + +# %% struct Pose p :: Vector{Float64} hd :: Float64 @@ -82,50 +147,126 @@ Base.show(io :: IO, p :: Pose) = Base.show(io, "Pose($(p.p), $(p.hd))") step_along_pose(p, s) = p.p + s * p.dp rotate_pose(p, a) = Pose(p.p, p.hd + a) -# A value `c :: Control` corresponds to the robot *first* advancing in its present direction by `c.ds`, *then* rotating by `c.dhd`. -struct Control - ds :: Float64 - dhd :: Float64 +Plots.plot!(p :: Pose; r=0.5, args...) = plot!(Segment(p.p, step_along_pose(p, r)); arrow=true, args...) +Plots.plot!(ps :: Vector{Pose}; args...) = plot_list!(ps; args...); + +# %% +the_plot = plot_world(world, "Given data") +plot!(Pose([1., 1.], 0.); color=:green3, label="a pose") +plot!(Pose([2., 3.], pi/2.); color=:green4, label="another pose") +the_plot + +# %% [markdown] +# ### Ideal sensors +# +# The robot will need to reason about its location on the map, on the basis of LIDAR-like sensor data. + +# %% +# A general algorithm to find the interection of a ray and a line segment. + +norm(v) = sqrt(sum(v.^2)) + +# Return unique s, t such that p + s*u == q + t*v. +function solve_lines(p, u, q, v; PARALLEL_TOL=1.0e-10) + det = u[1] * v[2] - u[2] * v[1] + if abs(det) < PARALLEL_TOL + return nothing, nothing + else + pq = p - q + s = (v[1] * pq[2] - v[2] * pq[1]) / det + t = (u[1] * pq[2] - u[2] * pq[1]) / det + return s, t + end end -function create_segments(verts; loop_around=false) - segs = [Segment(p1, p2) for (p1, p2) in zip(verts[1:end-1], verts[2:end])] - if loop_around; push!(segs, Segment(verts[end], verts[1])) end - return segs +function distance(p :: Pose, seg :: Segment) + s, t = solve_lines(p.p, p.dp, seg.p1, seg.dp) + # Solving failed (including, by fiat, if pose is parallel to segment) iff isnothing(s). + # Pose is oriented away from segment iff s < 0. + # Point of intersection lies on segment (as opposed to the infinite line) iff 0 <= t <= 1. + return (isnothing(s) || s < 0. || !(0. <= t <= 1.)) ? Inf : s +end; + +# %% [markdown] +# An ideal sensor reports the exact distance cast to a wall. (It is capped off at a max value in case of error.) + +# %% +function sensor_distance(pose, walls, box_size) + d = minimum(distance(pose, seg) for seg in walls) + # Capping to a finite value avoids issues below. + return isinf(d) ? 2. * box_size : d +end; + +sensor_angle(sensor_settings, j) = + sensor_settings.fov * (j - (sensor_settings.num_angles - 1) / 2.) / (sensor_settings.num_angles - 1) + +function ideal_sensor(pose, walls, sensor_settings) + readings = Vector{Float64}(undef, sensor_settings.num_angles) + for j in 1:sensor_settings.num_angles + sensor_pose = rotate_pose(pose, sensor_angle(sensor_settings, j)) + readings[j] = sensor_distance(sensor_pose, walls, sensor_settings.box_size) + end + return readings +end; + +# %% +# Plot sensor data. + +function plot_sensors!(pose, color, readings, label, sensor_settings) + plot!([pose.p[1]], [pose.p[2]]; color=color, label=nothing, seriestype=:scatter, markersize=3, markerstrokewidth=0) + projections = [step_along_pose(rotate_pose(pose, sensor_angle(sensor_settings, j)), s) for (j, s) in enumerate(readings)] + plot!(first.(projections), last.(projections); + color=:blue, label=label, seriestype=:scatter, markersize=3, markerstrokewidth=1, alpha=0.25) + plot!([Segment(pose.p, pr) for pr in projections]; color=:blue, label=nothing, alpha=0.25) end -function make_world(walls_vec, clutters_vec, start, controls; loop_around=false) - walls = create_segments(walls_vec; loop_around=loop_around) - clutters = [create_segments(clutter; loop_around=loop_around) for clutter in clutters_vec] - walls_clutters = [walls ; clutters...] - all_points = [walls_vec ; clutters_vec... ; [start.p]] - x_min = minimum(p[1] for p in all_points) - x_max = maximum(p[1] for p in all_points) - y_min = minimum(p[2] for p in all_points) - y_max = maximum(p[2] for p in all_points) - bounding_box = (x_min, x_max, y_min, y_max) - box_size = max(x_max - x_min, y_max - y_min) - center_point = [(x_min + x_max) / 2.0, (y_min + y_max) / 2.0] - T = length(controls) - return (walls=walls, clutters=clutters, walls_clutters=walls_clutters, - bounding_box=bounding_box, box_size=box_size, center_point=center_point), - (start=start, controls=controls), - T +function frame_from_sensors(world, title, poses, poses_color, poses_label, pose, readings, readings_label, sensor_settings; show_clutters=false) + the_plot = plot_world(world, title; show_clutters=show_clutters) + plot!(poses; color=poses_color, label=poses_label) + plot_sensors!(pose, poses_color, readings, readings_label, sensor_settings) + return the_plot +end; + +# %% +sensor_settings = (fov = 2π*(2/3), num_angles = 41, box_size = world.box_size) + +ani = Animation() +for pose in path_integrated + frame_plot = frame_from_sensors( + world, "Ideal sensor distances", + path_integrated, :green2, "some path", + pose, ideal_sensor(pose, world.walls, sensor_settings), "ideal sensors", + sensor_settings) + frame(ani, frame_plot) +end +gif(ani, "imgs/ideal_distances.gif", fps=1) + +# %% [markdown] +# ### Robot programs +# +# We also assume given a description of a robot's movement via +# * an estimated initial pose (= position + heading), and +# * a program of controls (= advance distance, followed by rotate heading). + +# %% +# A value `c :: Control` corresponds to the robot *first* advancing in its present direction by `c.ds`, *then* rotating by `c.dhd`. +struct Control + ds :: Float64 + dhd :: Float64 end -function load_world(file_name; loop_around=false) +function load_program(file_name) data = parsefile(file_name) - walls_vec = Vector{Vector{Float64}}(data["wall_verts"]) - clutters_vec = Vector{Vector{Vector{Float64}}}(data["clutter_vert_groups"]) start = Pose(Vector{Float64}(data["start_pose"]["p"]), Float64(data["start_pose"]["hd"])) controls = Vector{Control}([Control(control["ds"], control["dhd"]) for control in data["program_controls"]]) - return make_world(walls_vec, clutters_vec, start, controls; loop_around=loop_around) + return (; start, controls), length(controls) end; # %% -# Specific example code here +robot_inputs, T = load_program("robot_program.json") -world, robot_inputs, T = load_world("example_20_program.json"); +# %% [markdown] +# Before we can visualize such a program, we will need to model robot motion. # %% [markdown] # ### Integrate a path from a starting pose and controls @@ -150,26 +291,6 @@ end; # We employ the following simple physics: when the robot's forward step through a control comes into contact with a wall, that step is interrupted and the robot instead "bounces" a fixed distance from the point of contact in the normal direction to the wall. # %% -# Return unique s, t such that p + s*u == q + t*v. -function solve_lines(p, u, q, v; PARALLEL_TOL=1.0e-10) - det = u[1] * v[2] - u[2] * v[1] - if abs(det) < PARALLEL_TOL - return nothing, nothing - else - s = (v[1] * (p[2]-q[2]) - v[2] * (p[1]-q[1])) / det - t = (u[2] * (q[1]-p[1]) - u[1] * (q[2]-p[2])) / det - return s, t - end -end - -function distance(p, seg) - s, t = solve_lines(p.p, p.dp, seg.p1, seg.dp) - # Solving failed (including, by fiat, if pose is parallel to segment) iff isnothing(s). - # Pose is oriented away from segment iff s < 0. - # Point of intersection lies on segment (as opposed to the infinite line) iff 0 <= t <= 1. - return (isnothing(s) || s < 0. || !(0. <= t <= 1.)) ? Inf : s -end - function physical_step(p1, p2, hd, world_inputs) step_pose = Pose(p1, p2 - p1) (s, i) = findmin(w -> distance(step_pose, w), world_inputs.walls) @@ -205,43 +326,6 @@ world_inputs = (walls = world.walls, bounce = 0.1) path_integrated = integrate_controls(robot_inputs, world_inputs); -# %% [markdown] -# ### Plot such data - -# %% -function plot_list!(list; label=nothing, args...) - if isempty(list); return end - plt = plot!(list[1]; label=label, args...) - for item in list[2:end]; plot!(item; label=nothing, args...) end - return plt -end - -Plots.plot!(seg :: Segment; args...) = plot!([seg.p1[1], seg.p2[1]], [seg.p1[2], seg.p2[2]]; args...) -Plots.plot!(segs :: Vector{Segment}; args...) = plot_list!(segs; args...) -Plots.plot!(seg_groups :: Vector{Vector{Segment}}; args...) = plot_list!(seg_groups; args...) - -Plots.plot!(p :: Pose; r=0.5, args...) = plot!(Segment(p.p, step_along_pose(p, r)); arrow=true, args...) -Plots.plot!(ps :: Vector{Pose}; args...) = plot_list!(ps; args...) - -function plot_world(world, title; label_world=false, show_clutters=false) - border = world.box_size * (3.)/19. - the_plot = plot( - size = (500, 500), - aspect_ratio = :equal, - grid = false, - xlim = (world.bounding_box[1]-border, world.bounding_box[2]+border), - ylim = (world.bounding_box[3]-border, world.bounding_box[4]+border), - title = title, - legend = :bottomleft) - (walls_label, clutter_label) = label_world ? ("walls", "clutters") : (nothing, nothing) - plot!(world.walls; c=:black, label=walls_label) - if show_clutters; plot!(world.clutters; c=:magenta, label=clutter_label) end - return the_plot -end; - -# %% [markdown] -# Following this initial display of the given data, we *suppress the clutters* until much later in the notebook. - # %% the_plot = plot_world(world, "Given data", label_world=true, show_clutters=true) plot!(robot_inputs.start; color=:green3, label="given start pose") @@ -610,64 +694,6 @@ function integrate_controls_noisy(robot_inputs, world_inputs, motion_settings) return get_path(simulate(path_model, (length(robot_inputs.controls), robot_inputs, world_inputs, motion_settings))) end; -# %% [markdown] -# ### Ideal sensors -# -# We now, additionally, assume the robot is equipped with sensors that cast rays upon the environment at certain angles relative to the given pose, and return the distance to a hit. -# -# We first describe the ideal case, where the sensors return the true distances to the walls. - -# %% -function sensor_distance(pose, walls, box_size) - d = minimum(distance(pose, seg) for seg in walls) - # Capping to a finite value avoids issues below. - return isinf(d) ? 2. * box_size : d -end; - -sensor_angle(sensor_settings, j) = - sensor_settings.fov * (j - (sensor_settings.num_angles - 1) / 2.) / (sensor_settings.num_angles - 1) - -function ideal_sensor(pose, walls, sensor_settings) - readings = Vector{Float64}(undef, sensor_settings.num_angles) - for j in 1:sensor_settings.num_angles - sensor_pose = rotate_pose(pose, sensor_angle(sensor_settings, j)) - readings[j] = sensor_distance(sensor_pose, walls, sensor_settings.box_size) - end - return readings -end; - -# %% -# Plot sensor data. - -function plot_sensors!(pose, color, readings, label, sensor_settings) - plot!([pose.p[1]], [pose.p[2]]; color=color, label=nothing, seriestype=:scatter, markersize=3, markerstrokewidth=0) - projections = [step_along_pose(rotate_pose(pose, sensor_angle(sensor_settings, j)), s) for (j, s) in enumerate(readings)] - plot!(first.(projections), last.(projections); - color=:blue, label=label, seriestype=:scatter, markersize=3, markerstrokewidth=1, alpha=0.25) - plot!([Segment(pose.p, pr) for pr in projections]; color=:blue, label=nothing, alpha=0.25) -end - -function frame_from_sensors(world, title, poses, poses_color, poses_label, pose, readings, readings_label, sensor_settings; show_clutters=false) - the_plot = plot_world(world, title; show_clutters=show_clutters) - plot!(poses; color=poses_color, label=poses_label) - plot_sensors!(pose, poses_color, readings, readings_label, sensor_settings) - return the_plot -end; - -# %% -sensor_settings = (fov = 2π*(2/3), num_angles = 41, box_size = world.box_size) - -ani = Animation() -for pose in path_integrated - frame_plot = frame_from_sensors( - world, "Ideal sensor distances", - path_integrated, :green2, "some path", - pose, ideal_sensor(pose, world.walls, sensor_settings), "ideal sensors", - sensor_settings) - frame(ani, frame_plot) -end -gif(ani, "imgs/ideal_distances.gif", fps=1) - # %% [markdown] # ### Noisy sensors # @@ -2653,7 +2679,7 @@ the_plot # At the beginning of the notebook, we illustrated the data of "clutters", or extra boxes left inside the environment. These would impact the motion and the sensory *observation data* of a run of the robot, but are not accounted for in the above *model* when attempting to infer its path. How well does the inference process work in the presence of such discrepancies? # %% -world_inputs_cluttered = (world_inputs..., walls=world.walls_clutters) +world_inputs_cluttered = (world_inputs..., walls=[world.walls ; world.clutters...]) trace_cluttered = simulate(full_model, (T, robot_inputs, world_inputs_cluttered, full_settings_low_dev)) path_cluttered = get_path(trace_cluttered) observations_cluttered = get_sensors(trace_cluttered) diff --git a/robot_program.json b/robot_program.json new file mode 100644 index 0000000..9064dab --- /dev/null +++ b/robot_program.json @@ -0,0 +1,87 @@ +{ + "start_pose": { + "hd": 0.08090409915523009, + "p": [ + 1.8437380952380948, + 16.669857142857147 + ] + }, + "program_controls": [ + { + "dhd": -0.6276929400444784, + "ds": 1.66692863383373 + }, + { + "dhd": 0.05284747196966921, + "ds": 1.2090977174113897 + }, + { + "dhd": 0.49394136891957907, + "ds": 1.3260217281714377 + }, + { + "dhd": -0.6578886051822093, + "ds": 0.08980952380952267 + }, + { + "dhd": -0.7330542218202096, + "ds": 1.2484815966013108 + }, + { + "dhd": -0.17985349979247767, + "ds": 1.0041010013249054 + }, + { + "dhd": 0.2110933332227467, + "ds": 0.718476190476192 + }, + { + "dhd": 0.14189705460416402, + "ds": 1.2858766916828193 + }, + { + "dhd": 0.8230148192682236, + "ds": 0.9092521284586457 + }, + { + "dhd": 0.09190625132479169, + "ds": 0.5837619047619033 + }, + { + "dhd": 0.2671857556956463, + "ds": 0.7527409762489785 + }, + { + "dhd": -0.002743477341864045, + "ds": 1.2581349485459954 + }, + { + "dhd": 0.03844259002118827, + "ds": 1.1683870435119994 + }, + { + "dhd": 0.1973955598498823, + "ds": 1.1226190476190485 + }, + { + "dhd": 1.1760052070951348, + "ds": 0.915881028822475 + }, + { + "dhd": -0.026081041290752838, + "ds": 0.6869107716168559 + }, + { + "dhd": -0.3967728848421944, + "ds": 1.0130976061748402 + }, + { + "dhd": 0.05352026845932667, + "ds": 0.7725707358877328 + }, + { + "dhd": 0.0, + "ds": 0.5854864636291519 + } + ] +} diff --git a/example_20_program.json b/world.json similarity index 94% rename from example_20_program.json rename to world.json index aa71745..45f14d6 100644 --- a/example_20_program.json +++ b/world.json @@ -1,12 +1,5 @@ { - "start_pose": { - "hd": 0.08090409915523009, - "p": [ - 1.8437380952380948, - 16.669857142857147 - ] - }, - "wall_verts": [ + "wall_verts": [ [ 13.24, 0.1 @@ -2201,83 +2194,5 @@ 9.863761904761908 ] ] - ], - "program_controls": [ - { - "dhd": -0.6276929400444784, - "ds": 1.66692863383373 - }, - { - "dhd": 0.05284747196966921, - "ds": 1.2090977174113897 - }, - { - "dhd": 0.49394136891957907, - "ds": 1.3260217281714377 - }, - { - "dhd": -0.6578886051822093, - "ds": 0.08980952380952267 - }, - { - "dhd": -0.7330542218202096, - "ds": 1.2484815966013108 - }, - { - "dhd": -0.17985349979247767, - "ds": 1.0041010013249054 - }, - { - "dhd": 0.2110933332227467, - "ds": 0.718476190476192 - }, - { - "dhd": 0.14189705460416402, - "ds": 1.2858766916828193 - }, - { - "dhd": 0.8230148192682236, - "ds": 0.9092521284586457 - }, - { - "dhd": 0.09190625132479169, - "ds": 0.5837619047619033 - }, - { - "dhd": 0.2671857556956463, - "ds": 0.7527409762489785 - }, - { - "dhd": -0.002743477341864045, - "ds": 1.2581349485459954 - }, - { - "dhd": 0.03844259002118827, - "ds": 1.1683870435119994 - }, - { - "dhd": 0.1973955598498823, - "ds": 1.1226190476190485 - }, - { - "dhd": 1.1760052070951348, - "ds": 0.915881028822475 - }, - { - "dhd": -0.026081041290752838, - "ds": 0.6869107716168559 - }, - { - "dhd": -0.3967728848421944, - "ds": 1.0130976061748402 - }, - { - "dhd": 0.05352026845932667, - "ds": 0.7725707358877328 - }, - { - "dhd": 0.0, - "ds": 0.5854864636291519 - } ] -} +} \ No newline at end of file From c8fc3414ae2e1f71271c79018dbeb086874cad7a Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 23 Oct 2024 15:59:32 -0400 Subject: [PATCH 02/86] Fix typo --- probcomp-localization-tutorial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index d040347..db67265 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -38,7 +38,7 @@ using Gen mkpath("imgs"); # %% [markdown] -# ## Ojbects of reasoning +# ## Sensing a robot's location on a map # # ### The map # From d71404fd8b9034ccf0a7ecab366c68a0d08109e5 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 23 Oct 2024 16:00:15 -0400 Subject: [PATCH 03/86] Add viz plans --- probcomp-localization-tutorial.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index db67265..0ff5de2 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -127,6 +127,9 @@ end; # %% plot_world(world, "Given data", show=(:label, :clutters=true)) +# %% [markdown] +# POSSIBLE VIZ GOAL: user-editable map, clutters, etc. + # %% [markdown] # ### Robot poses # @@ -156,6 +159,9 @@ plot!(Pose([1., 1.], 0.); color=:green3, label="a pose") plot!(Pose([2., 3.], pi/2.); color=:green4, label="another pose") the_plot +# %% [markdown] +# POSSIBLE VIZ GOAL: user can manipulate a pose. (Unconstrained vs. map for now.) + # %% [markdown] # ### Ideal sensors # @@ -241,6 +247,9 @@ for pose in path_integrated end gif(ani, "imgs/ideal_distances.gif", fps=1) +# %% [markdown] +# POSSIBLE VIZ GOAL: as user manipulates pose, sensors get updated. + # %% [markdown] # ### Robot programs # @@ -268,6 +277,9 @@ robot_inputs, T = load_program("robot_program.json") # %% [markdown] # Before we can visualize such a program, we will need to model robot motion. +# %% [markdown] +# POSSIBLE VIZ GOAL: user can manipulate a pose, and independently a control (vecor-like relative to it), with new pose in shadow. + # %% [markdown] # ### Integrate a path from a starting pose and controls # @@ -285,6 +297,9 @@ function integrate_controls_unphysical(robot_inputs) return path end; +# %% [markdown] +# POSSIBLE VIZ GOAL: user can manipulate a whole path, still ignoring walls. + # %% [markdown] # This code has the problem that it is **unphysical**: the walls in no way constrain the robot motion. # @@ -347,6 +362,7 @@ the_plot # Each piece of the model is declared as a *generative function* (GF). The `Gen` library provides two DSLs for constructing GFs: the dynamic DSL using the decorator `@gen` on a function declaration, and the static DSL similarly decorated with `@gen (static)`. The dynamic DSL allows a rather wide class of program structures, whereas the static DSL only allows those for which a certain static analysis may be performed. # # The library offers two basic constructs for use within these DSLs: primitive *distributions* such as "Bernoulli" and "normal", and the sampling operator `~`. Recursively, GFs may sample from other GFs using `~`. +# POSSIBLE VIZ GOAL: user can manipulate a whole path, now obeying walls. # %% [markdown] # ### Components of the motion model From 370ca48263f8dd89dc1049165826c8748346a30b Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 23 Oct 2024 16:01:40 -0400 Subject: [PATCH 04/86] Reorg noisy sensors --- probcomp-localization-tutorial.jl | 150 +++++++++++++++++++----------- 1 file changed, 94 insertions(+), 56 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 0ff5de2..45e8c89 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -250,6 +250,93 @@ gif(ani, "imgs/ideal_distances.gif", fps=1) # %% [markdown] # POSSIBLE VIZ GOAL: as user manipulates pose, sensors get updated. +# %% [markdown] +# ## First steps in modeling uncertainty using Gen +# +# The robot will need to reason about its possible location on the map using incomplete information—in a pun, it must nagivate the uncertainty. The `Gen` system facilitates programming the required probabilistic logic. We will introduce the features of Gen, starting with some simple features now, and bringing in more complex ones later. +# +# Each piece of the model is declared as a *generative function* (GF). The `Gen` library provides two DSLs for constructing GFs: the dynamic DSL using the decorator `@gen` on a function declaration, and the static DSL similarly decorated with `@gen (static)`. The dynamic DSL allows a rather wide class of program structures, whereas the static DSL only allows those for which a certain static analysis may be performed. +# +# The library offers primitive *distributions* such as "Bernoulli" and "normal", and these two DLSs offer the *sampling operator* `~`. GFs may sample from distributions and, recursively, other GFs using `~`. A generative function embodies the *joint distribution* over the latent choices indicated by the sampling operations. + +# %% [markdown] +# ### Creating noisy measurements using `Gen.propose` +# +# We have on hand two kinds of things to model: the robot's pose (and possibly its motion), and its sensor data. We tackle the sensor model first because it is simpler. +# +# Here is its declarative model in `Gen`: + +# %% +@gen function sensor_model(pose, walls, sensor_settings) + for j in 1:sensor_settings.num_angles + sensor_pose = rotate_pose(pose, sensor_angle(sensor_settings, j)) + {j => :distance} ~ normal(sensor_distance(sensor_pose, walls, sensor_settings.box_size), sensor_settings.s_noise) + end +end; + +# %% [markdown] +# This model differs from `ideal_sensor` in the following ways. The ideal sensor measurements themselves are no longer stored into an array, but are instead used as the means of Gaussian distributions (representing our uncertainty about them). *Sampling* from these distributions, using the `~` operator, occurs at the addresses `j => :distance`. +# +# Moreover, the function returns no explicit value. But there is no loss of information here: the model can be run with `Gen.propose` semantics, which performs the required draws from the sampling operations and organizes them according to their address, returning the corresponding *choice map* data structure. The method is called with the GF plus a tuple of arguments. + +# %% +sensor_settings = (sensor_settings..., s_noise = 0.10) +cm, w = propose(sensor_model, (Pose([1., 1.], pi/2.), walls, sensor_settings)) + +# For brevity, show just a subset of the choice map's addresses. +get_selected(cm, select((1:5)...)) +# To instead see the whole trace, uncomment below: +# cm + +# %% [markdown] +# With a little wrapping, one gets a function of the same type as `ideal_sensor`. + +# %% +function noisy_sensor(pose, walls, sensor_settings) + cm, _ = propose(sensor_model, (pose, walls, sensor_settings)) + return [cm[j => :distance] for j in 1:sensor_settings.num_angles] +end; + +# %% [markdown] +# Let's get a picture of the distances returned by the model: + +# %% +ani = Animation() +for pose in path_integrated + frame_plot = frame_from_sensors( + world, "Sensor model (samples)", + path_integrated, :green2, "some path", + pose, noisy_sensor(pose, world.walls, sensor_settings), "noisy sensors", + sensor_settings) + frame(ani, frame_plot) +end +gif(ani, "imgs/noisy_distances.gif", fps=1) + +# %% [markdown] +# POSSIBLE VIZ GOAL: same sensor interactive as before, now with noisy sensors. + +# %% [markdown] +# ### Weighing data with `Gen.assess` +# +# The mathematical picture is as follows. Given the parameters of a pose $z$, walls $w$, and settings $\nu$, one gets a distribution $\text{sensor}(z, w, \nu)$ over certain choice maps. The supporting choice maps are identified with vectors $o = o^{(1:J)} = (o^{(1)}, o^{(2)}, \ldots, o^{(J)})$, where $J := \nu_\text{num\_angles}$, each $o^{(j)}$ independently following a certain normal distribution (depending, notably, on a distance to a wall). Thus the density of $o$ factors into a product of the form +# $$ +# P_\text{sensor}(o) = \prod\nolimits_{j=1}^J P_\text{normal}(o^{(j)}) +# $$ +# where we begin a habit of omitting the parameters to distributions that are implied by the code. +# +# As `propose` draws a sample, it simultaneously computes this density or *score* and returns its logarithm: + +# %% +exp(w) + +# %% [markdown] +# There are many scenarios where one has on hand a full set of data, perhaps via observation, and seeks their score according to the model. One could write a program by hand to do this—but one would simply recapitulate the code for `noisy_sensor`. The difference is that the sampling operations would be replaced with density computations, and instead of storing them in a choice map it would compute their log product. +# +# The construction of a log density function is automated by the `Gen.assess` semantics for generative functions. This method is passed the GF, a tuple of arguments, and a choice map. + +# %% +exp(assess(sensor_model, (Pose([1., 1.], pi/2.), walls, sensor_settings), cm)) + # %% [markdown] # ### Robot programs # @@ -710,62 +797,6 @@ function integrate_controls_noisy(robot_inputs, world_inputs, motion_settings) return get_path(simulate(path_model, (length(robot_inputs.controls), robot_inputs, world_inputs, motion_settings))) end; -# %% [markdown] -# ### Noisy sensors -# -# We assume that the sensor readings are themselves uncertain, say, the distances only knowable up to some noise. We model this as follows. (We satisfy ourselves with writing a loop in the dynamic DSL because we will have no need for incremental recomputation within this model.) - -# %% -@gen function sensor_model(pose, walls, sensor_settings) - for j in 1:sensor_settings.num_angles - sensor_pose = rotate_pose(pose, sensor_angle(sensor_settings, j)) - {j => :distance} ~ normal(sensor_distance(sensor_pose, walls, sensor_settings.box_size), sensor_settings.s_noise) - end -end - -function noisy_sensor(pose, walls, sensor_settings) - trace = simulate(sensor_model, (pose, walls, sensor_settings)) - return [trace[j => :distance] for j in 1:sensor_settings.num_angles] -end; - -# %% [markdown] -# The trace contains many choices corresponding to directions of sensor reading from the input pose. To reduce notebook clutter, here we just show a subset of 5 of them: - -# %% -sensor_settings = (sensor_settings..., s_noise = 0.10) - -trace = simulate(sensor_model, (robot_inputs.start, world.walls, sensor_settings)) -get_selected(get_choices(trace), select((1:5)...)) - -# %% [markdown] -# The mathematical picture is as follows. Given the parameters of a pose $y$, walls $w$, and settings $\nu$, one gets a distribution $\text{sensor}(y, w, \nu)$ over the traces of `sensor_model`, and when $z$ is a motion model trace we set $\text{sensor}(z, w, \nu) := \text{sensor}(\text{retval}(z), w, \nu)$. Its samples are identified with vectors $o = (o^{(1)}, o^{(2)}, \ldots, o^{(J)})$, where $J := \nu_\text{num\_angles}$, each $o^{(j)}$ independently following a certain normal distribution (depending, notably, on the distance from the pose to the nearest wall). Thus the density of $o$ factors into a product of the form -# $$ -# P_\text{sensor}(o) = \prod\nolimits_{j=1}^J P_\text{normal}(o^{(j)}) -# $$ -# where we begin a habit of omitting the parameters to distributions that are implied by the code. -# -# Visualizing the traces of the model is probably more useful for orientation, so we do this now. - -# %% -function frame_from_sensors_trace(world, title, poses, poses_color, poses_label, pose, trace; show_clutters=false) - readings = [trace[j => :distance] for j in 1:sensor_settings.num_angles] - return frame_from_sensors(world, title, poses, poses_color, poses_label, pose, - readings, "trace sensors", get_args(trace)[3]; - show_clutters=show_clutters) -end; - -# %% -ani = Animation() -for pose in path_integrated - trace = simulate(sensor_model, (pose, world.walls, sensor_settings)) - frame_plot = frame_from_sensors_trace( - world, "Sensor model (samples)", - path_integrated, :green2, "some path", - pose, trace) - frame(ani, frame_plot) -end -gif(ani, "imgs/sensor_1.gif", fps=1) - # %% [markdown] # ### Full model # @@ -818,6 +849,13 @@ get_selected(get_choices(trace), selection) # By this point, visualization is essential. # %% +function frame_from_sensors_trace(world, title, poses, poses_color, poses_label, pose, trace; show_clutters=false) + readings = [trace[j => :distance] for j in 1:sensor_settings.num_angles] + return frame_from_sensors(world, title, poses, poses_color, poses_label, pose, + readings, "trace sensors", get_args(trace)[3]; + show_clutters=show_clutters) +end + function frames_from_full_trace(world, title, trace; show_clutters=false) T = get_args(trace)[1] robot_inputs = get_args(trace)[2] From a59259ac884a1c292276279aa927d8d78ceded00 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 23 Oct 2024 16:43:24 -0400 Subject: [PATCH 05/86] Add likelihood optimization and commentary --- probcomp-localization-tutorial.jl | 231 +++++++++++++++--------------- 1 file changed, 119 insertions(+), 112 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 45e8c89..0301ab6 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -63,7 +63,7 @@ function create_segments(verts; loop_around=false) return segs end -function make_world(walls_vec, clutters_vec, start, controls; args...) +function make_world(walls_vec, clutters_vec; args...) walls = create_segments(walls_vec; args...) clutters = [create_segments(clutter; args...) for clutter in clutters_vec] all_points = [walls_vec ; clutters_vec...] @@ -125,7 +125,7 @@ end; # Following this initial display of the given data, we *suppress the clutters* until much later in the notebook. # %% -plot_world(world, "Given data", show=(:label, :clutters=true)) +plot_world(world, "Given data", show=(:label, :clutters)) # %% [markdown] # POSSIBLE VIZ GOAL: user-editable map, clutters, etc. @@ -154,10 +154,9 @@ Plots.plot!(p :: Pose; r=0.5, args...) = plot!(Segment(p.p, step_along_pose(p, r Plots.plot!(ps :: Vector{Pose}; args...) = plot_list!(ps; args...); # %% -the_plot = plot_world(world, "Given data") +plot_world(world, "Given data") plot!(Pose([1., 1.], 0.); color=:green3, label="a pose") plot!(Pose([2., 3.], pi/2.); color=:green4, label="another pose") -the_plot # %% [markdown] # POSSIBLE VIZ GOAL: user can manipulate a pose. (Unconstrained vs. map for now.) @@ -166,6 +165,8 @@ the_plot # ### Ideal sensors # # The robot will need to reason about its location on the map, on the basis of LIDAR-like sensor data. +# +# An "ideal" sensor reports the exact distance cast to a wall. (It is capped off at a max value in case of error.) # %% # A general algorithm to find the interection of a ray and a line segment. @@ -193,9 +194,6 @@ function distance(p :: Pose, seg :: Segment) return (isnothing(s) || s < 0. || !(0. <= t <= 1.)) ? Inf : s end; -# %% [markdown] -# An ideal sensor reports the exact distance cast to a wall. (It is capped off at a max value in case of error.) - # %% function sensor_distance(pose, walls, box_size) d = minimum(distance(pose, seg) for seg in walls) @@ -216,8 +214,6 @@ function ideal_sensor(pose, walls, sensor_settings) end; # %% -# Plot sensor data. - function plot_sensors!(pose, color, readings, label, sensor_settings) plot!([pose.p[1]], [pose.p[2]]; color=color, label=nothing, seriestype=:scatter, markersize=3, markerstrokewidth=0) projections = [step_along_pose(rotate_pose(pose, sensor_angle(sensor_settings, j)), s) for (j, s) in enumerate(readings)] @@ -226,8 +222,8 @@ function plot_sensors!(pose, color, readings, label, sensor_settings) plot!([Segment(pose.p, pr) for pr in projections]; color=:blue, label=nothing, alpha=0.25) end -function frame_from_sensors(world, title, poses, poses_color, poses_label, pose, readings, readings_label, sensor_settings; show_clutters=false) - the_plot = plot_world(world, title; show_clutters=show_clutters) +function frame_from_sensors(world, title, poses, poses_color, poses_label, pose, readings, readings_label, sensor_settings; show=()) + the_plot = plot_world(world, title; show=show) plot!(poses; color=poses_color, label=poses_label) plot_sensors!(pose, poses_color, readings, readings_label, sensor_settings) return the_plot @@ -236,11 +232,16 @@ end; # %% sensor_settings = (fov = 2π*(2/3), num_angles = 41, box_size = world.box_size) +some_poses = [Pose([uniform(world.bounding_box[1], world.bounding_box[2]), + uniform(world.bounding_box[3], world.bounding_box[4])], + uniform(-pi,pi)) + for _ in 1:20] + ani = Animation() -for pose in path_integrated +for pose in some_poses frame_plot = frame_from_sensors( world, "Ideal sensor distances", - path_integrated, :green2, "some path", + pose, :green2, "robot pose", pose, ideal_sensor(pose, world.walls, sensor_settings), "ideal sensors", sensor_settings) frame(ani, frame_plot) @@ -281,12 +282,12 @@ end; # %% sensor_settings = (sensor_settings..., s_noise = 0.10) -cm, w = propose(sensor_model, (Pose([1., 1.], pi/2.), walls, sensor_settings)) +cm, w = propose(sensor_model, (Pose([1., 1.], pi/2.), world.walls, sensor_settings)) +cm +# %% # For brevity, show just a subset of the choice map's addresses. get_selected(cm, select((1:5)...)) -# To instead see the whole trace, uncomment below: -# cm # %% [markdown] # With a little wrapping, one gets a function of the same type as `ideal_sensor`. @@ -302,10 +303,10 @@ end; # %% ani = Animation() -for pose in path_integrated +for pose in some_poses frame_plot = frame_from_sensors( world, "Sensor model (samples)", - path_integrated, :green2, "some path", + pose, :green2, "robot pose", pose, noisy_sensor(pose, world.walls, sensor_settings), "noisy sensors", sensor_settings) frame(ani, frame_plot) @@ -332,10 +333,71 @@ exp(w) # %% [markdown] # There are many scenarios where one has on hand a full set of data, perhaps via observation, and seeks their score according to the model. One could write a program by hand to do this—but one would simply recapitulate the code for `noisy_sensor`. The difference is that the sampling operations would be replaced with density computations, and instead of storing them in a choice map it would compute their log product. # -# The construction of a log density function is automated by the `Gen.assess` semantics for generative functions. This method is passed the GF, a tuple of arguments, and a choice map. +# The construction of a log density function is automated by the `Gen.assess` semantics for generative functions. This method is passed the GF, a tuple of arguments, and a choice map, and returns the log weight plus the return value. # %% -exp(assess(sensor_model, (Pose([1., 1.], pi/2.), walls, sensor_settings), cm)) +exp(assess(sensor_model, (Pose([1., 1.], pi/2.), world.walls, sensor_settings), cm)[1]) + +# %% [markdown] +# ## First steps in probabilistic reasoning +# +# Let pick some measured noisy distance data +# +# > ***LOAD THEM FROM FILE*** +# +# and try to reason about where the robot could have taken them from. + +# %% [markdown] +# POSSIBLE VIZ GOAL: User can start from the loaded data, or move around to grab some noisy sensors. Then, user can move around a separate candiate-match pose, and `assess` the data against it with updated result somehow. + +# %% [markdown] +# What we are exploring here is in Bayesian parlance the *likelihood* of the varying pose. One gets a sense that certain poses were somehow more likely than others, and the modeling of this intuitive sense is called *inference*. +# +# The above exploration points at a strategy of finding the pose (parameter) that optimizes the likelihood (some statistic), a ubiquitous process called *variational inference*. +# +# A subtle but crucial matter must be dealt with. This act of variational inference silently adopts assumptions having highly nontrivial consequences for our inferences, having to do with the issue of *prior* over the parameter. +# +# First we must acknowledge at all, that our reasoning always approaches the question "Where is the robot?" already having some idea of where it is possible for the robot to be. We interpret new information in terms of these assumptions, and they definitely influence the inferences we make. For example, if we were utterly sure the robot were near the center of the map, we would only really consider how the sensor data around such poses, even if there were better fits elsewhere that we did not expect the robot to be. +# +# These assumptions are modeled by a distribution over poses called the *prior*. Then, according to Bayes's Law, the key quantity to examine is not the likelihood density $P_\text{sensor}(o;z)$ but rather the *posterior* density +# $$ +# P_\text{posterior}(z|o) = P_\text{sensor}(o;z) \cdot P_\text{prior}(z) / Z +# $$ +# where $Z > 0$ is a normalizing constant. Likelihood optimization amounts to assuming a prior having $P_\text{prior}(z) \equiv 1$, a so-called "uniform" prior over the parameter space. +# +# The uniform prior may appear to be a natural expression of "complete ignorance", not preferencing any parameter over another. The other thing to acknowledge is that this is not the case: the parameterization of the latents itself embodies preferences among parameter values. Different parametramizations of the latents lead to different "uniform" distributions over them. For example, parameterizing the spread of a univariate normal distribution by its standard deviation and its variance lead to different "uniform" priors over the parameter space, the square map being nonlinear. Thus likelihood optimization's second tacit assumption is a particular parametric representation of the latents space, according to which uniformity occurs. +# +# Summarizing, likelihood optimization does not lead to *intrinsic* inference conclusions, because it relies on a prior that in turn is not intrinsic, but rather depends on how the parameters are presented. Intrinsic conclusions are instead drawn by specifying the prior as a distribution, which has a consistent meaning across parameterizations. +# +# So let us be upfront that we choose the uniform prior relative to the conventional meaning of the pose parameters. Here is a view onto the posterior distribution over poses, given a set of sensor measurements. + +# %% [markdown] +# POSSIBLE VIZ GOAL: Gather the preceding viz into one view: alpha blend all candidate-match poses by likelihood, so only plausible things appear, with the mode highlighted. + +# %% [markdown] +# PUT HERE: expanded discussion of single-pose inference problem. +# +# From optimization/VI to sampling techniques. Reasons: +# * Note *how much information we are throwing away* when passing from the distribution to a single statistic. Something must be afoot. +# * Later inferences depend on the *whole distribution* of parameters. +# * Reducing to (Dirac measures on) the modes breaks compositional validity! +# * The modes might not even be *representative* of the posterior: +# * The mode might not even be where any mass actually accumulates, as in a high-dimensional Gaussian! +# * Mass might be distributed among multiple near-tied modes, unnaturally preferencing one of them. +# * The posterior requires clearly specifying a prior, which (as mentioned above) prevents ambiguities of parameterization. +# +# Replace `argmax` with a resampling operation (SIR). Grid vs. free choice. +# +# Compare to a NN approach. + +# %% [markdown] +# ## Modeling robot motion +# +# As said initially, we are uncertain about the true initial position and subsequent motion of the robot. In order to reason about these, we now specify a model using `Gen`. +# +# Each piece of the model is declared as a *generative function* (GF). The `Gen` library provides two DSLs for constructing GFs: the dynamic DSL using the decorator `@gen` on a function declaration, and the static DSL similarly decorated with `@gen (static)`. The dynamic DSL allows a rather wide class of program structures, whereas the static DSL only allows those for which a certain static analysis may be performed. +# +# The library offers two basic constructs for use within these DSLs: primitive *distributions* such as "Bernoulli" and "normal", and the sampling operator `~`. Recursively, GFs may sample from other GFs using `~`. # %% [markdown] # ### Robot programs @@ -359,7 +421,7 @@ function load_program(file_name) end; # %% -robot_inputs, T = load_program("robot_program.json") +robot_inputs, T = load_program("robot_program.json"); # %% [markdown] # Before we can visualize such a program, we will need to model robot motion. @@ -429,12 +491,10 @@ world_inputs = (walls = world.walls, bounce = 0.1) path_integrated = integrate_controls(robot_inputs, world_inputs); # %% -the_plot = plot_world(world, "Given data", label_world=true, show_clutters=true) +plot_world(world, "Given data", show=(:label,)) plot!(robot_inputs.start; color=:green3, label="given start pose") plot!([pose.p[1] for pose in path_integrated], [pose.p[2] for pose in path_integrated]; color=:green2, label="path from integrating controls", seriestype=:scatter, markersize=3, markerstrokewidth=0) -savefig("imgs/given_data") -the_plot # %% [markdown] # We can also visualize the behavior of the model of physical motion: @@ -442,13 +502,6 @@ the_plot # ![](imgs_stable/physical_motion.gif) # %% [markdown] -# ## Gen basics -# -# As said initially, we are uncertain about the true initial position and subsequent motion of the robot. In order to reason about these, we now specify a model using `Gen`. -# -# Each piece of the model is declared as a *generative function* (GF). The `Gen` library provides two DSLs for constructing GFs: the dynamic DSL using the decorator `@gen` on a function declaration, and the static DSL similarly decorated with `@gen (static)`. The dynamic DSL allows a rather wide class of program structures, whereas the static DSL only allows those for which a certain static analysis may be performed. -# -# The library offers two basic constructs for use within these DSLs: primitive *distributions* such as "Bernoulli" and "normal", and the sampling operator `~`. Recursively, GFs may sample from other GFs using `~`. # POSSIBLE VIZ GOAL: user can manipulate a whole path, now obeying walls. # %% [markdown] @@ -485,25 +538,21 @@ pose_samples = [start_pose_prior(robot_inputs.start, motion_settings) for _ in 1 std_devs_radius = 2.5 * motion_settings.p_noise -the_plot = plot_world(world, "Start pose prior (samples)") +plot_world(world, "Start pose prior (samples)") plot!(make_circle(robot_inputs.start.p, std_devs_radius); color=:red, linecolor=:red, label="95% region", seriestype=:shape, alpha=0.25) plot!(pose_samples; color=:red, label="start pose samples") -savefig("imgs/start_prior") -the_plot # %% N_samples = 50 noiseless_step = robot_inputs.start.p + robot_inputs.controls[1].ds * robot_inputs.start.dp step_samples = [step_model(robot_inputs.start, robot_inputs.controls[1], world_inputs, motion_settings) for _ in 1:N_samples] -the_plot = plot_world(world, "Motion step model model (samples)") +plot_world(world, "Motion step model model (samples)") plot!(robot_inputs.start; color=:black, label="step from here") plot!(make_circle(noiseless_step, std_devs_radius); color=:red, linecolor=:red, label="95% region", seriestype=:shape, alpha=0.25) plot!(step_samples; color=:red, label="step samples") -savefig("imgs/motion_step") -the_plot # %% [markdown] # ### Traces: choice maps @@ -663,7 +712,7 @@ get_selected(get_choices(trace), select((prefix_address(t, :pose) for t in 1:6). # As our truncation of the example trace above might suggest, visualization is an essential practice in ProbComp. We could very well pass the output of the above `integrate_controls_noisy` to the `plot!` function to have a look at it. However, we want to get started early in this notebook on a good habit: writing interpretive code for GFs in terms of their traces rather than their return values. This enables the programmer include the parameters of the model in the display for clarity. # %% -function frames_from_motion_trace(world, title, trace; show_clutters=false) +function frames_from_motion_trace(world, title, trace; show=()) T = get_args(trace)[1] robot_inputs = get_args(trace)[2] poses = get_path(trace) @@ -672,7 +721,7 @@ function frames_from_motion_trace(world, title, trace; show_clutters=false) std_devs_radius = 2.5 * motion_settings.p_noise plots = Vector{Plots.Plot}(undef, T+1) for t in 1:(T+1) - frame_plot = plot_world(world, title; show_clutters=show_clutters) + frame_plot = plot_world(world, title; show=show) plot!(poses[1:t-1]; color=:black, label="past poses") plot!(make_circle(noiseless_steps[t], std_devs_radius); color=:red, linecolor=:red, label="95% region", seriestype=:shape, alpha=0.25) @@ -710,11 +759,9 @@ gif(ani, "imgs/motion.gif", fps=2) trace = simulate(start_pose_prior, (robot_inputs.start, motion_settings)) rotated_trace, rotated_trace_weight_diff, _, _ = update(trace, (robot_inputs.start, motion_settings), (NoChange(), NoChange()), choicemap((:hd, π/2.))) -the_plot = plot_world(world, "Modifying a heading") +plot_world(world, "Modifying a heading") plot!(get_retval(trace); color=:green, label="some pose") plot!(get_retval(rotated_trace); color=:red, label="with heading modified") -savefig("imgs/modify_trace_1") -the_plot # %% [markdown] # The original trace was typical under the pose prior model, whereas the modified one is rather less likely. This is the log of how much unlikelier: @@ -731,11 +778,9 @@ rotated_first_step, rotated_first_step_weight_diff, _, _ = update(trace, (T, robot_inputs, world_inputs, motion_settings), (NoChange(), NoChange(), NoChange(), NoChange()), choicemap((:steps => 1 => :pose => :hd, π/2.))) -the_plot = plot_world(world, "Modifying another heading") +plot_world(world, "Modifying another heading") plot!(get_path(trace); color=:green, label="some path") plot!(get_path(rotated_first_step); color=:red, label="with heading at first step modified") -savefig("imgs/modify_trace_1") -the_plot # %% [markdown] # In the above picture, the green path is apparently missing, having been near-completely overdrawn by the red path. This is because in the execution of the model, the only change in the stochastic choices took place where we specified. In particular, the stochastic choice of pose at the second step was left unchanged. This choice was typical relative to the first step's heading in the old trace, and while it is not impossible relative to the first step's heading in the new trace, it is *far unlikelier* under the mulitvariate normal distribution supporting it: @@ -856,7 +901,7 @@ function frame_from_sensors_trace(world, title, poses, poses_color, poses_label, show_clutters=show_clutters) end -function frames_from_full_trace(world, title, trace; show_clutters=false) +function frames_from_full_trace(world, title, trace; show=()) T = get_args(trace)[1] robot_inputs = get_args(trace)[2] poses = get_path(trace) @@ -866,7 +911,7 @@ function frames_from_full_trace(world, title, trace; show_clutters=false) sensor_readings = get_sensors(trace) plots = Vector{Plots.Plot}(undef, 2*(T+1)) for t in 1:(T+1) - frame_plot = plot_world(world, title; show_clutters=show_clutters) + frame_plot = plot_world(world, title; show=show) plot!(poses[1:t-1]; color=:black, label="past poses") plot!(make_circle(noiseless_steps[t], std_devs_radius); color=:red, linecolor=:red, label="95% region", seriestype=:shape, alpha=0.25) @@ -876,7 +921,7 @@ function frames_from_full_trace(world, title, trace; show_clutters=false) world, title, poses[1:t], :black, nothing, poses[t], sensor_readings[t], "sampled sensors", - settings.sensor_settings; show_clutters=show_clutters) + settings.sensor_settings; show=show) end return plots end; @@ -1091,9 +1136,7 @@ traces_generated_high_deviation = [generate(full_model, (T, full_model_args...), log_likelihoods_high_deviation = [project(trace, selection) for trace in traces_generated_high_deviation] hist_high_deviation = histogram(log_likelihoods_high_deviation; label=nothing, bins=20, title="high dev data, typical paths") -the_plot = plot(hist_low_deviation, hist_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="Log density of observations under the model") -savefig("imgs/likelihoods") -the_plot +plot(hist_low_deviation, hist_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="Log density of observations under the model") # %% [markdown] # ...than the log densities of data typically produced by the complete model run in its natural manner (*compare the scale at the bottom*): @@ -1101,7 +1144,7 @@ the_plot # %% traces_typical = [simulate(full_model, (T, full_model_args...)) for _ in 1:N_samples] log_likelihoods_typical = [project(trace, selection) for trace in traces_typical] -hist_typical = histogram(log_likelihoods_typical; label=nothing, bins=20, title="Log density of observations under the model\ntypical traces") +histogram(log_likelihoods_typical; label=nothing, bins=20, title="Log density of observations under the model\ntypical traces") # %% [markdown] # ### Inference: demonstration @@ -1119,8 +1162,8 @@ include("black_box.jl") # %% # Visualize distributions over traces. -function frame_from_traces(world, title, path, path_label, traces, trace_label; show_clutters=false) - the_plot = plot_world(world, title; show_clutters=show_clutters) +function frame_from_traces(world, title, path, path_label, traces, trace_label; show=()) + the_plot = plot_world(world, title; show=show) if !isnothing(path); plot!(path; label=path_label, color=:brown) end for trace in traces poses = get_path(trace) @@ -1150,9 +1193,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "posterior samples") -the_plot = plot(prior_plot, posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1500,500), layout=grid(1,3), plot_title="Prior vs. approximate posteriors") -savefig("imgs/prior_posterior") -the_plot +plot(prior_plot, posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1500,500), layout=grid(1,3), plot_title="Prior vs. approximate posteriors") # %% [markdown] # All of the traces thus produced have observations constrained to the data. The log densities of the observations under their typical samples show some improvement: @@ -1168,9 +1209,7 @@ traces_posterior_high_deviation = [BlackBox.black_box_inference(full_model, full log_likelihoods_high_deviation = [project(trace, selection) for trace in traces_posterior_high_deviation] hist_high_deviation = histogram(log_likelihoods_high_deviation; label=nothing, bins=20, title="typical data under posterior: high dev data") -the_plot = plot(hist_low_deviation, hist_high_deviation; size=(1500,500), layout=grid(1,2), plot_title="Log likelihood of observations") -savefig("imgs/likelihoods") -the_plot +plot(hist_low_deviation, hist_high_deviation; size=(1500,500), layout=grid(1,2), plot_title="Log likelihood of observations") # %% [markdown] # ## Generic strategies for inference @@ -1403,9 +1442,7 @@ traces = [sampling_importance_resampling(full_model, (T_short, full_model_args.. t2 = now() println("Time elapsed per run (short path): $(value(t2 - t1) / N_samples) ms.") -the_plot = frame_from_traces(world, "SIR (short path)", path_low_deviation[1:(T_short+1)], "path to fit", traces, "SIR samples") -savefig("imgs/SIR_short") -the_plot +frame_from_traces(world, "SIR (short path)", path_low_deviation[1:(T_short+1)], "path to fit", traces, "SIR samples") # %% [markdown] # There are still problems with SIR. SIR already do not provide high-quality traces on short paths. For longer paths, the difficulty only grows, as one blindly searches for a needle in a high-dimensional haystack. And if the proposal $Q$ is unlikely to generate typical samples from the target $P$, one would need a massive number of particles to get a good approximation; in fact, the rate of convergence of SIR towards the target can be super-exponentially slow in $N \to \infty$! @@ -1421,9 +1458,7 @@ traces = [sampling_importance_resampling(full_model, (T, full_model_args...), me t2 = now() println("Time elapsed per run (low dev): $(value(t2 - t1) / N_samples) ms.") -the_plot = frame_from_traces(world, "SIR (low dev)", path_low_deviation, "path to fit", traces, "SIR samples") -savefig("imgs/SIR_final") -the_plot +frame_from_traces(world, "SIR (low dev)", path_low_deviation, "path to fit", traces, "SIR samples") # %% [markdown] # ## Sequential Monte Carlo (SMC) techniques @@ -1496,8 +1531,8 @@ end; # Let's begin by picturing the step-by-step nature of SMC: # %% -function frame_from_weighted_traces(world, title, path, path_label, traces, log_weights, trace_label; show_clutters=false, min_alpha=0.03) - the_plot = plot_world(world, title; show_clutters=show_clutters) +function frame_from_weighted_traces(world, title, path, path_label, traces, log_weights, trace_label; show=(), min_alpha=0.03) + the_plot = plot_world(world, title; show=show) if !isnothing(path) plot!(path; label=path_label, color=:brown) @@ -1518,9 +1553,9 @@ function frame_from_weighted_traces(world, title, path, path_label, traces, log_ return the_plot end -function frame_from_info(world, title, path, path_label, info, info_label; show_clutters=false, min_alpha=0.03) +function frame_from_info(world, title, path, path_label, info, info_label; show=(), min_alpha=0.03) the_plot = frame_from_weighted_traces(world, title * "\nt=$(info.t)|" * info.label, path, path_label, - info.traces, info.log_weights, info_label; show_clutters=show_clutters, min_alpha=min_alpha) + info.traces, info.log_weights, info_label; show=show, min_alpha=min_alpha) if haskey(info, :vizs) viz_label = haskey(info.vizs[1].params, :label) ? info.vizs[1].params.label : nothing for viz in info.vizs @@ -1644,9 +1679,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF+Bootstrap") -savefig("imgs/PF_bootstrap") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF+Bootstrap") # %% [markdown] # The results are already more accurate than blind SIR for only a fraction of the work. @@ -1925,9 +1958,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid") -savefig("imgs/PF_SMCP3_grid") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid") # %% [markdown] # The speed of this approach is already perhaps an issue. The performance is even worse (~15x slower) using the "exact" backwards kernel, with no discernible improvement in inference, as can be seen by uncommenting and running the code below. @@ -1953,9 +1984,7 @@ the_plot # println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") # posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -# the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid") -# savefig("imgs/PF_SMCP3_grid_2") -# the_plot +# plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid") # %% [markdown] # Because that our rejuvenation scheme improves sample quality, perhaps we do not even need to track many particles. Let's try out *one* particle (and vacuous resampling): @@ -1981,9 +2010,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid (1pc)") -savefig("imgs/PF_SMCP3_grid_1") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid (1pc)") # %% [markdown] # Here we see some degredation in the inference quality. But since there is one particle, maybe we can spend a little more effort in the grid search. @@ -2009,9 +2036,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid (1pc)") -savefig("imgs/PF_SMCP3_grid_1_hard") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Grid (1pc)") # %% N_particles = 1 @@ -2116,9 +2141,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Drift") -savefig("imgs/PF_SMCP3_drift") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + SMCP3/Drift") # %% [markdown] # We can compromise between the grid search and jiggling. The idea is to perform a mere two-element search that compares the given point with the random one, or rather to immediately resample one from the pair. This would have a chance of improving sample quality, without spending much time searching scrupulously for the improvement. @@ -2185,9 +2208,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + Boltzmann/Drift") -savefig("imgs/PF_boltzmann_drift") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + Boltzmann/Drift") # %% [markdown] # Similarly, here are a detailed run, followed by the aggregate behavior, using the MH rule: @@ -2225,9 +2246,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + MH/Drift") -savefig("imgs/PF_mh_drift") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + MH/Drift") # %% [markdown] # Thus we can recover most of inference performance to the grid search, at a fraction of the compute cost. @@ -2253,9 +2272,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + MH/Drift (1pc)") -savefig("imgs/PF_mh_drift_1") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="PF + MH/Drift (1pc)") # %% [markdown] # ### Reusable components @@ -2411,9 +2428,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="Controlled PF") -savefig("imgs/PF_controller") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="Controlled PF") # %% [markdown] # ### Backtracking @@ -2627,9 +2642,7 @@ t2 = now() println("Time elapsed per run (high dev): $(value(t2 - t1) / N_samples) ms.") posterior_plot_high_deviation = frame_from_traces(world, "High dev observations", path_high_deviation, "path to be fit", traces, "samples") -the_plot = plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="Backtracking PF") -savefig("imgs/PF_backtrack") -the_plot +plot(posterior_plot_low_deviation, posterior_plot_high_deviation; size=(1000,500), layout=grid(1,2), plot_title="Backtracking PF") # %% [markdown] # ## Improving robustness @@ -2671,9 +2684,7 @@ t1 = now() traces = [particle_filter_rejuv(full_model, T, full_model_args, constraints_askew_start, N_particles, ESS_threshold, drift_mh_kernel, drift_args_schedule) for _ in 1:N_samples] t2 = now() println("Time elapsed per run (askew start): $(value(t2 - t1) / N_samples) ms.") -the_plot = frame_from_traces(world, "Askew start", path_askew_start, "path to be fit", traces, "samples") -savefig("imgs/askew_start") -the_plot +frame_from_traces(world, "Askew start", path_askew_start, "path to be fit", traces, "samples") # %% [markdown] # Or how about if we "kidnapped" the robot: partway through the journey, the robot is paused, moved to another room, then resumed? @@ -2723,9 +2734,7 @@ t1 = now() traces = [particle_filter_rejuv(full_model, T, full_model_args, constraints_kidnapped, N_particles, ESS_threshold, drift_mh_kernel, drift_args_schedule) for _ in 1:N_samples] t2 = now() println("Time elapsed per run (backwards start): $(value(t2 - t1) / N_samples) ms.") -the_plot = frame_from_traces(world, "Kidnapped after t=4", path_kidnapped, "path to be fit", traces, "samples") -savefig("imgs/backwards_start") -the_plot +frame_from_traces(world, "Kidnapped after t=4", path_kidnapped, "path to be fit", traces, "samples") # %% [markdown] # For another challenge, what if our map were modestly inaccurate? @@ -2760,9 +2769,7 @@ t1 = now() traces = [particle_filter_rejuv(full_model, T, full_model_args, constraints_cluttered, N_particles, ESS_threshold, drift_mh_kernel, drift_args_schedule) for _ in 1:N_samples] t2 = now() println("Time elapsed per run (backwards start): $(value(t2 - t1) / N_samples) ms.") -the_plot = frame_from_traces(world, "Cluttered space", path_cluttered, "path to be fit", traces, "samples"; show_clutters=true) -savefig("imgs/backwards_start") -the_plot +frame_from_traces(world, "Cluttered space", path_cluttered, "path to be fit", traces, "samples"; show_clutters=true) # %% [markdown] # We take up the task of accommodating a wider range of phenomena in our modeling and inference. From d058b5b74e744cf4f311f0715bf12f36c67f88b8 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 25 Oct 2024 16:03:07 -0400 Subject: [PATCH 06/86] Cleanup --- probcomp-localization-tutorial.jl | 32 ++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 0301ab6..9f38dfa 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -144,13 +144,10 @@ struct Pose dp :: Vector{Float64} Pose(p :: Vector{Float64}, hd :: Float64) = new(p, rem2pi(hd, RoundNearest), [cos(hd), sin(hd)]) end -Pose(p :: Vector{Float64}, dp :: Vector{Float64}) = Pose(p, atan(dp[2], dp[1])) Base.show(io :: IO, p :: Pose) = Base.show(io, "Pose($(p.p), $(p.hd))") -step_along_pose(p, s) = p.p + s * p.dp -rotate_pose(p, a) = Pose(p.p, p.hd + a) -Plots.plot!(p :: Pose; r=0.5, args...) = plot!(Segment(p.p, step_along_pose(p, r)); arrow=true, args...) +Plots.plot!(p :: Pose; r=0.5, args...) = plot!(Segment(p.p, p.p + r * p.dp); arrow=true, args...) Plots.plot!(ps :: Vector{Pose}; args...) = plot_list!(ps; args...); # %% @@ -171,17 +168,17 @@ plot!(Pose([2., 3.], pi/2.); color=:green4, label="another pose") # %% # A general algorithm to find the interection of a ray and a line segment. -norm(v) = sqrt(sum(v.^2)) +det2(u, v) = u[1] * v[2] - u[2] * v[1] # Return unique s, t such that p + s*u == q + t*v. function solve_lines(p, u, q, v; PARALLEL_TOL=1.0e-10) - det = u[1] * v[2] - u[2] * v[1] + det = det2(u, v) if abs(det) < PARALLEL_TOL return nothing, nothing else pq = p - q - s = (v[1] * pq[2] - v[2] * pq[1]) / det - t = (u[1] * pq[2] - u[2] * pq[1]) / det + s = det2(v, pq) / det + t = det2(u, pq) / det return s, t end end @@ -207,16 +204,18 @@ sensor_angle(sensor_settings, j) = function ideal_sensor(pose, walls, sensor_settings) readings = Vector{Float64}(undef, sensor_settings.num_angles) for j in 1:sensor_settings.num_angles - sensor_pose = rotate_pose(pose, sensor_angle(sensor_settings, j)) + sensor_pose = Pose(pose.p, pose.hd + sensor_angle(sensor_settings, j)) readings[j] = sensor_distance(sensor_pose, walls, sensor_settings.box_size) end return readings end; # %% +project_sensor(pose, angle, s) = let rotated = Pose(pose.p, pose.hd + angle); rotated.p + s * rotated.dp end + function plot_sensors!(pose, color, readings, label, sensor_settings) plot!([pose.p[1]], [pose.p[2]]; color=color, label=nothing, seriestype=:scatter, markersize=3, markerstrokewidth=0) - projections = [step_along_pose(rotate_pose(pose, sensor_angle(sensor_settings, j)), s) for (j, s) in enumerate(readings)] + projections = [project_sensor(pose, sensor_angle(sensor_settings, j), s) for (j, s) in enumerate(readings)] plot!(first.(projections), last.(projections); color=:blue, label=label, seriestype=:scatter, markersize=3, markerstrokewidth=1, alpha=0.25) plot!([Segment(pose.p, pr) for pr in projections]; color=:blue, label=nothing, alpha=0.25) @@ -270,7 +269,7 @@ gif(ani, "imgs/ideal_distances.gif", fps=1) # %% @gen function sensor_model(pose, walls, sensor_settings) for j in 1:sensor_settings.num_angles - sensor_pose = rotate_pose(pose, sensor_angle(sensor_settings, j)) + sensor_pose = Pose(pose.p, pose.hd + sensor_angle(sensor_settings, j)) {j => :distance} ~ normal(sensor_distance(sensor_pose, walls, sensor_settings.box_size), sensor_settings.s_noise) end end; @@ -455,10 +454,13 @@ end; # We employ the following simple physics: when the robot's forward step through a control comes into contact with a wall, that step is interrupted and the robot instead "bounces" a fixed distance from the point of contact in the normal direction to the wall. # %% +norm(v) = sqrt(sum(v.^2)) + function physical_step(p1, p2, hd, world_inputs) - step_pose = Pose(p1, p2 - p1) - (s, i) = findmin(w -> distance(step_pose, w), world_inputs.walls) - if s > norm(p2 - p1) + p21 = p2 - p1 + step_pose = Pose(p1, atan2(p21[2], p21[1])) + s, i = findmin(w -> distance(step_pose, w), world_inputs.walls) + if s > norm(p21) # Step succeeds without contact with walls. return Pose(p2, hd) else @@ -466,7 +468,7 @@ function physical_step(p1, p2, hd, world_inputs) unit_tangent = world_inputs.walls[i].dp / norm(world_inputs.walls[i].dp) unit_normal = [-unit_tangent[2], unit_tangent[1]] # Sign of 2D cross product determines orientation of bounce. - if step_pose.dp[1] * world_inputs.walls[i].dp[2] - step_pose.dp[2] * world_inputs.walls[i].dp[1] < 0. + if det2(step_pose.dp, world_inputs.walls[i].dp) < 0. unit_normal = -unit_normal end return Pose(contact_point + world_inputs.bounce * unit_normal, hd) From 1fe08c2912af5a7dc712c5dd6b4f1c0c3009dca5 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 25 Oct 2024 16:09:36 -0400 Subject: [PATCH 07/86] Bug fix --- probcomp-localization-tutorial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 9f38dfa..e8c48a2 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -458,7 +458,7 @@ norm(v) = sqrt(sum(v.^2)) function physical_step(p1, p2, hd, world_inputs) p21 = p2 - p1 - step_pose = Pose(p1, atan2(p21[2], p21[1])) + step_pose = Pose(p1, atan(p21[2], p21[1])) s, i = findmin(w -> distance(step_pose, w), world_inputs.walls) if s > norm(p21) # Step succeeds without contact with walls. From ce24f6a75b4143a6d5cb90c25bdd4eeca61cb7ff Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 28 Oct 2024 10:05:41 -0400 Subject: [PATCH 08/86] Fix performance regression --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index e8c48a2..561bdeb 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -176,7 +176,7 @@ function solve_lines(p, u, q, v; PARALLEL_TOL=1.0e-10) if abs(det) < PARALLEL_TOL return nothing, nothing else - pq = p - q + pq = (p[1] - q[1], p[2] - q[2]) s = det2(v, pq) / det t = det2(u, pq) / det return s, t @@ -457,7 +457,7 @@ end; norm(v) = sqrt(sum(v.^2)) function physical_step(p1, p2, hd, world_inputs) - p21 = p2 - p1 + p21 = (p2[1] - p1[1], p2[2] - p1[2]) step_pose = Pose(p1, atan(p21[2], p21[1])) s, i = findmin(w -> distance(step_pose, w), world_inputs.walls) if s > norm(p21) From ba9d0d65e1cf76a24073e8f10053ad09f5e80df0 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 28 Oct 2024 11:32:26 -0400 Subject: [PATCH 09/86] Nit --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 561bdeb..51c9939 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2394,8 +2394,8 @@ drift_args_schedule = [0.7^k for k=1:7] # Then try a more determined grid search. grid_n_points = [3, 3, 3] grid_sizes = [.5, .5, π/10] -grid_args_schedule = [(grid_n_points, grid_sizes .* (2/3)^j) for j=0:3] -grid_args_schedule_harder = [(grid_n_points, grid_sizes .* (2/3)^j) for j=0:6] +grid_args_schedule = [(grid_n_points, grid_sizes * (2/3)^j) for j=0:3] +grid_args_schedule_harder = [(grid_n_points, grid_sizes * (2/3)^j) for j=0:6] rejuv_schedule = [(drift_mh_kernel, drift_args_schedule), From cc6792eabd0d6c99c32256e4a691f99ed8c864e6 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 28 Oct 2024 11:45:40 -0400 Subject: [PATCH 10/86] Bug fix --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 51c9939..5be2288 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2751,7 +2751,7 @@ observations_cluttered = get_sensors(trace_cluttered) constraints_cluttered = [constraint_from_sensors(o...) for o in enumerate(observations_cluttered)] ani = Animation() -frames_cluttered = frames_from_full_trace(world, "Cluttered space", trace_cluttered; show_clutters=true) +frames_cluttered = frames_from_full_trace(world, "Cluttered space", trace_cluttered; show=(:label, :clutters)) for frame_plot in frames_cluttered[2:2:end] frame(ani, frame_plot) end @@ -2771,7 +2771,7 @@ t1 = now() traces = [particle_filter_rejuv(full_model, T, full_model_args, constraints_cluttered, N_particles, ESS_threshold, drift_mh_kernel, drift_args_schedule) for _ in 1:N_samples] t2 = now() println("Time elapsed per run (backwards start): $(value(t2 - t1) / N_samples) ms.") -frame_from_traces(world, "Cluttered space", path_cluttered, "path to be fit", traces, "samples"; show_clutters=true) +frame_from_traces(world, "Cluttered space", path_cluttered, "path to be fit", traces, "samples"; show=(:clutters,)) # %% [markdown] # We take up the task of accommodating a wider range of phenomena in our modeling and inference. From 53c6005d04612c7c3bd23a028a2117fe7a2b0486 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 28 Oct 2024 13:50:49 -0400 Subject: [PATCH 11/86] Add comment --- probcomp-localization-tutorial.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 5be2288..0a5db2b 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -52,6 +52,7 @@ mkpath("imgs"); struct Segment p1 :: Vector{Float64} p2 :: Vector{Float64} + # The quantity `p2-p1` is called upon in hot loops, so we cache it. dp :: Vector{Float64} Segment(p1 :: Vector{Float64}, p2 :: Vector{Float64}) = new(p1, p2, p2-p1) end @@ -141,6 +142,7 @@ plot_world(world, "Given data", show=(:label, :clutters)) struct Pose p :: Vector{Float64} hd :: Float64 + # The quantity `[cos(hd), sin(hd)]` is called upon in hot loops, so we cache it. dp :: Vector{Float64} Pose(p :: Vector{Float64}, hd :: Float64) = new(p, rem2pi(hd, RoundNearest), [cos(hd), sin(hd)]) end From f211b07fcd47d51608a0527898291877837d9b9a Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 28 Oct 2024 16:35:34 -0400 Subject: [PATCH 12/86] Clean up world map --- world.json | 2152 +++++++++++----------------------------------------- 1 file changed, 444 insertions(+), 1708 deletions(-) diff --git a/world.json b/world.json index 45f14d6..4b19733 100644 --- a/world.json +++ b/world.json @@ -1,1714 +1,450 @@ { "wall_verts": [ - [ - 13.24, - 0.1 - ], - [ - 13.23, - 0.11 - ], - [ - 13.23, - 5.67 - ], - [ - 13.31, - 5.67 - ], - [ - 13.32, - 5.68 - ], - [ - 13.32, - 5.78 - ], - [ - 13.31, - 5.79 - ], - [ - 12.24, - 5.79 - ], - [ - 12.23, - 5.8 - ], - [ - 12.23, - 9.57 - ], - [ - 14.52, - 9.57 - ], - [ - 14.53, - 9.58 - ], - [ - 14.53, - 9.68 - ], - [ - 14.52, - 9.69 - ], - [ - 8.93, - 9.69 - ], - [ - 8.93, - 9.85 - ], - [ - 8.92, - 9.86 - ], - [ - 8.82, - 9.86 - ], - [ - 8.81, - 9.85 - ], - [ - 8.81, - 9.69 - ], - [ - 5.98, - 9.69 - ], - [ - 5.97, - 9.7 - ], - [ - 5.96, - 9.7 - ], - [ - 5.95, - 9.69 - ], - [ - 5.63, - 9.69 - ], - [ - 5.62, - 9.68 - ], - [ - 5.62, - 9.58 - ], - [ - 5.63, - 9.57 - ], - [ - 5.87, - 9.57 - ], - [ - 5.87, - 5.79 - ], - [ - 1.94, - 5.79 - ], - [ - 1.93, - 5.8 - ], - [ - 1.93, - 9.07 - ], - [ - 4.51, - 9.07 - ], - [ - 4.53, - 9.09 - ], - [ - 4.53, - 9.57 - ], - [ - 4.77, - 9.57 - ], - [ - 4.78, - 9.58 - ], - [ - 4.78, - 9.68 - ], - [ - 4.77, - 9.69 - ], - [ - 4.53, - 9.69 - ], - [ - 4.53, - 10.14 - ], - [ - 4.52, - 10.15 - ], - [ - 4.42, - 10.15 - ], - [ - 4.41, - 10.14 - ], - [ - 4.41, - 9.19 - ], - [ - 1.93, - 9.19 - ], - [ - 1.93, - 11.07 - ], - [ - 4.41, - 11.07 - ], - [ - 4.41, - 11.0 - ], - [ - 4.42, - 10.99 - ], - [ - 4.52, - 10.99 - ], - [ - 4.53, - 11.0 - ], - [ - 4.53, - 11.07 - ], - [ - 5.32, - 11.07 - ], - [ - 5.33, - 11.06 - ], - [ - 5.33, - 11.05 - ], - [ - 5.34, - 11.04 - ], - [ - 5.36, - 11.04 - ], - [ - 5.37, - 11.05 - ], - [ - 5.38, - 11.05 - ], - [ - 5.39, - 11.06 - ], - [ - 5.4, - 11.06 - ], - [ - 5.41, - 11.07 - ], - [ - 5.42, - 11.07 - ], - [ - 5.43, - 11.08 - ], - [ - 5.44, - 11.08 - ], - [ - 5.45, - 11.09 - ], - [ - 5.43, - 11.11 - ], - [ - 5.43, - 11.18 - ], - [ - 5.42, - 11.19 - ], - [ - 5.42, - 13.57 - ], - [ - 6.46, - 13.57 - ], - [ - 6.47, - 13.58 - ], - [ - 6.47, - 13.68 - ], - [ - 6.46, - 13.69 - ], - [ - 4.19, - 13.69 - ], - [ - 4.18, - 13.68 - ], - [ - 4.18, - 13.58 - ], - [ - 4.19, - 13.57 - ], - [ - 5.31, - 13.57 - ], - [ - 5.31, - 11.19 - ], - [ - 1.93, - 11.19 - ], - [ - 1.93, - 13.57 - ], - [ - 3.23, - 13.57 - ], - [ - 3.24, - 13.58 - ], - [ - 3.24, - 13.68 - ], - [ - 3.23, - 13.69 - ], - [ - 1.89, - 13.69 - ], - [ - 1.88, - 13.7 - ], - [ - 1.87, - 13.7 - ], - [ - 1.85, - 13.72 - ], - [ - 1.84, - 13.72 - ], - [ - 1.82, - 13.74 - ], - [ - 1.81, - 13.74 - ], - [ - 1.79, - 13.76 - ], - [ - 1.78, - 13.76 - ], - [ - 1.76, - 13.78 - ], - [ - 1.75, - 13.78 - ], - [ - 1.73, - 13.8 - ], - [ - 1.72, - 13.8 - ], - [ - 1.71, - 13.81 - ], - [ - 1.7, - 13.81 - ], - [ - 1.68, - 13.83 - ], - [ - 1.67, - 13.83 - ], - [ - 1.65, - 13.85 - ], - [ - 1.64, - 13.85 - ], - [ - 1.62, - 13.87 - ], - [ - 1.61, - 13.87 - ], - [ - 1.59, - 13.89 - ], - [ - 1.58, - 13.89 - ], - [ - 1.56, - 13.91 - ], - [ - 1.55, - 13.91 - ], - [ - 1.53, - 13.93 - ], - [ - 1.52, - 13.93 - ], - [ - 1.5, - 13.95 - ], - [ - 1.49, - 13.95 - ], - [ - 1.47, - 13.97 - ], - [ - 1.46, - 13.97 - ], - [ - 1.44, - 13.99 - ], - [ - 1.43, - 13.99 - ], - [ - 1.41, - 14.01 - ], - [ - 1.4, - 14.01 - ], - [ - 1.38, - 14.03 - ], - [ - 1.37, - 14.03 - ], - [ - 1.35, - 14.05 - ], - [ - 1.34, - 14.05 - ], - [ - 1.32, - 14.07 - ], - [ - 1.31, - 14.07 - ], - [ - 1.29, - 14.09 - ], - [ - 1.28, - 14.09 - ], - [ - 1.27, - 14.1 - ], - [ - 1.26, - 14.1 - ], - [ - 1.24, - 14.12 - ], - [ - 1.23, - 14.12 - ], - [ - 1.21, - 14.14 - ], - [ - 1.2, - 14.14 - ], - [ - 1.18, - 14.16 - ], - [ - 1.17, - 14.16 - ], - [ - 1.15, - 14.18 - ], - [ - 1.14, - 14.18 - ], - [ - 1.12, - 14.2 - ], - [ - 1.11, - 14.2 - ], - [ - 1.09, - 14.22 - ], - [ - 1.08, - 14.22 - ], - [ - 1.06, - 14.24 - ], - [ - 1.05, - 14.24 - ], - [ - 1.03, - 14.26 - ], - [ - 1.02, - 14.26 - ], - [ - 1.0, - 14.28 - ], - [ - 0.99, - 14.28 - ], - [ - 0.97, - 14.3 - ], - [ - 0.96, - 14.3 - ], - [ - 0.94, - 14.32 - ], - [ - 0.93, - 14.32 - ], - [ - 0.91, - 14.34 - ], - [ - 0.9, - 14.34 - ], - [ - 0.88, - 14.36 - ], - [ - 0.87, - 14.36 - ], - [ - 0.85, - 14.38 - ], - [ - 0.84, - 14.38 - ], - [ - 0.83, - 14.39 - ], - [ - 0.82, - 14.39 - ], - [ - 0.8, - 14.41 - ], - [ - 0.79, - 14.41 - ], - [ - 0.77, - 14.43 - ], - [ - 0.76, - 14.43 - ], - [ - 0.74, - 14.45 - ], - [ - 0.73, - 14.45 - ], - [ - 0.71, - 14.47 - ], - [ - 0.7, - 14.47 - ], - [ - 0.68, - 14.49 - ], - [ - 0.67, - 14.49 - ], - [ - 0.65, - 14.51 - ], - [ - 0.64, - 14.51 - ], - [ - 0.62, - 14.53 - ], - [ - 0.61, - 14.53 - ], - [ - 0.59, - 14.55 - ], - [ - 0.58, - 14.55 - ], - [ - 0.56, - 14.57 - ], - [ - 0.55, - 14.57 - ], - [ - 0.53, - 14.59 - ], - [ - 0.52, - 14.59 - ], - [ - 0.5, - 14.61 - ], - [ - 0.49, - 14.61 - ], - [ - 0.47, - 14.63 - ], - [ - 0.46, - 14.63 - ], - [ - 0.44, - 14.65 - ], - [ - 0.43, - 14.65 - ], - [ - 0.41, - 14.67 - ], - [ - 0.4, - 14.67 - ], - [ - 0.39, - 14.68 - ], - [ - 0.38, - 14.68 - ], - [ - 0.36, - 14.7 - ], - [ - 0.35, - 14.7 - ], - [ - 0.33, - 14.72 - ], - [ - 0.32, - 14.72 - ], - [ - 0.3, - 14.74 - ], - [ - 0.29, - 14.74 - ], - [ - 0.27, - 14.76 - ], - [ - 0.26, - 14.76 - ], - [ - 0.24, - 14.78 - ], - [ - 0.23, - 14.78 - ], - [ - 0.21, - 14.8 - ], - [ - 0.2, - 14.8 - ], - [ - 0.18, - 14.82 - ], - [ - 0.17, - 14.82 - ], - [ - 0.15, - 14.84 - ], - [ - 0.14, - 14.84 - ], - [ - 0.12, - 14.86 - ], - [ - 0.11, - 14.86 - ], - [ - 0.1, - 14.87 - ], - [ - 0.1, - 17.78 - ], - [ - 0.11, - 17.79 - ], - [ - 0.12, - 17.79 - ], - [ - 0.14, - 17.81 - ], - [ - 0.15, - 17.81 - ], - [ - 0.17, - 17.83 - ], - [ - 0.18, - 17.83 - ], - [ - 0.2, - 17.85 - ], - [ - 0.21, - 17.85 - ], - [ - 0.23, - 17.87 - ], - [ - 0.24, - 17.87 - ], - [ - 0.26, - 17.89 - ], - [ - 0.27, - 17.89 - ], - [ - 0.29, - 17.91 - ], - [ - 0.3, - 17.91 - ], - [ - 0.32, - 17.93 - ], - [ - 0.33, - 17.93 - ], - [ - 0.35, - 17.95 - ], - [ - 0.36, - 17.95 - ], - [ - 0.38, - 17.97 - ], - [ - 0.39, - 17.97 - ], - [ - 0.41, - 17.99 - ], - [ - 0.42, - 17.99 - ], - [ - 0.44, - 18.01 - ], - [ - 0.45, - 18.01 - ], - [ - 0.47, - 18.03 - ], - [ - 0.48, - 18.03 - ], - [ - 0.5, - 18.05 - ], - [ - 0.51, - 18.05 - ], - [ - 0.53, - 18.07 - ], - [ - 0.54, - 18.07 - ], - [ - 0.56, - 18.09 - ], - [ - 0.57, - 18.09 - ], - [ - 0.59, - 18.11 - ], - [ - 0.6, - 18.11 - ], - [ - 0.62, - 18.13 - ], - [ - 0.63, - 18.13 - ], - [ - 0.65, - 18.15 - ], - [ - 0.66, - 18.15 - ], - [ - 0.68, - 18.17 - ], - [ - 0.69, - 18.17 - ], - [ - 0.71, - 18.19 - ], - [ - 0.72, - 18.19 - ], - [ - 0.74, - 18.21 - ], - [ - 0.75, - 18.21 - ], - [ - 0.77, - 18.23 - ], - [ - 0.78, - 18.23 - ], - [ - 0.8, - 18.25 - ], - [ - 0.81, - 18.25 - ], - [ - 0.83, - 18.27 - ], - [ - 0.84, - 18.27 - ], - [ - 0.86, - 18.29 - ], - [ - 0.87, - 18.29 - ], - [ - 0.89, - 18.31 - ], - [ - 0.9, - 18.31 - ], - [ - 0.91, - 18.32 - ], - [ - 0.92, - 18.32 - ], - [ - 0.94, - 18.34 - ], - [ - 0.95, - 18.34 - ], - [ - 0.97, - 18.36 - ], - [ - 0.98, - 18.36 - ], - [ - 1.0, - 18.38 - ], - [ - 1.01, - 18.38 - ], - [ - 1.03, - 18.4 - ], - [ - 1.04, - 18.4 - ], - [ - 1.06, - 18.42 - ], - [ - 1.07, - 18.42 - ], - [ - 1.09, - 18.44 - ], - [ - 1.1, - 18.44 - ], - [ - 1.12, - 18.46 - ], - [ - 1.13, - 18.46 - ], - [ - 1.15, - 18.48 - ], - [ - 1.16, - 18.48 - ], - [ - 1.18, - 18.5 - ], - [ - 1.19, - 18.5 - ], - [ - 1.21, - 18.52 - ], - [ - 1.22, - 18.52 - ], - [ - 1.24, - 18.54 - ], - [ - 1.25, - 18.54 - ], - [ - 1.27, - 18.56 - ], - [ - 1.28, - 18.56 - ], - [ - 1.3, - 18.58 - ], - [ - 1.31, - 18.58 - ], - [ - 1.33, - 18.6 - ], - [ - 1.34, - 18.6 - ], - [ - 1.36, - 18.62 - ], - [ - 1.37, - 18.62 - ], - [ - 1.39, - 18.64 - ], - [ - 1.4, - 18.64 - ], - [ - 1.42, - 18.66 - ], - [ - 1.43, - 18.66 - ], - [ - 1.45, - 18.68 - ], - [ - 1.46, - 18.68 - ], - [ - 1.48, - 18.7 - ], - [ - 1.49, - 18.7 - ], - [ - 1.51, - 18.72 - ], - [ - 1.52, - 18.72 - ], - [ - 1.54, - 18.74 - ], - [ - 1.55, - 18.74 - ], - [ - 1.57, - 18.76 - ], - [ - 1.58, - 18.76 - ], - [ - 1.6, - 18.78 - ], - [ - 1.61, - 18.78 - ], - [ - 1.63, - 18.8 - ], - [ - 1.64, - 18.8 - ], - [ - 1.66, - 18.82 - ], - [ - 1.67, - 18.82 - ], - [ - 1.69, - 18.84 - ], - [ - 1.7, - 18.84 - ], - [ - 1.72, - 18.86 - ], - [ - 1.73, - 18.86 - ], - [ - 1.75, - 18.88 - ], - [ - 1.76, - 18.88 - ], - [ - 1.78, - 18.9 - ], - [ - 1.79, - 18.9 - ], - [ - 1.81, - 18.92 - ], - [ - 1.82, - 18.92 - ], - [ - 1.84, - 18.94 - ], - [ - 1.85, - 18.94 - ], - [ - 1.87, - 18.96 - ], - [ - 8.8, - 18.96 - ], - [ - 8.8, - 15.38 - ], - [ - 8.81, - 15.37 - ], - [ - 8.81, - 13.69 - ], - [ - 8.32, - 13.69 - ], - [ - 8.31, - 13.68 - ], - [ - 8.31, - 13.58 - ], - [ - 8.32, - 13.57 - ], - [ - 8.81, - 13.57 - ], - [ - 8.81, - 10.87 - ], - [ - 8.82, - 10.86 - ], - [ - 8.92, - 10.86 - ], - [ - 8.93, - 10.87 - ], - [ - 8.93, - 11.07 - ], - [ - 11.61, - 11.07 - ], - [ - 11.62, - 11.08 - ], - [ - 11.62, - 11.18 - ], - [ - 11.61, - 11.19 - ], - [ - 8.93, - 11.19 - ], - [ - 8.93, - 15.26 - ], - [ - 12.71, - 15.26 - ], - [ - 12.71, - 11.19 - ], - [ - 12.47, - 11.19 - ], - [ - 12.46, - 11.18 - ], - [ - 12.46, - 11.08 - ], - [ - 12.47, - 11.07 - ], - [ - 13.04, - 11.07 - ], - [ - 13.05, - 11.08 - ], - [ - 13.05, - 11.18 - ], - [ - 13.04, - 11.19 - ], - [ - 12.83, - 11.19 - ], - [ - 12.83, - 15.26 - ], - [ - 15.81, - 15.26 - ], - [ - 15.81, - 11.19 - ], - [ - 13.9, - 11.19 - ], - [ - 13.89, - 11.18 - ], - [ - 13.89, - 11.08 - ], - [ - 13.9, - 11.07 - ], - [ - 15.99, - 11.07 - ], - [ - 16.0, - 11.08 - ], - [ - 16.0, - 11.18 - ], - [ - 15.99, - 11.19 - ], - [ - 15.93, - 11.19 - ], - [ - 15.93, - 15.26 - ], - [ - 18.9, - 15.26 - ], - [ - 18.9, - 11.19 - ], - [ - 16.85, - 11.19 - ], - [ - 16.84, - 11.18 - ], - [ - 16.84, - 11.08 - ], - [ - 16.85, - 11.07 - ], - [ - 17.01, - 11.07 - ], - [ - 17.01, - 11.03 - ], - [ - 17.02, - 11.02 - ], - [ - 17.12, - 11.02 - ], - [ - 17.13, - 11.03 - ], - [ - 17.13, - 11.07 - ], - [ - 18.9, - 11.07 - ], - [ - 18.9, - 9.69 - ], - [ - 17.13, - 9.69 - ], - [ - 17.13, - 10.17 - ], - [ - 17.12, - 10.18 - ], - [ - 17.02, - 10.18 - ], - [ - 17.01, - 10.17 - ], - [ - 17.01, - 9.69 - ], - [ - 16.85, - 9.69 - ], - [ - 16.84, - 9.68 - ], - [ - 16.84, - 9.58 - ], - [ - 16.85, - 9.57 - ], - [ - 18.9, - 9.57 - ], - [ - 18.9, - 5.79 - ], - [ - 15.73, - 5.79 - ], - [ - 15.73, - 9.57 - ], - [ - 15.99, - 9.57 - ], - [ - 16.0, - 9.58 - ], - [ - 16.0, - 9.68 - ], - [ - 15.99, - 9.69 - ], - [ - 15.38, - 9.69 - ], - [ - 15.37, - 9.68 - ], - [ - 15.37, - 9.58 - ], - [ - 15.38, - 9.57 - ], - [ - 15.61, - 9.57 - ], - [ - 15.61, - 7.49 - ], - [ - 12.85, - 7.49 - ], - [ - 12.84, - 7.48 - ], - [ - 12.84, - 7.39 - ], - [ - 12.85, - 7.38 - ], - [ - 15.61, - 7.38 - ], - [ - 15.61, - 5.79 - ], - [ - 14.17, - 5.79 - ], - [ - 14.16, - 5.78 - ], - [ - 14.16, - 5.68 - ], - [ - 14.17, - 5.67 - ], - [ - 18.9, - 5.67 - ], - [ - 18.9, - 0.1 + [ + 13.23, + 0.1 + ], + [ + 13.23, + 5.67 + ], + [ + 13.32, + 5.67 + ], + [ + 13.32, + 5.79 + ], + [ + 12.23, + 5.79 + ], + [ + 12.23, + 9.57 + ], + [ + 14.53, + 9.57 + ], + [ + 14.53, + 9.69 + ], + [ + 8.93, + 9.69 + ], + [ + 8.93, + 9.86 + ], + [ + 8.81, + 9.86 + ], + [ + 8.81, + 9.69 + ], + [ + 5.62, + 9.69 + ], + [ + 5.62, + 9.57 + ], + [ + 5.87, + 9.57 + ], + [ + 5.87, + 5.79 + ], + [ + 1.93, + 5.79 + ], + [ + 1.93, + 9.07 + ], + [ + 4.53, + 9.07 + ], + [ + 4.53, + 9.57 + ], + [ + 4.78, + 9.57 + ], + [ + 4.78, + 9.69 + ], + [ + 4.53, + 9.69 + ], + [ + 4.53, + 10.15 + ], + [ + 4.41, + 10.15 + ], + [ + 4.41, + 9.19 + ], + [ + 1.93, + 9.19 + ], + [ + 1.93, + 11.07 + ], + [ + 4.41, + 11.07 + ], + [ + 4.41, + 10.99 + ], + [ + 4.53, + 10.99 + ], + [ + 4.53, + 11.07 + ], + [ + 5.42, + 11.07 + ], + [ + 5.42, + 13.57 + ], + [ + 6.47, + 13.57 + ], + [ + 6.47, + 13.69 + ], + [ + 4.18, + 13.69 + ], + [ + 4.18, + 13.57 + ], + [ + 5.31, + 13.57 + ], + [ + 5.31, + 11.19 + ], + [ + 1.93, + 11.19 + ], + [ + 1.93, + 13.57 + ], + [ + 3.24, + 13.57 + ], + [ + 3.24, + 13.69 + ], + [ + 1.89, + 13.69 + ], + [ + 0.1, + 14.87 + ], + [ + 0.1, + 17.78 + ], + [ + 1.87, + 18.96 + ], + [ + 8.8, + 18.96 + ], + [ + 8.8, + 13.69 + ], + [ + 8.31, + 13.69 + ], + [ + 8.31, + 13.57 + ], + [ + 8.81, + 13.57 + ], + [ + 8.81, + 10.86 + ], + [ + 8.93, + 10.86 + ], + [ + 8.93, + 11.07 + ], + [ + 11.62, + 11.07 + ], + [ + 11.62, + 11.19 + ], + [ + 8.93, + 11.19 + ], + [ + 8.93, + 15.26 + ], + [ + 12.71, + 15.26 + ], + [ + 12.71, + 11.19 + ], + [ + 12.46, + 11.19 + ], + [ + 12.46, + 11.07 + ], + [ + 13.05, + 11.07 + ], + [ + 13.05, + 11.19 + ], + [ + 12.83, + 11.19 + ], + [ + 12.83, + 15.26 + ], + [ + 15.81, + 15.26 + ], + [ + 15.81, + 11.19 + ], + [ + 13.89, + 11.19 + ], + [ + 13.89, + 11.07 + ], + [ + 16.0, + 11.07 + ], + [ + 16.0, + 11.19 + ], + [ + 15.93, + 11.19 + ], + [ + 15.93, + 15.26 + ], + [ + 18.9, + 15.26 + ], + [ + 18.9, + 11.19 + ], + [ + 16.84, + 11.19 + ], + [ + 16.84, + 11.07 + ], + [ + 17.02, + 11.07 + ], + [ + 17.02, + 11.02 + ], + [ + 17.13, + 11.02 + ], + [ + 17.13, + 11.07 + ], + [ + 18.9, + 11.07 + ], + [ + 18.9, + 9.69 + ], + [ + 17.13, + 9.69 + ], + [ + 17.13, + 10.18 + ], + [ + 17.01, + 10.18 + ], + [ + 17.01, + 9.69 + ], + [ + 16.84, + 9.69 + ], + [ + 16.84, + 9.57 + ], + [ + 18.9, + 9.57 + ], + [ + 18.9, + 5.79 + ], + [ + 15.73, + 5.79 + ], + [ + 15.73, + 9.57 + ], + [ + 16.0, + 9.57 + ], + [ + 16.0, + 9.69 + ], + [ + 15.37, + 9.69 + ], + [ + 15.37, + 9.57 + ], + [ + 15.61, + 9.57 + ], + [ + 15.61, + 7.49 + ], + [ + 12.84, + 7.49 + ], + [ + 12.85, + 7.39 + ], + [ + 15.61, + 7.38 + ], + [ + 15.61, + 5.79 + ], + [ + 14.16, + 5.79 + ], + [ + 14.16, + 5.67 + ], + [ + 18.9, + 5.67 + ], + [ + 18.9, + 0.1 + ], + [ + 13.23, + 0.1 + ] ], - [ - 13.24, - 0.1 - ] - ], "clutter_vert_groups": [ [ [ From 1e44687e2f8dbee5ae30499d9a64f67addbdf26b Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Mon, 28 Oct 2024 22:29:26 -0400 Subject: [PATCH 13/86] Add map discretization into rooms and doorways --- probcomp-localization-tutorial.jl | 29 ++ world_coarse.json | 466 ++++++++++++++++++++++++++++++ 2 files changed, 495 insertions(+) create mode 100644 world_coarse.json diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 0a5db2b..794728d 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2797,3 +2797,32 @@ frame_from_traces(world, "Cluttered space", path_cluttered, "path to be fit", tr # 4. Complete ignorance, embodied as a uniform distribution over the bounding box. # # + +# %% [markdown] +# ## Goal inference +# +# + +# %% [markdown] +# ### Discretization + +# %% +function load_discretization(file_name) + data = parsefile(file_name) + rooms = Dict{String, Vector{Vector{Float64}}}(data["rooms"]) + doorways = [(Tuple(sort(doorway["rooms"])), Vector{Float64}(doorway["p"])) for doorway in data["doorways"]] + return rooms, doorways +end; + +# %% +rooms, doorways = load_discretization("world_coarse.json"); + +# %% +the_plot = plot_world(world, "Discretization: rooms and doorways") +for (i, (name, ps)) in enumerate(rooms) + plot!(Shape(first.(ps), last.(ps)); color=(i+1), label=nothing, markersize=3, markerstrokewidth=1) + midpoint = sum(ps)/length(ps) + annotate!(midpoint[1], midpoint[2], ("$name", :black)) +end +plot!(first.(last.(doorways)), last.(last.(doorways)); seriestype=:scatter, color=:red, label=nothing, markersize=5, markerstrokewidth=1) +the_plot diff --git a/world_coarse.json b/world_coarse.json new file mode 100644 index 0000000..4893fca --- /dev/null +++ b/world_coarse.json @@ -0,0 +1,466 @@ +{ + "rooms": { + "1": [ + [ + 13.23, + 0.1 + ], + [ + 13.23, + 5.67 + ], + [ + 18.9, + 5.67 + ], + [ + 18.9, + 0.1 + ] + ], + "2": [ + [ + 12.23, + 5.79 + ], + [ + 12.23, + 7.38 + ], + [ + 15.61, + 7.38 + ], + [ + 15.61, + 5.79 + ] + ], + "3": [ + [ + 12.23, + 7.49 + ], + [ + 12.23, + 9.57 + ], + [ + 15.61, + 9.57 + ], + [ + 15.61, + 7.49 + ] + ], + "4": [ + [ + 8.93, + 9.69 + ], + [ + 8.93, + 11.07 + ], + [ + 17.02, + 11.07 + ], + [ + 17.01, + 9.69 + ] + ], + "5": [ + [ + 8.81, + 9.69 + ], + [ + 5.62, + 9.69 + ], + [ + 5.42, + 11.07 + ], + [ + 5.42, + 13.57 + ], + [ + 8.81, + 13.57 + ] + ], + "6": [ + [ + 5.62, + 9.69 + ], + [ + 4.53, + 9.69 + ], + [ + 4.53, + 11.07 + ], + [ + 5.42, + 11.07 + ] + ], + "7": [ + [ + 5.87, + 9.57 + ], + [ + 5.87, + 5.79 + ], + [ + 4.53, + 9.07 + ], + [ + 4.53, + 9.57 + ] + ], + "8": [ + [ + 5.87, + 5.79 + ], + [ + 1.93, + 5.79 + ], + [ + 1.93, + 9.07 + ], + [ + 4.53, + 9.07 + ] + ], + "9": [ + [ + 4.41, + 9.19 + ], + [ + 1.93, + 9.19 + ], + [ + 1.93, + 11.07 + ], + [ + 4.41, + 11.07 + ] + ], + "10": [ + [ + 1.89, + 13.69 + ], + [ + 0.1, + 14.87 + ], + [ + 0.1, + 17.78 + ], + [ + 1.87, + 18.96 + ], + [ + 8.8, + 18.96 + ], + [ + 8.8, + 13.69 + ] + ], + "11": [ + [ + 5.31, + 13.57 + ], + [ + 5.31, + 11.19 + ], + [ + 1.93, + 11.19 + ], + [ + 1.93, + 13.57 + ] + ], + "12": [ + [ + 8.93, + 11.19 + ], + [ + 8.93, + 15.26 + ], + [ + 12.71, + 15.26 + ], + [ + 12.71, + 11.19 + ] + ], + "13": [ + [ + 12.83, + 11.19 + ], + [ + 12.83, + 15.26 + ], + [ + 15.81, + 15.26 + ], + [ + 15.81, + 11.19 + ] + ], + "14": [ + [ + 15.93, + 11.19 + ], + [ + 15.93, + 15.26 + ], + [ + 18.9, + 15.26 + ], + [ + 18.9, + 11.19 + ] + ], + "15": [ + [ + 17.13, + 11.07 + ], + [ + 18.9, + 11.07 + ], + [ + 18.9, + 9.69 + ], + [ + 17.13, + 9.69 + ] + ], + "16": [ + [ + 18.9, + 9.57 + ], + [ + 18.9, + 5.79 + ], + [ + 15.73, + 5.79 + ], + [ + 15.73, + 9.57 + ] + ] + }, + "doorways": [ + { + "rooms": [ + "1", + "2" + ], + "p": [ + 13.74, + 5.73 + ] + }, + { + "rooms": [ + "2", + "3" + ], + "p": [ + 12.5375, + 7.44 + ] + }, + { + "rooms": [ + "3", + "4" + ], + "p": [ + 14.95, + 9.63 + ] + }, + { + "rooms": [ + "4", + "5" + ], + "p": [ + 8.87, + 10.36 + ] + }, + { + "rooms": [ + "5", + "6" + ], + "p": [ + 5.52, + 10.38 + ] + }, + { + "rooms": [ + "6", + "7" + ], + "p": [ + 5.2, + 9.63 + ] + }, + { + "rooms": [ + "7", + "8" + ], + "p": [ + 5.2, + 7.43 + ] + }, + { + "rooms": [ + "6", + "9" + ], + "p": [ + 4.47, + 10.57 + ] + }, + { + "rooms": [ + "5", + "10" + ], + "p": [ + 7.39, + 13.63 + ] + }, + { + "rooms": [ + "10", + "11" + ], + "p": [ + 3.71, + 13.63 + ] + }, + { + "rooms": [ + "4", + "12" + ], + "p": [ + 12.04, + 11.13 + ] + }, + { + "rooms": [ + "4", + "13" + ], + "p": [ + 13.47, + 11.13 + ] + }, + { + "rooms": [ + "4", + "14" + ], + "p": [ + 16.42, + 11.13 + ] + }, + { + "rooms": [ + "4", + "15" + ], + "p": [ + 17.0725, + 10.6 + ] + }, + { + "rooms": [ + "4", + "16" + ], + "p": [ + 16.42, + 9.63 + ] + } + ], + "tasks": [ + { + "name": "home", + "p": [ + 1.0, + 1.0 + ], + "r": 1.0 + } + ] +} \ No newline at end of file From ca23df8aff4c33205bdde83cf41c51ddf3a0a873 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Tue, 29 Oct 2024 16:59:10 -0400 Subject: [PATCH 14/86] Nit --- probcomp-localization-tutorial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 794728d..1394201 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2820,7 +2820,7 @@ rooms, doorways = load_discretization("world_coarse.json"); # %% the_plot = plot_world(world, "Discretization: rooms and doorways") for (i, (name, ps)) in enumerate(rooms) - plot!(Shape(first.(ps), last.(ps)); color=(i+1), label=nothing, markersize=3, markerstrokewidth=1) + plot!(first.(ps), last.(ps); seriestype=:shape, color=(i+1), label=nothing, markersize=3, markerstrokewidth=1) midpoint = sum(ps)/length(ps) annotate!(midpoint[1], midpoint[2], ("$name", :black)) end From 8e7fc386e4f30dacb60bf24e6a26d45df65ea176 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Tue, 29 Oct 2024 16:59:25 -0400 Subject: [PATCH 15/86] Add goals and tasks --- goals.json | 46 +++++++++++++++++++++++++++++++ probcomp-localization-tutorial.jl | 27 ++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 goals.json diff --git a/goals.json b/goals.json new file mode 100644 index 0000000..fe77799 --- /dev/null +++ b/goals.json @@ -0,0 +1,46 @@ +{ + "tasks": { + "task1": { + "p": [ + 15.0, + 12.0 + ], + "r": 0.5 + }, + "task2": { + "p": [ + 3.0, + 7.0 + ], + "r": 0.5 + }, + "task3": { + "p": [ + 10.0, + 12.0 + ], + "r": 0.5 + }, + "task4": { + "p": [ + 18.0, + 2.0 + ], + "r": 0.5 + } + }, + "goals": { + "goal1": { + "task1": [] + }, + "goal2": { + "task2": [], + "task3": [ + "task2" + ] + }, + "goal3": { + "task4": [] + } + } +} \ No newline at end of file diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 1394201..74bb86d 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2803,6 +2803,33 @@ frame_from_traces(world, "Cluttered space", path_cluttered, "path to be fit", tr # # +# %% [markdown] +# ### Goals + +# %% +function load_goals(file_name) + data = parsefile(file_name) + tasks = Dict( + task => (Vector{Float64}(info["p"]), Float64(info["r"])) + for (task, info) in data["tasks"]) + goals = Dict( + goal => Dict( + task => Set{String}(dependencies) + for (task, dependencies) in tasks) + for (goal, tasks) in data["goals"]) + return tasks, goals +end; + +# %% +tasks, goals = load_goals("goals.json"); + +# %% +the_plot = plot_world(world, "Tasks") +for (task, (p, r)) in tasks + plot!(make_circle(p, r); label=task, seriestype=:shape) +end +the_plot + # %% [markdown] # ### Discretization From c8b9ac2271168ca75d44cefb1b89a7043a6b51a5 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Tue, 29 Oct 2024 17:08:28 -0400 Subject: [PATCH 16/86] Add goals prior --- probcomp-localization-tutorial.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 74bb86d..86a6c20 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2830,6 +2830,14 @@ for (task, (p, r)) in tasks end the_plot +# %% +@dist labeled_uniform(labels) = labels[uniform_discrete(1, length(labels))] +@dist labeled_categorical(labels, probs) = labels[categorical(probs)] +normalize_logprobs(lps) = lps .- logsumexp(lps); + +@gen goals_prior(goals) = {:goal} ~ labeled_uniform(goals) +get_choices(simulate(goals_prior, (collect(keys(goals)),))) + # %% [markdown] # ### Discretization From 92199f8b500d3a09cca97cac7ab57dd7fc9342ca Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 30 Oct 2024 09:35:25 -0400 Subject: [PATCH 17/86] Wrap around room point lists --- world_coarse.json | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/world_coarse.json b/world_coarse.json index 4893fca..5115d15 100644 --- a/world_coarse.json +++ b/world_coarse.json @@ -16,6 +16,10 @@ [ 18.9, 0.1 + ], + [ + 13.23, + 0.1 ] ], "2": [ @@ -34,6 +38,10 @@ [ 15.61, 5.79 + ], + [ + 12.23, + 5.79 ] ], "3": [ @@ -52,6 +60,10 @@ [ 15.61, 7.49 + ], + [ + 12.23, + 7.49 ] ], "4": [ @@ -70,6 +82,10 @@ [ 17.01, 9.69 + ], + [ + 8.93, + 9.69 ] ], "5": [ @@ -92,6 +108,10 @@ [ 8.81, 13.57 + ], + [ + 8.81, + 9.69 ] ], "6": [ @@ -110,6 +130,10 @@ [ 5.42, 11.07 + ], + [ + 5.62, + 9.69 ] ], "7": [ @@ -128,6 +152,10 @@ [ 4.53, 9.57 + ], + [ + 5.87, + 9.57 ] ], "8": [ @@ -146,6 +174,10 @@ [ 4.53, 9.07 + ], + [ + 5.87, + 5.79 ] ], "9": [ @@ -164,6 +196,10 @@ [ 4.41, 11.07 + ], + [ + 4.41, + 9.19 ] ], "10": [ @@ -190,6 +226,10 @@ [ 8.8, 13.69 + ], + [ + 1.89, + 13.69 ] ], "11": [ @@ -208,6 +248,10 @@ [ 1.93, 13.57 + ], + [ + 5.31, + 13.57 ] ], "12": [ @@ -227,6 +271,10 @@ 12.71, 11.19 ] + ,[ + 8.93, + 11.19 + ] ], "13": [ [ @@ -244,6 +292,10 @@ [ 15.81, 11.19 + ], + [ + 12.83, + 11.19 ] ], "14": [ @@ -262,6 +314,10 @@ [ 18.9, 11.19 + ], + [ + 15.93, + 11.19 ] ], "15": [ @@ -280,6 +336,10 @@ [ 17.13, 9.69 + ], + [ + 17.13, + 11.07 ] ], "16": [ @@ -298,6 +358,10 @@ [ 15.73, 9.57 + ], + [ + 18.9, + 9.57 ] ] }, From a3320b6bf500bee61f7b81aa7c283cf6622b9e9d Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 30 Oct 2024 12:36:09 -0400 Subject: [PATCH 18/86] Manually simplify sensor distance --- probcomp-localization-tutorial.jl | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 86a6c20..4cc7747 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -172,25 +172,19 @@ plot!(Pose([2., 3.], pi/2.); color=:green4, label="another pose") det2(u, v) = u[1] * v[2] - u[2] * v[1] -# Return unique s, t such that p + s*u == q + t*v. -function solve_lines(p, u, q, v; PARALLEL_TOL=1.0e-10) - det = det2(u, v) - if abs(det) < PARALLEL_TOL - return nothing, nothing - else - pq = (p[1] - q[1], p[2] - q[2]) - s = det2(v, pq) / det - t = det2(u, pq) / det - return s, t - end -end +function distance(p :: Pose, seg :: Segment; PARALLEL_TOL=1.0e-10) + # Check if pose is parallel to segment. + det = det2(p.dp, seg.dp) + if abs(det) < PARALLEL_TOL; return Inf end + + # Return unique s, t such that p.p + s * p.dp == seg.p1 + t * seg.dp. + pq = (p.p[1] - seg.p1[1], p.p[2] - seg.p1[2]) + s = det2(seg.dp, pq) / det + t = det2(p.dp, pq) / det -function distance(p :: Pose, seg :: Segment) - s, t = solve_lines(p.p, p.dp, seg.p1, seg.dp) - # Solving failed (including, by fiat, if pose is parallel to segment) iff isnothing(s). - # Pose is oriented away from segment iff s < 0. + # Pose is oriented towards from segment iff s >= 0. # Point of intersection lies on segment (as opposed to the infinite line) iff 0 <= t <= 1. - return (isnothing(s) || s < 0. || !(0. <= t <= 1.)) ? Inf : s + return (s >= 0. && 0. <= t <= 1.) ? s : Inf end; # %% From 20fefd1924197dd029aaced2c5f4034375210bfc Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 30 Oct 2024 12:36:45 -0400 Subject: [PATCH 19/86] Project point onto map discretization --- probcomp-localization-tutorial.jl | 56 +++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 4cc7747..330f60d 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2855,3 +2855,59 @@ for (i, (name, ps)) in enumerate(rooms) end plot!(first.(last.(doorways)), last.(last.(doorways)); seriestype=:scatter, color=:red, label=nothing, markersize=5, markerstrokewidth=1) the_plot + +# %% +# Assumes the polygon is simple (no self-crossings). +# Assumes `polygon` is an ordered list of points with first point repeated at the end. +function point_in_polygon(point, polygon; PARALLEL_TOL=1.0e-10) + # Cast a ray from `point` in x-direction, and return whether ray intersects polygon exactly once. + crossed = 0 + for (s1, s2) in zip(polygon[1:end-1], polygon[2:end]) + det = s2[2] - s1[2] + s1p = (s1[1] - point[1], s1[2] - point[2]) + s2p = (s2[1] - point[1], s2[2] - point[2]) + if abs(det) < PARALLEL_TOL + # If segment is parallel to x-direction, check whether point lies on segment. + if (s1p[1] * s2p[1] <= 0) && (s1p[2] * s2p[2] <= 0); return true end + else + # Otherwise, test whether ray meets segment, + # and increment/decrement count according to crossing orientation. + s = det2(s1p, s2p) / det + t = -s1p[2] / det + if s >= 0. && (0 <= t <= 1.); crossed += (det > 0.) ? 1 : -1 end + end + end + return crossed == 1 || crossed == -1 +end + +function locate(p, rooms, doorways; DOORWAY_RADIUS=1.0) + for (name, ps) in rooms + if point_in_polygon(p, ps); return name end + end + nearest_doorway = argmin(((_, d),) -> sum((p - d).^2), doorways) + return norm(p - nearest_doorway[2]) < DOORWAY_RADIUS ? nearest_doorway : nothing +end; + +# %% +DOORWAY_RADIUS=1.0 + +# some_poses = [Pose([uniform(world.bounding_box[1], world.bounding_box[2]), +# uniform(world.bounding_box[3], world.bounding_box[4])], +# uniform(-pi,pi)) +# for _ in 1:20] + +ani = Animation() +for (i, pose) in enumerate(some_poses) + frame_plot = plot_world(world, "Location discretization") + location = locate(pose.p, rooms, doorways; DOORWAY_RADIUS=DOORWAY_RADIUS) + if isnothing(location) + annotate!(world.center_point..., ("Not located", :red)) + elseif isa(location, String) + plot!(first.(rooms[location]), last.(rooms[location]); seriestype=:shape, color=:green3, label="room $location", markersize=3, markerstrokewidth=1, alpha=0.25) + else + plot!(make_circle(location[2], DOORWAY_RADIUS); label="doorway between rooms $(location[1][1]) and $(location[1][2])", seriestype=:shape, alpha=0.25) + end + plot!(pose; label="pose $i", color=:green) + frame(ani, frame_plot) +end +gif(ani, "imgs/discretization.gif", fps=0.5) From 8eb4c6d40e658609c51e69a0e4bcad8b75aee863 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Wed, 30 Oct 2024 12:39:14 -0400 Subject: [PATCH 20/86] Nits --- probcomp-localization-tutorial.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 330f60d..3c5ca7c 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2891,13 +2891,8 @@ end; # %% DOORWAY_RADIUS=1.0 -# some_poses = [Pose([uniform(world.bounding_box[1], world.bounding_box[2]), -# uniform(world.bounding_box[3], world.bounding_box[4])], -# uniform(-pi,pi)) -# for _ in 1:20] - ani = Animation() -for (i, pose) in enumerate(some_poses) +for pose in some_poses frame_plot = plot_world(world, "Location discretization") location = locate(pose.p, rooms, doorways; DOORWAY_RADIUS=DOORWAY_RADIUS) if isnothing(location) @@ -2907,7 +2902,7 @@ for (i, pose) in enumerate(some_poses) else plot!(make_circle(location[2], DOORWAY_RADIUS); label="doorway between rooms $(location[1][1]) and $(location[1][2])", seriestype=:shape, alpha=0.25) end - plot!(pose; label="pose $i", color=:green) + plot!(pose; label="pose", color=:green) frame(ani, frame_plot) end gif(ani, "imgs/discretization.gif", fps=0.5) From c1ca1eabbcfcc8e8f5955562472d5ea03ec1cd9d Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 14:32:29 -0400 Subject: [PATCH 21/86] Nit --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 3c5ca7c..76fbf62 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2884,8 +2884,8 @@ function locate(p, rooms, doorways; DOORWAY_RADIUS=1.0) for (name, ps) in rooms if point_in_polygon(p, ps); return name end end - nearest_doorway = argmin(((_, d),) -> sum((p - d).^2), doorways) - return norm(p - nearest_doorway[2]) < DOORWAY_RADIUS ? nearest_doorway : nothing + distance, i = findmin(((_, d),) -> norm(p - d), doorways) + return distance < DOORWAY_RADIUS ? doorways[i] : nothing end; # %% From 6d94d5e50963d26aff8a5255f136cfaa2d0752f2 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:24:44 -0400 Subject: [PATCH 22/86] Move up random sample poses --- probcomp-localization-tutorial.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 76fbf62..afee5bc 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -153,9 +153,15 @@ Plots.plot!(p :: Pose; r=0.5, args...) = plot!(Segment(p.p, p.p + r * p.dp); arr Plots.plot!(ps :: Vector{Pose}; args...) = plot_list!(ps; args...); # %% +some_poses = [Pose([uniform(world.bounding_box[1], world.bounding_box[2]), + uniform(world.bounding_box[3], world.bounding_box[4])], + uniform(-pi,pi)) + for _ in 1:20] + plot_world(world, "Given data") -plot!(Pose([1., 1.], 0.); color=:green3, label="a pose") +plot!(Pose([1., 2.], 0.); color=:green3, label="a pose") plot!(Pose([2., 3.], pi/2.); color=:green4, label="another pose") +plot!(some_poses; color=:brown, label="some poses") # %% [markdown] # POSSIBLE VIZ GOAL: user can manipulate a pose. (Unconstrained vs. map for now.) @@ -227,11 +233,6 @@ end; # %% sensor_settings = (fov = 2π*(2/3), num_angles = 41, box_size = world.box_size) -some_poses = [Pose([uniform(world.bounding_box[1], world.bounding_box[2]), - uniform(world.bounding_box[3], world.bounding_box[4])], - uniform(-pi,pi)) - for _ in 1:20] - ani = Animation() for pose in some_poses frame_plot = frame_from_sensors( From 1c2435145e322cded29552d8b43eecf12d988574 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:24:55 -0400 Subject: [PATCH 23/86] Move up `make_circle` --- probcomp-localization-tutorial.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index afee5bc..0154abe 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -94,6 +94,10 @@ world = load_world("world.json"); # It is crucial to picture what we are doing at all times, so we develop plotting code early and often. # %% +unit_circle_xs = [cos(t) for t in LinRange(0., 2pi, 500)] +unit_circle_ys = [sin(t) for t in LinRange(0., 2pi, 500)] +make_circle(p, r) = (p[1] .+ r * unit_circle_xs, p[2] .+ r * unit_circle_ys) + function plot_list!(list; label=nothing, args...) if !isempty(list) plt = plot!(list[1]; label=label, args...) @@ -524,11 +528,6 @@ end; # %% [markdown] # Returning to the code, we can call a GF like a normal function and it will just run stochastically: -# %% -unit_circle_xs = [cos(t) for t in LinRange(0., 2pi, 500)] -unit_circle_ys = [sin(t) for t in LinRange(0., 2pi, 500)] -make_circle(p, r) = (p[1] .+ r * unit_circle_xs, p[2] .+ r * unit_circle_ys); - # %% motion_settings = (p_noise = 0.5, hd_noise = 2π / 360) From 8c02b2894e1c0e0b6f77ccbf0c83af54c5ffb7d3 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:25:03 -0400 Subject: [PATCH 24/86] Vestigial bug --- probcomp-localization-tutorial.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 0154abe..2738578 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -892,11 +892,10 @@ get_selected(get_choices(trace), selection) # By this point, visualization is essential. # %% -function frame_from_sensors_trace(world, title, poses, poses_color, poses_label, pose, trace; show_clutters=false) +function frame_from_sensors_trace(world, title, poses, poses_color, poses_label, pose, trace; show=()) readings = [trace[j => :distance] for j in 1:sensor_settings.num_angles] return frame_from_sensors(world, title, poses, poses_color, poses_label, pose, - readings, "trace sensors", get_args(trace)[3]; - show_clutters=show_clutters) + readings, "trace sensors", get_args(trace)[3]; show=show) end function frames_from_full_trace(world, title, trace; show=()) From e1959a4e3f6c9a523644eff40d5f66ae61dcd2ad Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:25:10 -0400 Subject: [PATCH 25/86] Nit --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 2738578..aae5b73 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2818,8 +2818,8 @@ tasks, goals = load_goals("goals.json"); # %% the_plot = plot_world(world, "Tasks") -for (task, (p, r)) in tasks - plot!(make_circle(p, r); label=task, seriestype=:shape) +for (task, geom) in tasks + plot!(make_circle(geom...); label=task, seriestype=:shape) end the_plot From 92de90704f5d0b9eae1451f4ac132a4b62d52491 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:25:26 -0400 Subject: [PATCH 26/86] Nit --- probcomp-localization-tutorial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index aae5b73..4be474e 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2876,7 +2876,7 @@ function point_in_polygon(point, polygon; PARALLEL_TOL=1.0e-10) if s >= 0. && (0 <= t <= 1.); crossed += (det > 0.) ? 1 : -1 end end end - return crossed == 1 || crossed == -1 + return crossed != 0 end function locate(p, rooms, doorways; DOORWAY_RADIUS=1.0) From 194e47cdefa98661e4790f4bfc2d68b62bddabf9 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:25:53 -0400 Subject: [PATCH 27/86] Rename `locate` to `locate_discrete` --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 4be474e..8aa9c48 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2879,7 +2879,7 @@ function point_in_polygon(point, polygon; PARALLEL_TOL=1.0e-10) return crossed != 0 end -function locate(p, rooms, doorways; DOORWAY_RADIUS=1.0) +function locate_discrete(p, rooms, doorways; DOORWAY_RADIUS=1.0) for (name, ps) in rooms if point_in_polygon(p, ps); return name end end @@ -2893,7 +2893,7 @@ DOORWAY_RADIUS=1.0 ani = Animation() for pose in some_poses frame_plot = plot_world(world, "Location discretization") - location = locate(pose.p, rooms, doorways; DOORWAY_RADIUS=DOORWAY_RADIUS) + location = locate_discrete(pose.p, rooms, doorways; DOORWAY_RADIUS=DOORWAY_RADIUS) if isnothing(location) annotate!(world.center_point..., ("Not located", :red)) elseif isa(location, String) From 003124bbf8823fc4fd1e06d9161459740648e062 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:26:01 -0400 Subject: [PATCH 28/86] Nit --- probcomp-localization-tutorial.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 8aa9c48..555cc62 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2883,8 +2883,8 @@ function locate_discrete(p, rooms, doorways; DOORWAY_RADIUS=1.0) for (name, ps) in rooms if point_in_polygon(p, ps); return name end end - distance, i = findmin(((_, d),) -> norm(p - d), doorways) - return distance < DOORWAY_RADIUS ? doorways[i] : nothing + distance, door = findmin(v -> norm(p - v), doorways) + return distance < DOORWAY_RADIUS ? door : nothing end; # %% From 5293d91a2c7c9091865419b78457d62483bd9ac6 Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:26:22 -0400 Subject: [PATCH 29/86] Refactor `doorways` into Dict --- probcomp-localization-tutorial.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 555cc62..c5cc259 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2835,10 +2835,12 @@ get_choices(simulate(goals_prior, (collect(keys(goals)),))) # ### Discretization # %% +sort2(a, b) = (a <= b) ? (a, b) : (b, a) + function load_discretization(file_name) data = parsefile(file_name) rooms = Dict{String, Vector{Vector{Float64}}}(data["rooms"]) - doorways = [(Tuple(sort(doorway["rooms"])), Vector{Float64}(doorway["p"])) for doorway in data["doorways"]] + doorways = Dict{Tuple{String, String}, Vector{Float64}}(sort2(doorway["rooms"]...) => Vector{Float64}(doorway["p"]) for doorway in data["doorways"]) return rooms, doorways end; @@ -2852,7 +2854,7 @@ for (i, (name, ps)) in enumerate(rooms) midpoint = sum(ps)/length(ps) annotate!(midpoint[1], midpoint[2], ("$name", :black)) end -plot!(first.(last.(doorways)), last.(last.(doorways)); seriestype=:scatter, color=:red, label=nothing, markersize=5, markerstrokewidth=1) +plot!(first.(values(doorways)), last.(values(doorways)); seriestype=:scatter, color=:red, label=nothing, markersize=5, markerstrokewidth=1) the_plot # %% @@ -2899,7 +2901,7 @@ for pose in some_poses elseif isa(location, String) plot!(first.(rooms[location]), last.(rooms[location]); seriestype=:shape, color=:green3, label="room $location", markersize=3, markerstrokewidth=1, alpha=0.25) else - plot!(make_circle(location[2], DOORWAY_RADIUS); label="doorway between rooms $(location[1][1]) and $(location[1][2])", seriestype=:shape, alpha=0.25) + plot!(make_circle(doorways[location], DOORWAY_RADIUS); label="doorway between rooms $(location[1]) and $(location[2])", seriestype=:shape, alpha=0.25) end plot!(pose; label="pose", color=:green) frame(ani, frame_plot) From d8689b3470d42a284edf239fa17d3355fe0f38bd Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:26:39 -0400 Subject: [PATCH 30/86] Change framerate --- probcomp-localization-tutorial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index c5cc259..def2031 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2906,4 +2906,4 @@ for pose in some_poses plot!(pose; label="pose", color=:green) frame(ani, frame_plot) end -gif(ani, "imgs/discretization.gif", fps=0.5) +gif(ani, "imgs/discretization.gif", fps=1) From 35df1f82c0cb52907ef47cc17efbd8358216edef Mon Sep 17 00:00:00 2001 From: Jay Pottharst Date: Fri, 1 Nov 2024 17:27:36 -0400 Subject: [PATCH 31/86] Add coarse planning --- probcomp-localization-tutorial.jl | 77 +++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index def2031..3feaf1d 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -2907,3 +2907,80 @@ for pose in some_poses frame(ani, frame_plot) end gif(ani, "imgs/discretization.gif", fps=1) + +# %% [markdown] +# ### Coarse planning + +# %% +task_locations = Dict( + task => locate_discrete(p, rooms, doorways) + for (task, (p, _)) in tasks) + +# %% +function location_to_room(location, destination_room, rooms, doorways) + if isnothing(location) + return nothing + elseif isa(location, String) + # Breadth-first search for path + paths = [[location]] + visited = Set((location,)) + while !isempty(paths) + found = findfirst(path -> path[end] == destination_room, paths) + if !isnothing(found); return paths[found] end + new_paths = [] + for new_room in filter(!in(visited), keys(rooms)) + for path in filter(path -> sort2(path[end], new_room) in keys(doorways), paths) + push!(visited, new_room) + push!(new_paths, vcat(path, new_room)) + end + end + paths = new_paths + end + return nothing + else + paths = [location_to_room(branch, destination_room, rooms, doorways) for branch in location] + return isnothing(paths[1]) && isnothing(paths[2]) ? nothing : + isnothing(paths[1]) ? vcat(location, paths[2]) : + isnothing(paths[2]) ? vcat(location, paths[1]) : + length(paths[1]) <= length(paths[2]) ? + vcat(location, paths[1]) : vcat(location, paths[2]) + end +end; + +# %% +location_to_room(task_locations["task1"], task_locations["task2"], rooms, doorways) + +# %% +ani = Animation() +for pose in some_poses + location = locate_discrete(pose.p, rooms, doorways) + if isnothing(location) + frame_plot = plot_world(world, "Location discretization") + annotate!(world.center_point..., ("Not located", :red)) + plot!(pose; label="pose", color=:green) + frame(ani, frame_plot) + else + for task in keys(tasks) + frame_plot = plot_world(world, "Location discretization") + path = location_to_room(location, task_locations[task], rooms, doorways) + if isnothing(path) + annotate!(world.center_point..., ("Routing fail", :red)) + else + # annotate!(world.center_point..., ("$path", :blue)) + for node in path + if isa(node, String) + plot!(first.(rooms[node]), last.(rooms[node]); seriestype=:shape, color=:green3, label="room $node", markersize=3, markerstrokewidth=1, alpha=0.25) + else + plot!([doorways[node][1]], [doorways[node][2]]; seriestype=:scatter, color=:red, label="doorway", markersize=5, markerstrokewidth=1) + end + end + end + plot!(make_circle(tasks[task]...); label=task, seriestype=:shape) + plot!(pose; label="pose", color=:green) + frame(ani, frame_plot) + end + end +end +gif(ani, "imgs/discretization.gif", fps=1) + +# %% From 484f2dbaa893496baae1fc32f7683691102e4e08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 21:33:35 +0000 Subject: [PATCH 32/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- goals.json | 2 +- probcomp-localization-tutorial.jl | 2 +- world.json | 2 +- world_coarse.json | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/goals.json b/goals.json index fe77799..5ef2ee0 100644 --- a/goals.json +++ b/goals.json @@ -43,4 +43,4 @@ "task4": [] } } -} \ No newline at end of file +} diff --git a/probcomp-localization-tutorial.jl b/probcomp-localization-tutorial.jl index 3feaf1d..b96a903 100644 --- a/probcomp-localization-tutorial.jl +++ b/probcomp-localization-tutorial.jl @@ -42,7 +42,7 @@ mkpath("imgs"); # # ### The map # -# The tutorial will revolve around modeling the activity of a robot within some space. A large simplifying assumption, which could be lifted with more effort, is that we have been given a *map* of the space, to which the robot will have access. +# The tutorial will revolve around modeling the activity of a robot within some space. A large simplifying assumption, which could be lifted with more effort, is that we have been given a *map* of the space, to which the robot will have access. # # The code below loads such a map, along with other data for later use. Generally speaking, we keep general code and specific examples in separate cells, as signposted here. diff --git a/world.json b/world.json index 4b19733..a069224 100644 --- a/world.json +++ b/world.json @@ -931,4 +931,4 @@ ] ] ] -} \ No newline at end of file +} diff --git a/world_coarse.json b/world_coarse.json index 5115d15..d23b6a2 100644 --- a/world_coarse.json +++ b/world_coarse.json @@ -527,4 +527,4 @@ "r": 1.0 } ] -} \ No newline at end of file +} From a7fc36f9cc9e3e8711678b11c3402a151ec7540f Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Sun, 26 Jan 2025 21:18:18 -0800 Subject: [PATCH 33/86] Sequential Importance Sampling (SIS) for localization (#21) --- .../probcomp-localization-tutorial.ipynb | 2735 +++++++++++++++++ .../probcomp-localization-tutorial.py | 735 +++-- poetry.lock | 2022 ++++++------ pyproject.toml | 4 +- 4 files changed, 4255 insertions(+), 1241 deletions(-) create mode 100644 genjax-localization-tutorial/probcomp-localization-tutorial.ipynb diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb new file mode 100644 index 0000000..476fc5f --- /dev/null +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb @@ -0,0 +1,2735 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# pyright: reportUnusedExpression=false" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# import sys\n", + "\n", + "# if \"google.colab\" in sys.modules:\n", + "# from google.colab import auth # pyright: ignore [reportMissingImports]\n", + "\n", + "# auth.authenticate_user()\n", + "# %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa\n", + "# %pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "# ProbComp Localization Tutorial\n", + "\n", + "This notebook provides an introduction to probabilistic computation (ProbComp). This term refers to a way of expressing probabilistic constructs in a computational paradigm, made precise by a probabilistic programming language (PPL). The programmer can encode their probabilistic intuition for solving a problem into an algorithm. Back-end language work automates the routine but error-prone derivations.\n", + "\n", + "Dependencies are specified in pyproject.toml." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "# Global setup code\n", + "\n", + "import json\n", + "import genstudio.plot as Plot\n", + "import itertools\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import genjax\n", + "from urllib.request import urlopen\n", + "from genjax import SelectionBuilder as S\n", + "from genjax import ChoiceMapBuilder as C\n", + "from genjax.typing import FloatArray, PRNGKey\n", + "from penzai import pz\n", + "from typing import Any, Iterable\n", + "\n", + "import os\n", + "\n", + "html = Plot.Hiccup\n", + "Plot.configure({\"display_as\": \"html\", \"dev\": False})\n", + "\n", + "# Ensure a location for image generation.\n", + "os.makedirs(\"imgs\", exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## The \"real world\"\n", + "\n", + "We assume given\n", + "* a map of a space, together with\n", + "* some clutters that sometimes unexpectedly exist in that space.\n", + "\n", + "We also assume given a description of a robot's behavior via\n", + "* an estimated initial pose (position + heading), and\n", + "* a program of controls (advance distance, followed by rotate heading).\n", + "\n", + "*In addition to the uncertainty in the initial pose, we are uncertain about the true execution of the motion of the robot.*\n", + "\n", + "Below, we will also introduce sensors." + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "### Load map and robot data\n", + "\n", + "Generally speaking, we keep general code and specific examples in separate cells, as signposted here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "# General code here\n", + "\n", + "\n", + "@pz.pytree_dataclass\n", + "class Pose(genjax.PythonicPytree):\n", + " p: FloatArray\n", + " hd: FloatArray\n", + "\n", + " def __repr__(self):\n", + " return f\"Pose(p={self.p}, hd={self.hd})\"\n", + "\n", + " def dp(self):\n", + " return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)])\n", + "\n", + " def step_along(self, s: float) -> \"Pose\":\n", + " \"\"\"\n", + " Moves along the direction of the pose by a scalar and returns a new Pose.\n", + "\n", + " Args:\n", + " s (float): The scalar distance to move along the pose's direction.\n", + "\n", + " Returns:\n", + " Pose: A new Pose object representing the moved position.\n", + " \"\"\"\n", + " dp = self.dp()\n", + " new_p = self.p + s * dp\n", + " return Pose(new_p, hd=self.hd)\n", + "\n", + " def apply_control(self, control):\n", + " return Pose(self.p + control.ds * self.dp(), self.hd + control.dhd)\n", + "\n", + " def rotate(self, a: float) -> \"Pose\":\n", + " \"\"\"\n", + " Rotates the pose by angle 'a' (in radians) and returns a new Pose.\n", + "\n", + " Args:\n", + " a (float): The angle in radians to rotate the pose.\n", + "\n", + " Returns:\n", + " Pose: A new Pose object representing the rotated pose.\n", + " \"\"\"\n", + " new_hd = self.hd + a\n", + " return Pose(self.p, hd=new_hd)\n", + "\n", + "\n", + "@pz.pytree_dataclass\n", + "class Control(genjax.PythonicPytree):\n", + " ds: FloatArray\n", + " dhd: FloatArray\n", + "\n", + "\n", + "def create_segments(points):\n", + " \"\"\"\n", + " Given an array of points of shape (N, 2), return an array of\n", + " pairs of points. [p_1, p_2, p_3, ...] -> [[p_1, p_2], [p_2, p_3], ...]\n", + " where each p_i is [x_i, y_i]\n", + " \"\"\"\n", + " return jnp.stack([points, jnp.roll(points, shift=-1, axis=0)], axis=1)\n", + "\n", + "\n", + "def make_world(wall_verts, clutters_vec, start, controls):\n", + " \"\"\"\n", + " Constructs the world by creating segments for walls and clutters, calculates the bounding box, and prepares the simulation parameters.\n", + "\n", + " Args:\n", + " - wall_verts (list of list of float): A list of 2D points representing the vertices of walls.\n", + " - clutters_vec (list of list of list of float): A list where each element is a list of 2D points representing the vertices of a clutter.\n", + " - start (Pose): The starting pose of the robot.\n", + " - controls (list of Control): Control actions for the robot.\n", + "\n", + " Returns:\n", + " - tuple: A tuple containing the world configuration, the initial state, and the total number of control steps.\n", + " \"\"\"\n", + " # Create segments for walls and clutters\n", + " walls = create_segments(wall_verts)\n", + " clutters = jax.vmap(create_segments)(clutters_vec)\n", + "\n", + " # Combine all points for bounding box calculation\n", + " all_points = jnp.vstack(\n", + " (jnp.array(wall_verts), jnp.concatenate(clutters_vec), jnp.array([start.p]))\n", + " )\n", + " x_min, y_min = jnp.min(all_points, axis=0)\n", + " x_max, y_max = jnp.max(all_points, axis=0)\n", + "\n", + " # Calculate bounding box, box size, and center point\n", + " bounding_box = (x_min, x_max, y_min, y_max)\n", + " box_size = max(x_max - x_min, y_max - y_min)\n", + " center_point = Pose(\n", + " jnp.array([(x_min + x_max) / 2.0, (y_min + y_max) / 2.0]), jnp.array(0.0)\n", + " )\n", + "\n", + " # How bouncy the walls are in this world.\n", + " bounce = 0.1\n", + "\n", + " # We prepend a zero-effect control step to the control array. This allows\n", + " # numerous simplifications in what follows: we can consider the initial\n", + " # pose uncertainty as well as each subsequent step to be the same function\n", + " # of current position and control step.\n", + " noop_control = Control(jnp.array(0.0), jnp.array(0.0))\n", + " controls = controls.prepend(noop_control)\n", + "\n", + " # Determine the total number of control steps\n", + " T = len(controls.ds)\n", + "\n", + " return (\n", + " {\n", + " \"walls\": walls,\n", + " \"wall_verts\": wall_verts,\n", + " \"clutters\": clutters,\n", + " \"bounding_box\": bounding_box,\n", + " \"box_size\": box_size,\n", + " \"center_point\": center_point,\n", + " \"bounce\": bounce,\n", + " },\n", + " {\"start\": start, \"controls\": controls},\n", + " T,\n", + " )\n", + "\n", + "\n", + "def load_world(file_name):\n", + " \"\"\"\n", + " Loads the world configuration from a specified file and constructs the world.\n", + "\n", + " Args:\n", + " - file_name (str): The name of the file containing the world configuration.\n", + "\n", + " Returns:\n", + " - tuple: A tuple containing the world configuration, the initial state, and the total number of control steps.\n", + " \"\"\"\n", + " with urlopen(\n", + " \"https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json\"\n", + " ) as url:\n", + " data = json.load(url)\n", + "\n", + " walls_vec = jnp.array(data[\"wall_verts\"])\n", + " clutters_vec = jnp.array(data[\"clutter_vert_groups\"])\n", + " start = Pose(\n", + " jnp.array(data[\"start_pose\"][\"p\"], dtype=float),\n", + " jnp.array(data[\"start_pose\"][\"hd\"], dtype=float),\n", + " )\n", + "\n", + " cs = jnp.array([[c[\"ds\"], c[\"dhd\"]] for c in data[\"program_controls\"]])\n", + " controls = Control(cs[:, 0], cs[:, 1])\n", + "\n", + " return make_world(walls_vec, clutters_vec, start, controls)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "# Specific example code here\n", + "\n", + "world, robot_inputs, T = load_world(\"../example_20_program.json\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "### Integrate a path from a starting pose and controls\n", + "\n", + "If the motion of the robot is determined in an ideal manner by the controls, then we may simply integrate to determine the resulting path. Naïvely, this results in the following." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "def integrate_controls_unphysical(robot_inputs):\n", + " \"\"\"\n", + " Integrates the controls to generate a path from the starting pose.\n", + "\n", + " This function takes the initial pose and a series of control steps (ds for distance, dhd for heading change)\n", + " and computes the resulting path by applying each control step sequentially.\n", + "\n", + " Args:\n", + " - robot_inputs (dict): A dictionary containing the starting pose and control steps.\n", + "\n", + " Returns:\n", + " - list: A list of Pose instances representing the path taken by applying the controls.\n", + " \"\"\"\n", + " return jax.lax.scan(\n", + " lambda pose, control: (\n", + " pose.apply_control(control),\n", + " pose.apply_control(control),\n", + " ),\n", + " robot_inputs[\"start\"],\n", + " robot_inputs[\"controls\"],\n", + " )[1]" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "This code has the problem that it is **unphysical**: the walls in no way constrain the robot motion.\n", + "\n", + "We employ the following simple physics: when the robot's forward step through a control comes into contact with a wall, that step is interrupted and the robot instead \"bounces\" a fixed distance from the point of contact in the normal direction to the wall." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [ + "def solve_lines(p, u, q, v, PARALLEL_TOL=1.0e-10):\n", + " \"\"\"\n", + " Solves for the intersection of two lines defined by points and direction vectors.\n", + "\n", + " Args:\n", + " - p, u: Point and direction vector defining the first line.\n", + " - q, v: Point and direction vector defining the second line.\n", + " - PARALLEL_TOL: Tolerance for determining if lines are parallel.\n", + "\n", + " Returns:\n", + " - s, t: Parameters for the line equations at the intersection point.\n", + " Returns [-inf, -inf] if lines are parallel.\n", + " \"\"\"\n", + " det = u[0] * v[1] - u[1] * v[0]\n", + " return jnp.where(\n", + " jnp.abs(det) < PARALLEL_TOL,\n", + " jnp.array([-jnp.inf, -jnp.inf]),\n", + " jnp.array(\n", + " [\n", + " (v[0] * (p[1] - q[1]) - v[1] * (p[0] - q[0])) / det,\n", + " (u[1] * (q[0] - p[0]) - u[0] * (q[1] - p[1])) / det,\n", + " ]\n", + " ),\n", + " )\n", + "\n", + "\n", + "def distance(p, seg):\n", + " \"\"\"\n", + " Computes the distance from a pose to a segment, considering the pose's direction.\n", + "\n", + " Args:\n", + " - p: The Pose object.\n", + " - seg: The Segment object.\n", + "\n", + " Returns:\n", + " - float: The distance to the segment. Returns infinity if no valid intersection is found.\n", + " \"\"\"\n", + " a = solve_lines(p.p, p.dp(), seg[0], seg[1] - seg[0])\n", + " return jnp.where(\n", + " (a[0] >= 0.0) & (a[1] >= 0.0) & (a[1] <= 1.0),\n", + " a[0],\n", + " jnp.inf,\n", + " )\n", + "\n", + "\n", + "def compute_wall_normal(wall_direction) -> FloatArray:\n", + " normalized_wall_direction = wall_direction / jnp.linalg.norm(wall_direction)\n", + " return jnp.array([-normalized_wall_direction[1], normalized_wall_direction[0]])\n", + "\n", + "\n", + "@jax.jit\n", + "def physical_step(p1: FloatArray, p2: FloatArray, hd):\n", + " \"\"\"\n", + " Computes a physical step considering wall collisions and bounces.\n", + "\n", + " Args:\n", + " - p1, p2: Start and end points of the step.\n", + " - hd: Heading direction.\n", + "\n", + " Returns:\n", + " - Pose: The new pose after taking the step, considering potential wall collisions.\n", + " \"\"\"\n", + " # Calculate step direction and length\n", + " step_direction = p2 - p1\n", + " step_length = jnp.linalg.norm(step_direction)\n", + " step_pose = Pose(p1, jnp.arctan2(step_direction[1], step_direction[0]))\n", + "\n", + " # Calculate distances to all walls\n", + " distances = jax.vmap(distance, in_axes=(None, 0))(step_pose, world[\"walls\"])\n", + "\n", + " # Find the closest wall\n", + " closest_wall_index = jnp.argmin(distances)\n", + " closest_wall_distance = distances[closest_wall_index]\n", + " closest_wall = world[\"walls\"][closest_wall_index]\n", + "\n", + " # Calculate wall normal and collision point\n", + " wall_direction = closest_wall[1] - closest_wall[0]\n", + " wall_normal = compute_wall_normal(wall_direction)\n", + " collision_point = p1 + closest_wall_distance * step_pose.dp()\n", + "\n", + " # Ensure wall_normal points away from the robot's direction\n", + " wall_normal = jnp.where(\n", + " jnp.dot(step_pose.dp(), wall_normal) > 0, -wall_normal, wall_normal\n", + " )\n", + "\n", + " # Calculate bounce off point\n", + " bounce_off_point: FloatArray = collision_point + world[\"bounce\"] * wall_normal\n", + "\n", + " # Determine final position based on whether a collision occurred\n", + " final_position = jnp.where(\n", + " closest_wall_distance >= step_length, p2, bounce_off_point\n", + " )\n", + "\n", + " return Pose(final_position, hd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "def integrate_controls_physical(robot_inputs):\n", + " \"\"\"\n", + " Integrates controls to generate a path, taking into account physical interactions with walls.\n", + "\n", + " Args:\n", + " - robot_inputs: Dictionary containing the starting pose and control steps.\n", + "\n", + " Returns:\n", + " - Pose: A Pose object representing the path taken by applying the controls.\n", + " \"\"\"\n", + " return jax.lax.scan(\n", + " lambda pose, control: (\n", + " new_pose := physical_step(\n", + " pose.p, pose.p + control.ds * pose.dp(), pose.hd + control.dhd\n", + " ),\n", + " new_pose,\n", + " ),\n", + " robot_inputs[\"start\"],\n", + " robot_inputs[\"controls\"],\n", + " )[1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "path_integrated = integrate_controls_physical(robot_inputs)" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "### Plot such data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "def pose_plot(p, fill: str | Any = \"black\", **opts):\n", + " z = opts.get(\"zoom\", 1.0)\n", + " r = z * 0.15\n", + " wing_opacity = opts.get(\"opacity\", 0.3)\n", + " WING_ANGLE, WING_LENGTH = jnp.pi / 12, z * opts.get(\"wing_length\", 0.6)\n", + " center = p.p\n", + " angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1])\n", + "\n", + " # Calculate wing endpoints\n", + " wing_ends = [\n", + " center - WING_LENGTH * jnp.array([jnp.cos(angle + a), jnp.sin(angle + a)])\n", + " for a in [WING_ANGLE, -WING_ANGLE]\n", + " ]\n", + "\n", + " # Draw wings\n", + " wings = Plot.line(\n", + " [wing_ends[0], center, wing_ends[1]],\n", + " strokeWidth=opts.get(\"strokeWidth\", 2),\n", + " stroke=fill,\n", + " opacity=wing_opacity,\n", + " )\n", + "\n", + " # Draw center dot\n", + " dot = Plot.ellipse([center], fill=fill, **({\"r\": r} | opts))\n", + "\n", + " return wings + dot\n", + "\n", + "\n", + "walls_plot = Plot.new(\n", + " Plot.line(\n", + " world[\"wall_verts\"],\n", + " strokeWidth=2,\n", + " stroke=\"#ccc\",\n", + " ),\n", + " {\"margin\": 0, \"inset\": 50, \"width\": 500, \"axis\": None, \"aspectRatio\": 1},\n", + " Plot.domain([0, 20]),\n", + ")\n", + "# Plot the world with walls only\n", + "world_plot = Plot.new(\n", + " walls_plot, Plot.frame(strokeWidth=4, stroke=\"#ddd\"), Plot.color_legend()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot of the starting pose of the robot\n", + "starting_pose_plot = pose_plot(\n", + " robot_inputs[\"start\"],\n", + " fill=Plot.constantly(\"given start pose\"),\n", + ") + Plot.color_map({\"given start pose\": \"blue\"})\n", + "\n", + "# Plot of the path from integrating controls\n", + "controls_path_plot = Plot.dot(\n", + " [pose.p for pose in path_integrated],\n", + " fill=Plot.constantly(\"path from integrating controls\"),\n", + ") + Plot.color_map({\"path from integrating controls\": \"#0c0\"})\n", + "\n", + "# Plot of the clutters\n", + "clutters_plot = (\n", + " [Plot.line(c[:, 0], fill=Plot.constantly(\"clutters\")) for c in world[\"clutters\"]],\n", + " Plot.color_map({\"clutters\": \"magenta\"}),\n", + ")\n", + "\n", + "(\n", + " world_plot\n", + " + controls_path_plot\n", + " + starting_pose_plot\n", + " + clutters_plot\n", + " + {\"title\": \"Given Data\"}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "\n", + "TODO(jay): Include code visualization" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "## Gen basics\n", + "\n", + "As said initially, we are uncertain about the true initial position and subsequent motion of the robot. In order to reason about these, we now specify a model using `Gen`.\n", + "\n", + "Each piece of the model is declared as a *generative function* (GF). The `GenJAX` library provides a DSL for constructing GFs signalled by the use of the `@genjax.gen` decorator on an ordinary Python function. As we shall see, in order for the functions we write to be compilable for a GPU, there are certain constraints we must follow in the use of control flow, which we will discuss soon.\n", + "\n", + "\n", + "The library offers two basic constructs for use within the DSL: primitive *distributions* such as \"bernoulli\" and \"normal\", and the *sampling operator* `@`. Recursively, GFs may sample from other GFs using `@`." + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "### Components of the motion model\n", + "\n", + "We start with the two building blocks: the starting pose and individual steps of motion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO(colin,jay): Originally, we passed motion_settings['p_noise'] ** 2 to\n", + "# mv_normal_diag, but I think this squares the scale twice. TFP documenentation\n", + "# - https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiag\n", + "# states that: scale = diag(scale_diag); covariance = scale @ scale.T. The second\n", + "# equation will have the effect of squaring the individual diagonal scales.\n", + "\n", + "\n", + "@genjax.gen\n", + "def step_proposal(motion_settings, start, control):\n", + " p = (\n", + " genjax.mv_normal_diag(\n", + " start.p + control.ds * start.dp(), motion_settings[\"p_noise\"] * jnp.ones(2)\n", + " )\n", + " @ \"p\"\n", + " )\n", + " hd = genjax.normal(start.hd + control.dhd, motion_settings[\"hd_noise\"]) @ \"hd\"\n", + " return physical_step(start.p, p, hd)\n", + "\n", + "\n", + "# Set the motion settings\n", + "default_motion_settings = {\"p_noise\": 0.5, \"hd_noise\": 2 * jnp.pi / 36.0}" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "Returning to the code: we find that our function cannot be called directly--it is now a stochastic function!--so we must supply a source of randomness, in the form of a *key*, followed by a tuple of the function's expected arguments, illustrated here:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)\n", + "step_proposal.simulate(\n", + " key, (default_motion_settings, robot_inputs[\"start\"], robot_inputs[\"controls\"][0])\n", + ").get_retval()" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "\n", + "We called `get_retval()` on the result, which is a *trace*, a data structure with which we become much more familiar before we are done." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate points on the unit circle\n", + "theta = jnp.linspace(0, 2 * jnp.pi, 500)\n", + "unit_circle_xs = jnp.cos(theta)\n", + "unit_circle_ys = jnp.sin(theta)\n", + "\n", + "\n", + "# Function to create a circle with center p and radius r\n", + "def make_circle(p, r):\n", + " return (p[0] + r * unit_circle_xs, p[1] + r * unit_circle_ys)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate N_samples of starting poses from the prior\n", + "N_samples = 50\n", + "key, sub_key = jax.random.split(key)\n", + "pose_samples = jax.vmap(step_proposal.simulate, in_axes=(0, None))(\n", + " jax.random.split(sub_key, N_samples),\n", + " (default_motion_settings, robot_inputs[\"start\"], robot_inputs[\"controls\"][0]),\n", + ")\n", + "\n", + "\n", + "def poses_to_plots(poses: Iterable[Pose], **plot_opts):\n", + " return [pose_plot(pose, **plot_opts) for pose in poses]\n", + "\n", + "\n", + "# Plot the world, starting pose samples, and 95% confidence region\n", + "# Calculate the radius of the 95% confidence region\n", + "def confidence_circle(pose: Pose, p_noise: float):\n", + " return Plot.scaled_circle(\n", + " *pose.p,\n", + " fill=Plot.constantly(\"95% confidence region\"),\n", + " r=2.5 * p_noise,\n", + " ) + Plot.color_map({\"95% confidence region\": \"rgba(255,0,0,0.25)\"})\n", + "\n", + "\n", + "(\n", + " world_plot\n", + " + poses_to_plots([robot_inputs[\"start\"]], fill=Plot.constantly(\"step from here\"))\n", + " + confidence_circle(\n", + " robot_inputs[\"start\"].apply_control(robot_inputs[\"controls\"][0]),\n", + " default_motion_settings[\"p_noise\"],\n", + " )\n", + " + poses_to_plots(pose_samples.get_retval(), fill=Plot.constantly(\"step samples\"))\n", + " + Plot.color_map({\"step from here\": \"#000\", \"step samples\": \"red\"})\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "26", + "metadata": {}, + "source": [ + "### Traces: choice maps\n", + "\n", + "The actual return value of `step_model.simulate` is a *trace*, which records certain information obtained during execution of the function.\n", + "\n", + "The foremost information stored in the trace is the *choice map*, which is tree of labels mapping to the corresponding stochastic choices, i.e. occurrences of the `@` operator, that were encountered. It is accessed by `get_choices`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# `simulate` takes the GF plus a tuple of args to pass to it.\n", + "key, sub_key = jax.random.split(key)\n", + "trace = step_proposal.simulate(\n", + " sub_key,\n", + " (default_motion_settings, robot_inputs[\"start\"], robot_inputs[\"controls\"][0]),\n", + ")\n", + "trace.get_choices()" + ] + }, + { + "cell_type": "markdown", + "id": "28", + "metadata": {}, + "source": [ + "The choice map being the point of focus of the trace in most discussions, we often abusively just speak of a *trace* when we really mean its *choice map*." + ] + }, + { + "cell_type": "markdown", + "id": "29", + "metadata": {}, + "source": [ + "### GenJAX API for traces\n", + "\n", + "One can access the primitive choices in a trace using the method `get_choices`.\n", + "One can access from a trace the GF that produced it using `trace.get_gen_fn()`, along with with arguments that were supplied using `trace.get_args()`, and the return value sample of the GF using the method `get_retval()`. See below the fold for examples of all these." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "pose_choices = trace.get_choices()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "pose_choices[\"hd\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "pose_choices[\"p\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "trace.get_gen_fn()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "trace.get_args()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "trace.get_retval()" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, + "source": [ + "### Traces: scores/weights/densities\n", + "\n", + "Traced execution of a generative function also produces a particular kind of score/weight/density. It is very important to be clear about which score/weight/density value is to be expected, and why. Consider the following generative function\n", + "```\n", + "p = 0.25\n", + "@genjax.gen\n", + "def g(x,y):\n", + " flip = genjax.flip(p) @ 'flip'\n", + " return jax.lax.select(flip, x, y)\n", + "end\n", + "```\n", + "that, given two inputs `x` and `y`, flips a coin with weight `p`, and accordingly returns `x` or `y`. When `x` and `y` are unequal, a sensible reporting of the score/weight/density in the sampling process would produce `p` or `1.0-p` accordingly. If the user supplied equal values `x == y`, then which score/weight/density should be returned?\n", + "\n", + "One tempting view identifies a GF with a *distribution over its return values*. In this view, the correct score/weight/density of `g` above would be $1$.\n", + "\n", + "The mathematical picture would be as follows. Given a stochastic function $g$ from $X$ to $X'$, the results of calling $g$ on the input $x$ are described by a probability distribution $k_{g;x}$ on $X'$. A family of probability distributions of this form is called a *probability kernel* and is indicated by the dashed arrow $k_g \\colon X \\dashrightarrow X'$. And for some $x,x'$ we would be seeking the density $k_{g;x}(x')$ with which the sample $x' \\sim k_{g;x}$ occurs. Pursuing this approach requires knowlege of all execution histories that $g$ that might have followed from $x$ to $x'$, and then performing a sum or integral over them. For some small finite situations this may be fine, but this general problem of computing marginalizations is computationally impossible.\n", + "\n", + "The marginalization question is especially forced upon us when trying to compose stochastic functions. Given a second stochastic function $g'$ from $X'$ to $X''$, corresponding to a probability kernel $k_{g'} \\colon X' \\dashrightarrow X''$, the composite $g' \\circ g$ from $X$ to $X''$ should correspond to the following probability kernel $k_{g' \\circ g} \\colon X \\dashrightarrow X''$. To sample $x'' \\sim k_{g' \\circ g;x}$ means \"sample $x' \\sim k_{g;x}$, then sample $x'' \\sim k_{g';x'}$, then return $x''$\". However, computing the density $k_{g' \\circ g;x}(x'')$, even if one can compute $k_{g;x}(x')$ and $k_{g';x'}(x'')$ for any given $x,x',x''$, would require summing or integrating over all possible intermediate values $x'$ (which manifests an \"execution history\" of $g' \\circ g$) that could have intervened in producing $x''$ given $x$.\n", + "\n", + "Therefore, distribution-over-return-values is ***not the viewpoint of Gen***, and the score/weight/density being introduced here is a ***different number***.\n", + "\n", + "The only thing a program can reasonably be expected to know is the score/weight/density of its arriving at its return value *via the particular stochastic computation path* that got it there, and the approach of Gen is to report this number. The corresponding mathematical picture imagines GFs as factored into *distributions over choice maps*, whose score/weight/density is computable, together with *deterministic functions on these data* that produce the return value from them. In mathematical language, a GF $g$ from $X$ to $X'$ corresponds to the data of an auxiliary space $U_g$ containing all of the choice map information, a probability kernel $k_g \\colon X \\dashrightarrow U_g$ (with computable density) embodying the stochastic execution history, and a deterministic function that we will (somewhat abusively) denote $g \\colon X \\times U_g \\to X'$ embodying extraction of the return value from the particular stochastic execution choices.\n", + "\n", + "In the toy example `g` above, choice map consists of `flip` so the space $U_g$ is binary; the deterministic computation $g$ amounts to the `return` statement; and the score/weight/density is `p` or `1.0-p`, regardless of whether the inputs are equal.\n", + "\n", + "Tractable compositionality holds in this formulation; let's spell it out. If another GF $g'$ from $X'$ to $X''$ has data $U_{g'}$, $k_{g'}$, and $g' \\colon X' \\times U_{g'} \\to X''$, then the composite GF $g' \\circ g$ from $X$ to $X''$ has the following data.\n", + "* The auxiliary space is $U_{g' \\circ g} := U_g \\times U_{g'}$.\n", + "* The kernel $k_{g' \\circ g}$ is defined by \"sample $u \\sim k_{g;x}$, then compute $x' = \\text{ret}_g(x,u)$, then sample $u' \\sim k_{g';x'}$, then return $(u,u')$\", and\n", + "* its density is computed via $k_{g' \\circ g; x}(u,u') := k_{g;x}(u) \\cdot k_{g';g(x,u)}(u')$.\n", + "* The return value function is $(g' \\circ g)(x,(u,u')) := g'(g(x,u),u')$.\n", + "\n", + "As one composes more GFs, the auxiliary space accumulates more factors $U$, reflecting how the \"execution history\" consists of longer and longer records.\n", + "\n", + "In this picture, one may still be concerned with the distribution on return values as in the straw man viewpoint. This information is still embodied in the aggregate of the stochastic executions that lead to any return value, together with their weights. (Consider that this is true even in the toy example! More math?) In a sense, when we kick the can of marginalization down the road, we can proceed without difficulty.\n", + "\n", + "A final caveat: The common practice of confusing traces with their choice maps continues here, and we speak of a GF inducing a \"distribution over traces\"." + ] + }, + { + "cell_type": "markdown", + "id": "37", + "metadata": {}, + "source": [ + "Let's have a look at the score/weight/densities in our running example.\n", + "\n", + "A pose consists of a pair $z = (z_\\text p, z_\\text{hd})$ where $z_\\text p$ is a position vector and $z_\\text{hd}$ is an angle. A control consists of a pair $(s, \\eta)$ where $s$ is a distance of displacement and $\\eta$ is a change in angle. Write $u(\\theta) = (\\cos\\theta, \\sin\\theta)$ for the unit vector in the direction $\\theta$. We are given a \"world\" $w$ and \"motion settings\" parameters $\\nu = (\\nu_\\text p, \\nu_\\text{hd})$.\n", + "\n", + "The models and `step_proposal` correspond to distributions over their traces, respectively written $\\text{start}$ and $\\text{step}$. In both cases these traces consist of the choices at addresses `:p` and `:hd`, so they may be identified with poses $z$ as above. The distributions are defined as follows, when $y$ is a pose:\n", + "* $z \\sim \\text{start}(y, \\nu)$ means that $z_\\text p \\sim \\text{mvnormal}(y_\\text p, \\nu_\\text p^2 I)$ and $z_\\text{hd} \\sim \\text{normal}(y_\\text{hd}, \\nu_\\text{hd})$ independently.\n", + "* $z \\sim \\text{step}(y, (s, \\eta), w, \\nu)$ means that $z_\\text p \\sim \\text{mvnormal}(y_\\text p + s\\,u(y_\\text{hd}), \\nu_\\text p^2 I)$ and $z_\\text{hd} \\sim \\text{normal}(y_\\text{hd} + \\eta, \\nu_\\text {hd})$ independently.\n", + "\n", + "The return values $\\text{retval}(z)$ of these models are obtained from traces $z$ by reducing $z_\\text{hd}$ modulo $2\\pi$, and in the second case applying collision physics (relative to $w$) to the path from $y_\\text p$ to $z_\\text p$. (We invite the reader to imagine if PropComp required us to compute the marginal density of the return value here!) We have the following closed form for the density functions:\n", + "$$\\begin{align*}\n", + "P_\\text{start}(z; y, \\nu)\n", + "&= P_\\text{mvnormal}(z_\\text p; y_\\text p, \\nu_\\text p^2 I)\n", + "\\cdot P_\\text{normal}(z_\\text{hd}; y_\\text{hd}, \\nu_\\text{hd}), \\\\\n", + "P_\\text{step}(z; y, (s, \\eta), w, \\nu)\n", + "&= P_\\text{mvnormal}(z_\\text p; y_\\text p + s\\,u(y_\\text{hd}), \\nu_\\text p^2 I)\n", + "\\cdot P_\\text{normal}(z_\\text{hd}; y_\\text{hd} + \\eta, \\nu_\\text{hd}).\n", + "\\end{align*}$$\n", + "\n", + "In general, the density of any trace factors as the product of the densities of the individual primitive choices that appear in it. Since the primitive distributions of the language are equipped with efficient probability density functions, this overall computation is tractable. It is represented by `Gen.get_score`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "trace.get_score()" + ] + }, + { + "cell_type": "markdown", + "id": "39", + "metadata": {}, + "source": [ + "#### Subscores/subweights/subdensities\n", + "\n", + "Instead of (the log of) the product of all the primitive choices made in a trace, one can take the product over just a subset using `Gen.project`. See below the fold for examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": { + "lines_to_next_cell": 0, + "title": "[hide-input]" + }, + "outputs": [], + "source": [ + "# jax.random.split(jax.random.PRNGKey(3333), N_samples).shape\n", + "\n", + "ps0 = jax.tree.map(lambda v: v[0], pose_samples)\n", + "(\n", + " ps0.project(jax.random.PRNGKey(2), S[()]),\n", + " ps0.project(jax.random.PRNGKey(2), S[\"p\"]),\n", + " ps0.project(jax.random.PRNGKey(2), S[\"p\"] | S[\"hd\"]),\n", + ")\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "trace.project(key, S[()])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "trace.project(key, S[(\"p\")])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "trace.project(key, S[\"p\"] | S[\"hd\"])" + ] + }, + { + "cell_type": "markdown", + "id": "43", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "If in fact all of those projections resulted in the same number, you have encountered the issue [GEN-316](https://linear.app/chi-fro/issue/GEN-316/project-is-broken-in-genjax).\n", + "\n", + "### Modeling a full path\n", + "\n", + "The model contains all information in its trace, rendering its return value redundant. The noisy path integration will just be a wrapper around its functionality, extracting what it needs from the trace.\n", + "\n", + "(It is worth acknowledging two strange things in the code below: the use of the suffix `.accumulate()` in path_model and the use of that auxiliary function itself." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "path_model = (\n", + " step_proposal.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_path_trace(key: PRNGKey) -> genjax.Trace:\n", + " return path_model.simulate(key, (robot_inputs[\"start\"], robot_inputs[\"controls\"]))\n", + "\n", + "\n", + "def path_from_trace(tr: genjax.Trace) -> Pose:\n", + " return tr.get_retval()[1]\n", + "\n", + "\n", + "def generate_path(key: PRNGKey) -> Pose:\n", + " return path_from_trace(generate_path_trace(key))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "generate_path_trace(sub_key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "N_samples = 12\n", + "key, sub_key = jax.random.split(key)\n", + "sample_paths_v = jax.vmap(generate_path)(jax.random.split(sub_key, N_samples))\n", + "\n", + "Plot.Grid(*[walls_plot + poses_to_plots(path) for path in sample_paths_v])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": {}, + "outputs": [], + "source": [ + "# Animation showing a single path with confidence circles\n", + "\n", + "\n", + "def plot_path_with_confidence(path: Pose, step: int, p_noise: float):\n", + " plot = (\n", + " world_plot\n", + " + [pose_plot(path[i]) for i in range(step + 1)]\n", + " + Plot.color_map({\"next pose\": \"red\"})\n", + " )\n", + " if step < len(path) - 1:\n", + " plot += [\n", + " confidence_circle(\n", + " path[step].apply_control(robot_inputs[\"controls\"][step]),\n", + " p_noise,\n", + " ),\n", + " pose_plot(path[step + 1], fill=Plot.constantly(\"next pose\")),\n", + " ]\n", + " return plot\n", + "\n", + "\n", + "def animate_path_with_confidence(path: Pose, motion_settings: dict):\n", + " frames = [\n", + " plot_path_with_confidence(path, step, motion_settings[\"p_noise\"])\n", + " for step in range(len(path.p))\n", + " ]\n", + "\n", + " return Plot.Frames(frames, fps=2)\n", + "\n", + "\n", + "# Generate a single path\n", + "key, sample_key = jax.random.split(key)\n", + "path = generate_path(sample_key)\n", + "Plot.Frames(\n", + " [\n", + " plot_path_with_confidence(path, step, default_motion_settings[\"p_noise\"])\n", + " + Plot.title(\"Motion model (samples)\")\n", + " for step in range(len(path))\n", + " ],\n", + " fps=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "49", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "### Modifying traces\n", + "\n", + "The metaprogramming approach of Gen affords the opportunity to explore alternate stochastic execution histories. Namely, `trace.update` takes as inputs a source of randomness, together with modifications to its arguments and primitive choice values, and returns an accordingly modified trace. It also returns (the log of) the ratio of the updated trace's density to the original trace's density, together with a precise record of the resulting modifications that played out.\n", + "\n", + "One could, for instance, consider just the placement of the first step, and replace its stochastic choice of heading with an updated value. The original trace was typical under the pose prior model, whereas the modified one may be rather less likely. This plot is annotated with log of how much unlikelier, the score ratio:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50", + "metadata": {}, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "trace = step_proposal.simulate(\n", + " sub_key,\n", + " (default_motion_settings, robot_inputs[\"start\"], robot_inputs[\"controls\"][0]),\n", + ")\n", + "key, sub_key = jax.random.split(key)\n", + "rotated_trace, rotated_trace_weight_diff, _, _ = trace.update(\n", + " sub_key, C[\"hd\"].set(jnp.pi / 2.0)\n", + ")\n", + "\n", + "# TODO(huebert): try using a slider to choose the heading we set (initial value is 0.0)\n", + "\n", + "(\n", + " Plot.new(\n", + " world_plot\n", + " + pose_plot(trace.get_retval(), fill=Plot.constantly(\"some pose\"))\n", + " + pose_plot(\n", + " rotated_trace.get_retval(), fill=Plot.constantly(\"with heading modified\")\n", + " )\n", + " + Plot.color_map({\"some pose\": \"green\", \"with heading modified\": \"red\"})\n", + " + Plot.title(\"Modifying a heading\")\n", + " )\n", + " | html(\"span.tc\", f\"score ratio: {rotated_trace_weight_diff}\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "51", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "It is worth carefully thinking through a trickier instance of this. Suppose instead, within the full path, we replaced the first step's stochastic choice of heading with some specific value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52", + "metadata": {}, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "trace = generate_path_trace(sub_key)\n", + "key, sub_key = jax.random.split(key)\n", + "\n", + "rotated_first_step, rotated_first_step_weight_diff, _, _ = trace.update(\n", + " sub_key, C[0, \"steps\", \"pose\", \"hd\"].set(jnp.pi / 2.0)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " world_plot\n", + " + [\n", + " pose_plot(pose, fill=Plot.constantly(\"with heading modified\"))\n", + " for pose in path_from_trace(rotated_first_step)\n", + " ]\n", + " + [\n", + " pose_plot(pose, fill=Plot.constantly(\"some path\"))\n", + " for pose in path_from_trace(trace)\n", + " ]\n", + " + Plot.color_map({\"some path\": \"green\", \"with heading modified\": \"red\"})\n", + ") | html(\"span.tc\", f\"score ratio: {rotated_first_step_weight_diff}\")" + ] + }, + { + "cell_type": "markdown", + "id": "54", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "### Ideal sensors\n", + "\n", + "We now, additionally, assume the robot is equipped with sensors that cast\n", + "rays upon the environment at certain angles relative to the given pose,\n", + "and return the distance to a hit.\n", + "\n", + "We first describe the ideal case, where the sensors return the true\n", + "distances to the walls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "sensor_settings = {\n", + " \"fov\": 2 * jnp.pi * (2 / 3),\n", + " \"num_angles\": 41,\n", + " \"box_size\": world[\"box_size\"],\n", + "}\n", + "\n", + "\n", + "def sensor_distance(pose, walls, box_size):\n", + " distances = jax.vmap(distance, in_axes=(None, 0))(pose, walls)\n", + " d = jnp.min(distances)\n", + " # Capping to a finite value avoids issues below.\n", + " return jnp.where(jnp.isinf(d), 2.0 * box_size, d)\n", + "\n", + "\n", + "# This represents a \"fan\" of sensor angles, with given field of vision, centered at angle 0.\n", + "\n", + "\n", + "def make_sensor_angles(sensor_settings):\n", + " na = sensor_settings[\"num_angles\"]\n", + " return sensor_settings[\"fov\"] * (jnp.arange(na) - jnp.floor(na / 2)) / (na - 1)\n", + "\n", + "\n", + "sensor_angles = make_sensor_angles(sensor_settings)\n", + "\n", + "\n", + "def ideal_sensor(pose: Pose):\n", + " walls = world[\"walls\"]\n", + " box_size = sensor_settings[\"box_size\"]\n", + "\n", + " def reading(angle):\n", + " return sensor_distance(pose.rotate(angle), walls, box_size)\n", + "\n", + " return jax.vmap(reading)(sensor_angles)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot sensor data.\n", + "\n", + "\n", + "def plot_sensors(pose: Pose, readings):\n", + " projections = [\n", + " pose.rotate(angle).step_along(s) for angle, s in zip(sensor_angles, readings)\n", + " ]\n", + "\n", + " return [\n", + " Plot.line(\n", + " [(x, y, i) for i, p in enumerate(projections) for x, y in [pose.p, p.p]],\n", + " stroke=Plot.constantly(\"sensor rays\"),\n", + " ),\n", + " [\n", + " Plot.dot(\n", + " [pose.p for pose in projections],\n", + " r=2.75,\n", + " fill=Plot.constantly(\"sensor readings\"),\n", + " )\n", + " ],\n", + " Plot.color_map({\"sensor rays\": \"rgba(0,0,0,0.1)\", \"sensor readings\": \"#f80\"}),\n", + " ]\n", + "\n", + "\n", + "world_plot + plot_sensors(robot_inputs[\"start\"], ideal_sensor(robot_inputs[\"start\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57", + "metadata": {}, + "outputs": [], + "source": [ + "def animate_path_with_sensor(path, readings):\n", + " frames = [\n", + " (\n", + " world_plot\n", + " + [pose_plot(pose) for pose in path[:step]]\n", + " + plot_sensors(pose, readings[step])\n", + " + [pose_plot(pose, fill=\"red\")]\n", + " )\n", + " for step, pose in enumerate(path)\n", + " ]\n", + " return Plot.Frames(frames, fps=2)\n", + "\n", + "\n", + "key, sample_key = jax.random.split(key)\n", + "path = generate_path(sample_key)\n", + "readings = jax.vmap(ideal_sensor)(path)\n", + "animate_path_with_sensor(generate_path(sample_key), readings)" + ] + }, + { + "cell_type": "markdown", + "id": "58", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "### Noisy sensors\n", + "\n", + "We assume that the sensor readings are themselves uncertain, say, the distances only knowable\n", + "up to some noise. We model this as follows." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59", + "metadata": {}, + "outputs": [], + "source": [ + "@genjax.gen\n", + "def sensor_model_one(pose, angle):\n", + " sensor_pose = pose.rotate(angle)\n", + " return (\n", + " genjax.normal(\n", + " sensor_distance(sensor_pose, world[\"walls\"], sensor_settings[\"box_size\"]),\n", + " sensor_settings[\"s_noise\"],\n", + " )\n", + " @ \"distance\"\n", + " )\n", + "\n", + "\n", + "sensor_model = sensor_model_one.vmap(in_axes=(None, 0))\n", + "\n", + "\n", + "def noisy_sensor(pose):\n", + " trace = sensor_model.simulate(key, (pose, sensor_angles))\n", + " return trace" + ] + }, + { + "cell_type": "markdown", + "id": "60", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "The trace contains many choices corresponding to directions of sensor reading from the input pose. To explore the trace values, open the \"folders\" within the trace by clicking on the small triangles until you see the 41-element array of sensor values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61", + "metadata": {}, + "outputs": [], + "source": [ + "sensor_settings[\"s_noise\"] = 0.10\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "trace = sensor_model.simulate(sub_key, (robot_inputs[\"start\"], sensor_angles))" + ] + }, + { + "cell_type": "markdown", + "id": "62", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "The mathematical picture is as follows. Given the parameters of a pose $y$, walls $w$,\n", + "and settings $\\nu$, one gets a distribution $\\text{sensor}(y, w, \\nu)$ over the traces\n", + "of `sensor_model`, and when $z$ is a motion model trace we set\n", + "$\\text{sensor}(z, w, \\nu) := \\text{sensor}(\\text{retval}(z), w, \\nu)$.\n", + "Its samples are identified with vectors $o = (o^{(1)}, o^{(2)}, \\ldots, o^{(J)})$, where\n", + "$J := \\nu_\\text{num\\_angles}$, each $o^{(j)}$ independently following a certain normal\n", + "distribution (depending, notably, on the distance from the pose to the nearest wall).\n", + "Thus the density of $o$ factors into a product of the form\n", + "$$\n", + "P_\\text{sensor}(o) = \\prod\\nolimits_{j=1}^J P_\\text{normal}(o^{(j)})\n", + "$$\n", + "where we begin a habit of omitting the parameters to distributions that are implied by the code.\n", + "\n", + "Visualizing the traces of the model is probably more useful for orientation, so we do this now." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63", + "metadata": {}, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "path = generate_path(sub_key)\n", + "readings = jax.vmap(noisy_sensor)(path).get_retval()\n", + "animate_path_with_sensor(path, readings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# TODO: annotate plot with title \"Sensor model (samples)\"" + ] + }, + { + "cell_type": "markdown", + "id": "65", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "### Full model\n", + "\n", + "We fold the sensor model into the motion model to form a \"full model\", whose traces describe simulations of the entire robot situation as we have described it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "@genjax.gen\n", + "def full_model_kernel(motion_settings, state, control):\n", + " pose = step_proposal(motion_settings, state, control) @ \"pose\"\n", + " sensor_model(pose, sensor_angles) @ \"sensor\"\n", + " return pose, pose\n", + "\n", + "\n", + "@genjax.gen\n", + "def full_model(motion_settings):\n", + " return (\n", + " full_model_kernel.partial_apply(motion_settings).scan()(\n", + " robot_inputs[\"start\"], robot_inputs[\"controls\"]\n", + " )\n", + " @ \"steps\"\n", + " )\n", + "\n", + "\n", + "def get_path(trace):\n", + " ps = trace.get_retval()[1]\n", + " return ps\n", + "\n", + "\n", + "def get_sensors(trace):\n", + " ch = trace.get_choices()\n", + " return ch[\"steps\", :, \"sensor\", :, \"distance\"]\n", + "\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "tr = full_model.simulate(sub_key, (default_motion_settings,))\n", + "\n", + "pz.ts.display(tr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "68", + "metadata": {}, + "source": [ + "Again, the trace of the full model contains many choices, so we have used the Penzai visualization library to render the result. Click on the various nesting arrows and see if you can find the path within. For our purposes, we will supply a function `get_path` which will extract the list of Poses that form the path." + ] + }, + { + "cell_type": "markdown", + "id": "69", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "In the math picture, `full_model` corresponds to a distribution $\\text{full}$ over its traces. Such a trace is identified with of a pair $(z_{0:T}, o_{0:T})$ where $z_{0:T} \\sim \\text{path}(\\ldots)$ and $o_t \\sim \\text{sensor}(z_t, \\ldots)$ for $t=0,\\ldots,T$. The density of this trace is then\n", + "$$\\begin{align*}\n", + "P_\\text{full}(z_{0:T}, o_{0:T})\n", + "&= P_\\text{path}(z_{0:T}) \\cdot \\prod\\nolimits_{t=0}^T P_\\text{sensor}(o_t) \\\\\n", + "&= \\big(P_\\text{start}(z_0)\\ P_\\text{sensor}(o_0)\\big)\n", + " \\cdot \\prod\\nolimits_{t=1}^T \\big(P_\\text{step}(z_t)\\ P_\\text{sensor}(o_t)\\big).\n", + "\\end{align*}$$\n", + "\n", + "By this point, visualization is essential." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "tr = full_model.simulate(sub_key, (default_motion_settings,))\n", + "\n", + "\n", + "def animate_path_and_sensors(path, readings, motion_settings, frame_key=None):\n", + " frames = [\n", + " plot_path_with_confidence(path, step, motion_settings[\"p_noise\"])\n", + " + plot_sensors(pose, readings[step])\n", + " for step, pose in enumerate(path)\n", + " ]\n", + "\n", + " return Plot.Frames(frames, fps=2, key=frame_key)\n", + "\n", + "\n", + "def animate_full_trace(trace, frame_key=None):\n", + " path = get_path(trace)\n", + " readings = get_sensors(trace)\n", + " motion_settings = trace.get_args()[0]\n", + " return animate_path_and_sensors(\n", + " path, readings, motion_settings, frame_key=frame_key\n", + " )\n", + "\n", + "\n", + "animate_full_trace(tr)" + ] + }, + { + "cell_type": "markdown", + "id": "71", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "## The data\n", + "\n", + "Let us generate some fixed synthetic motion data that, for pedagogical purposes, we will work with as if it were the actual path of the robot. We will generate two versions, one each with low or high motion deviation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "motion_settings_low_deviation = {\n", + " \"p_noise\": 0.05,\n", + " \"hd_noise\": (1 / 10.0) * 2 * jnp.pi / 360,\n", + "}\n", + "key, k_low, k_high = jax.random.split(key, 3)\n", + "\n", + "trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,))\n", + "motion_settings_high_deviation = {\"p_noise\": 0.25, \"hd_noise\": 2 * jnp.pi / 360}\n", + "trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,))\n", + "\n", + "animate_full_trace(trace_low_deviation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# TODO: next task is to create a side-by-side animation of the low and high deviation paths.\n", + "\n", + "animate_full_trace(trace_high_deviation)" + ] + }, + { + "cell_type": "markdown", + "id": "74", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Since we imagine these data as having been recorded from the real world, keep only their extracted data, *discarding* the traces that produced them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75", + "metadata": {}, + "outputs": [], + "source": [ + "# These are what we hope to recover...\n", + "path_low_deviation = get_path(trace_low_deviation)\n", + "path_high_deviation = get_path(trace_high_deviation)\n", + "\n", + "# ...using these data.\n", + "observations_low_deviation = get_sensors(trace_low_deviation)\n", + "observations_high_deviation = get_sensors(trace_high_deviation)\n", + "\n", + "# Encode sensor readings into choice map.\n", + "\n", + "\n", + "def constraint_from_sensors(readings):\n", + " angle_indices = jnp.arange(len(sensor_angles))\n", + " return jax.vmap(\n", + " lambda ix, v: C[\"steps\", ix, \"sensor\", angle_indices, \"distance\"].set(v)\n", + " )(jnp.arange(T), readings) + C[\"initial\", \"sensor\", angle_indices, \"distance\"].set(\n", + " readings[0]\n", + " )\n", + "\n", + "\n", + "constraints_low_deviation = constraint_from_sensors(observations_low_deviation)\n", + "constraints_high_deviation = constraint_from_sensors(observations_high_deviation)" + ] + }, + { + "cell_type": "markdown", + "id": "76", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "We summarize the information available to the robot to determine its location. On the one hand, one has to produce a guess of the start pose plus some controls, which one might integrate to produce an idealized guess of path. On the other hand, one has the sensor data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "def animate_bare_sensors(path, plot_base=[]):\n", + " def frame(pose, readings1, readings2):\n", + " def plt(readings):\n", + " return Plot.new(\n", + " plot_base or Plot.domain([0, 20]),\n", + " plot_sensors(pose, readings),\n", + " {\"width\": 400, \"height\": 400},\n", + " )\n", + "\n", + " return plt(readings1) & plt(readings2)\n", + "\n", + " frames = [\n", + " frame(*scene)\n", + " for scene in zip(path, observations_low_deviation, observations_high_deviation)\n", + " ]\n", + " return Plot.Frames(frames, fps=2)\n", + "\n", + "\n", + "animate_bare_sensors(itertools.repeat(world[\"center_point\"]))" + ] + }, + { + "cell_type": "markdown", + "id": "78", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "## Inference\n", + "### Why we need inference: in a picture\n", + "\n", + "The path obtained by integrating the controls serves as a proposal for the true path, but it is unsatisfactory, especially in the high motion deviation case. The picture gives an intuitive sense of the fit:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "animate_bare_sensors(path_integrated, world_plot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "world_plot + plot_sensors(robot_inputs[\"start\"], observations_low_deviation[0])" + ] + }, + { + "cell_type": "markdown", + "id": "81", + "metadata": {}, + "source": [ + "It would seem that the fit is reasonable in low motion deviation, but really breaks down in high motion deviation.\n", + "\n", + "We are not limited to visual judgments here: the model can quantitatively assess how good a fit the integrated path is for the data. In order to do this, we detour to explain how to produce samples from our model that agree with the fixed observation data." + ] + }, + { + "cell_type": "markdown", + "id": "82", + "metadata": {}, + "source": [ + "### Producing samples with constraints\n", + "\n", + "We have seen how `simulate` performs traced execution of a generative function: as the program runs, it draws stochastic choices from all required primitive distributions, and records them in a choice map.\n", + "\n", + "Given a choice map of *constraints* that declare fixed values of some of the primitive choices, the operation `importance` proposes traces of the generative function that are consistent with these constraints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "model_importance = jax.jit(full_model.importance)\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "sample, log_weight = model_importance(\n", + " sub_key, constraints_low_deviation, (motion_settings_low_deviation,)\n", + ")\n", + "animate_full_trace(sample) | html(\"span.tc\", f\"log_weight: {log_weight}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "sample, log_weight = model_importance(\n", + " sub_key, constraints_high_deviation, (motion_settings_high_deviation,)\n", + ")\n", + "animate_full_trace(sample) | html(\"span.tc\", f\"log_weight: {log_weight}\")" + ] + }, + { + "cell_type": "markdown", + "id": "85", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "A trace resulting from a call to `importance` is structurally indistinguishable from one drawn from `simulate`. But there is a key situational difference: while `get_score` always returns the frequency with which `simulate` stochastically produces the trace, this value is **no longer equal to** the frequency with which the trace is stochastically produced by `importance`. This is both true in an obvious and less relevant sense, as well as true in a more subtle and extremely germane sense.\n", + "\n", + "On the superficial level, since all traces produced by `importance` are consistent with the constraints, those traces that are inconsistent with the constraints do not occur at all, and in aggregate the traces that are consistent with the constraints are more common.\n", + "\n", + "More deeply and importantly, the stochastic choice of the *constraints* under a run of `simulate` might have any density, perhaps very low. This constraints density contributes as always to the `get_score`, whereas it does not influence the frequency of producing this trace under `importance`.\n", + "\n", + "The ratio of the `get_score` of a trace to the probability density that `importance` would produce it with the given constraints, is called the *importance weight*. For convenience, (the log of) this quantity is returned by `importance` along with the trace.\n", + "\n", + "We stress the basic invariant:\n", + "$$\n", + "\\text{get\\_score}(\\text{trace})\n", + "=\n", + "(\\text{weight from importance})\n", + "\\cdot\n", + "(\\text{frequency simulate creates this trace}).\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "86", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "The preceding comments apply to generative functions in wide generality. We can say even more about our present examples, because further assumptions hold.\n", + "1. There is no untraced randomness. Given a full choice map for constraints, everything else is deterministic. In particular, the importance weight is the `get_score`.\n", + "2. The generative function was constructed using GenJAX's DSL and primitive distributions. Ancestral sampling; `importance` with empty constraints reduces to `simulate` with importance weight $1$.\n", + "3. Combined, the importance weight is directly computed as the `project` of the trace upon the choice map addresses that were constrained in the call to `importance`.\n", + "\n", + " In our running example, the projection in question is $\\prod_{t=0}^T P_\\text{sensor}(o_t)$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: this calculation doesn't work in GenJAX currently\n", + "# log_weight - project(trace, select([prefix_address(i, :sensor) for i in 1:(T+1)]...))" + ] + }, + { + "cell_type": "markdown", + "id": "88", + "metadata": {}, + "source": [ + "### Why we need inference: in numbers\n", + "\n", + "We return to how the model offers a numerical benchmark for how good a fit the integrated path is.\n", + "\n", + "In words, the data are incongruously unlikely for the integrated path. The (log) density of the measurement data, given the integrated path..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "def constraint_from_path(path):\n", + " c_ps = jax.vmap(lambda ix, p: C[\"steps\", ix, \"pose\", \"p\"].set(p))(\n", + " jnp.arange(T), path.p\n", + " )\n", + "\n", + " c_hds = jax.vmap(lambda ix, hd: C[\"steps\", ix, \"pose\", \"hd\"].set(hd))(\n", + " jnp.arange(T), path.hd\n", + " )\n", + " return c_ps + c_hds # + c_p + c_hd\n", + "\n", + "\n", + "constraints_path_integrated = constraint_from_path(path_integrated)\n", + "constraints_path_integrated_observations_low_deviation = (\n", + " constraints_path_integrated ^ constraints_low_deviation\n", + ")\n", + "constraints_path_integrated_observations_high_deviation = (\n", + " constraints_path_integrated ^ constraints_high_deviation\n", + ")\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "trace_path_integrated_observations_low_deviation, w_low = model_importance(\n", + " sub_key,\n", + " constraints_path_integrated_observations_low_deviation,\n", + " (motion_settings_low_deviation,),\n", + ")\n", + "key, sub_key = jax.random.split(key)\n", + "trace_path_integrated_observations_high_deviation, w_high = model_importance(\n", + " sub_key,\n", + " constraints_path_integrated_observations_high_deviation,\n", + " (motion_settings_high_deviation,),\n", + ")\n", + "\n", + "w_low, w_high\n", + "# TODO: Jay then does two projections to compare the log-weights of these two things,\n", + "# in order to show that we can be quantitative about the quality of the paths generated\n", + "# by the two models. Unfortunately we can't, and so we should raise the priority of the\n", + "# blocking bug" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90", + "metadata": {}, + "outputs": [], + "source": [ + "Plot.Row(\n", + " *[\n", + " (\n", + " html(\"div.f3.b.tc\", title)\n", + " | animate_full_trace(trace, frame_key=\"frame\")\n", + " | html(\"span.tc\", f\"score: {score:,.2f}\")\n", + " )\n", + " for (title, trace, motion_settings, score) in [\n", + " [\n", + " \"Low deviation\",\n", + " trace_path_integrated_observations_low_deviation,\n", + " motion_settings_low_deviation,\n", + " w_low,\n", + " ],\n", + " [\n", + " \"High deviation\",\n", + " trace_path_integrated_observations_high_deviation,\n", + " motion_settings_high_deviation,\n", + " w_high,\n", + " ],\n", + " ]\n", + " ]\n", + ") | Plot.Slider(\"frame\", 0, T, fps=2)" + ] + }, + { + "cell_type": "markdown", + "id": "91", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92", + "metadata": {}, + "outputs": [], + "source": [ + "N_samples = 200\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "\n", + "traces_generated_low_deviation, low_weights = jax.vmap(\n", + " model_importance, in_axes=(0, None, None)\n", + ")(\n", + " jax.random.split(sub_key, N_samples),\n", + " constraints_low_deviation,\n", + " (motion_settings_low_deviation,),\n", + ")\n", + "\n", + "traces_generated_high_deviation, high_weights = jax.vmap(\n", + " model_importance, in_axes=(0, None, None)\n", + ")(\n", + " jax.random.split(sub_key, N_samples),\n", + " constraints_high_deviation,\n", + " (motion_settings_high_deviation,),\n", + ")\n", + "\n", + "# low_weights, high_weights\n", + "# two histograms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93", + "metadata": {}, + "outputs": [], + "source": [ + "low_deviation_paths = jax.vmap(get_path)(traces_generated_low_deviation)\n", + "high_deviation_paths = jax.vmap(get_path)(traces_generated_high_deviation)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "Plot.new(\n", + " world_plot,\n", + " [\n", + " poses_to_plots(pose, fill=\"blue\", opacity=0.1)\n", + " for pose in high_deviation_paths[:20]\n", + " ],\n", + " [\n", + " poses_to_plots(pose, fill=\"green\", opacity=0.1)\n", + " for pose in low_deviation_paths[:20]\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "95", + "metadata": {}, + "source": [ + "## Generic strategies for inference\n", + "\n", + "We now spell out some generic strategies for conditioning the ouputs of a model towards some observed data. The word \"generic\" indicates that they make no special intelligent use of the model structure, and their convergence is guaranteed by theorems of a similar nature. In terms to be defined shortly, they simply take a pair $(Q,f)$ of a proposal and a weight function that implement importance sampling with target $P$.\n", + "\n", + "There is no free lunch in this game: generic inference recipies are inefficient, for example, converging very slowly or needing vast counts of particles, especially in high-dimensional settings. One of the root problems is that proposals $Q$ may provide arbitrarily bad samples relative to our target $P$; if $Q$ still supports all samples of $P$ with microscopic but nonzero density, then the generic algorithm will converge in the limit, however astronomically slowly.\n", + "\n", + "Rather, efficiency will become possible when we do the *opposite* of generic: exploit what we actually know about the problem in our design of the inference strategy to propose better traces towards our target. Gen's aim is to provide the right entry points to enact this exploitation." + ] + }, + { + "cell_type": "markdown", + "id": "96", + "metadata": {}, + "source": [ + "### The posterior distribution and importance sampling\n", + "\n", + "Mathematically, the passage from the prior to the posterior is the operation of conditioning distributions.\n", + "\n", + "Intuitively, the conditional distribution $\\text{full}(\\cdot | o_{0:T})$ is just the restriction of the joint distribution $\\text{full}(z_{0:T}, o_{0:T})$ to where the parameter $o_{0:T}$ is constant, letting $z_{0:T}$ continue to vary. This restriction no longer has total density equal to $1$, so we must renormalize it. The normalizing constant must be\n", + "$$\n", + "P_\\text{marginal}(o_{0:T})\n", + ":= \\int P_\\text{full}(Z_{0:T}, o_{0:T}) \\, dZ_{0:T}\n", + " = \\mathbf{E}_{Z_{0:T} \\sim \\text{path}}\\big[P_\\text{full}(Z_{0:T}, o_{0:T})\\big].\n", + "$$\n", + "By Fubini's Theorem, this function of $o_{0:T}$ is the density of a probability distribution over observations $o_{0:T}$, called the *marginal distribution*; but we will often have $o_{0:T}$ fixed, and consider it a constant. Then, finally, the *conditional distribution* $\\text{full}(\\cdot | o_{0:T})$ is defined to have the normalized density\n", + "$$\n", + "P_\\text{full}(z_{0:T} | o_{0:T}) := \\frac{P_\\text{full}(z_{0:T}, o_{0:T})}{P_\\text{marginal}(o_{0:T})}.\n", + "$$\n", + "\n", + "The goal of inference is to produce samples $\\text{trace}_{0:T}$ distributed (approximately) according to $\\text{full}(\\cdot | o_{0:T})$. The most immediately evident problem with doing inference is that the quantity $P_\\text{marginal}(o_{0:T})$ is intractable!" + ] + }, + { + "cell_type": "markdown", + "id": "97", + "metadata": {}, + "source": [ + "Define the function $\\hat f(z_{0:T})$ of sample values $z_{0:T}$ to be the ratio of probability densities between the posterior distribution $\\text{full}(\\cdot | o_{0:T})$ that we wish to sample from, and the prior distribution $\\text{path}$ that we are presently able to produce samples from. Manipulating it à la Bayes's Rule gives:\n", + "$$\n", + "\\hat f(z_{0:T})\n", + ":=\n", + "\\frac{P_\\text{full}(z_{0:T} | o_{0:T})}{P_\\text{path}(z_{0:T})}\n", + "=\n", + "\\frac{P_\\text{full}(z_{0:T}, o_{0:T})}{P_\\text{marginal}(o_{0:T}) \\cdot P_\\text{path}(z_{0:T})}\n", + "=\n", + "\\frac{\\prod_{t=0}^T P_\\text{sensor}(o_t)}{P_\\text{marginal}(o_{0:T})}.\n", + "$$\n", + "Noting that the intractable quantity\n", + "$$\n", + "Z := P_\\text{marginal}(o_{0:T})\n", + "$$\n", + "is constant in $z_{0:T}$, we define the explicitly computable quantity\n", + "$$\n", + "f(z_{0:T}) := Z \\cdot \\hat f(z_{0:T}) = \\prod\\nolimits_{t=0}^T P_\\text{sensor}(o_t).\n", + "$$\n", + "The right hand side has been written sloppily, but we remind the reader that $P_\\text{sensor}(o_t)$ is a product of densities of normal distributions that *does depend* on $z_t$ as well as \"sensor\" and \"world\" parameters.\n", + "\n", + "Compare to our previous description of calling `importance` on `full_model` with the observations $o_{0:T}$ as constraints: it produces a trace of the form $(z_{0:T}, o_{0:T})$ where $z_{0:T} \\sim \\text{path}$ has been drawn from $\\text{path}$, together with the weight equal to none other than this $f(z_{0:T})$." + ] + }, + { + "cell_type": "markdown", + "id": "98", + "metadata": {}, + "source": [ + "This reasoning involving `importance` is indicative of the general scenario with conditioning, and fits into the following shape.\n", + "\n", + "We have on hand two distributions, a *target* $P$ from which we would like to (approximately) generate samples, and a *proposal* $Q$ from which we are presently able to generate samples. We must assume that the proposal is a suitable substitute for the target, in the sense that every possible event under $P$ occurs under $Q$ (mathematically, $P$ is absolutely continuous with respect to $Q$).\n", + "\n", + "Under these hypotheses, there is a well-defined density ratio function $\\hat f$ between $P$ and $Q$ (mathematically, the Radon–Nikodym derivative). If $z$ is a sample drawn from $Q$, then $\\hat w = \\hat f(z)$ is how much more or less likely $z$ would have been drawn from $P$. We only require that we are able to compute the *unnormalized* density ratio, that is, some function of the form $f = Z \\cdot \\hat f$ where $Z > 0$ is constant.\n", + "\n", + "The pair $(Q,f)$ is said to implement *importance sampling* for $P$, and the values of $f$ are called *importance weights*. Generic inference attempts to use knowledge of $f$ to correct for the difference in behavior between $P$ and $Q$, and thereby use $Q$ to produce samples from (approximately) $P$.\n", + "\n", + "So in our running example, the target $P$ is the posterior distribution on paths $\\text{full}(\\cdot | o_{0:T})$, the proposal $Q$ is the path prior $\\text{path}$, and the importance weight $f$ is the product of the sensor model densities. We seek a computational model of the first; the second and third are computationally modeled by calling `importance` on `full_model` constrained by the observations $o_{0:T}$. (The computation of the second, on its own, simplifies to `path_prior`.)\n" + ] + }, + { + "cell_type": "markdown", + "id": "99", + "metadata": {}, + "source": [ + "\n", + "### TODO: TBD: rejection sampling. We proceed directly to SIR." + ] + }, + { + "cell_type": "markdown", + "id": "100", + "metadata": {}, + "source": [ + "### Sampling / importance resampling\n", + "\n", + "We turn to inference strategies that require only our proposal $Q$ and unnormalized weight function $f$ for the target $P$, *without* forcing us to wrangle any intractable integrals or upper bounds.\n", + "\n", + "Suppose we are given a list of nonnegative numbers, not all zero: $w^1, w^2, \\ldots, w^N$. To *normalize* the numbers means computing $\\hat w^i := w^i / \\sum_{j=1}^N w^j$. The normalized list $\\hat w^1, \\hat w^2, \\ldots, \\hat w^N$ determines a *categorical distribution* on the indices $1, \\ldots, N$, wherein the index $i$ occurs with probability $\\hat w^i$.\n", + "Note that for any constant $Z > 0$, the scaled list $Zw^1, Zw^2, \\ldots, Zw^N$ leads to the same normalized $\\hat w^i$ as well as the same categorical distribution.\n", + "\n", + "When some list of data $z^1, z^2, \\ldots, z^N$ have been associated with these respective numbers $w^1, w^2, \\ldots, w^N$, then to *importance **re**sample* $M$ values from these data according to these weights means to independently sample indices $a^1, a^2, \\ldots, a^M \\sim \\text{categorical}([\\hat w^1, \\hat w^2, \\ldots, \\hat w^N])$ and return the new list of data $z^{a^1}, z^{a^2}, \\ldots, z^{a^M}$. Compare to the function `resample` implemented in the code box below.\n", + "\n", + "The *sampling / importance resampling* (SIR) strategy for inference runs as follows. Let counts $N > 0$ and $M > 0$ be given.\n", + "1. Importance sample: Independently sample $N$ data $z^1, z^2, \\ldots, z^N$ from the proposal $Q$, called *particles*. Compute also their *importance weights* $w^i := f(z^i)$ for $i = 1, \\ldots, N$.\n", + "2. Importance resample: Independently sample $M$ indices $a^1, a^2, \\ldots, a^M \\sim \\text{categorical}([\\hat w^1, \\hat w^2, \\ldots, \\hat w^N])$, where $\\hat w^i = w^i / \\sum_{j=1}^N w^j$, and return $z^{a^1}, z^{a^2}, \\ldots, z^{a^M}$. These sampled particles all inherit the *average weight* $\\sum_{j=1}^N w^j / N$.\n", + "\n", + "As $N \\to \\infty$ with $M$ fixed, the samples produced by this algorithm converge to $M$ independent samples drawn from the target $P$. This strategy is computationally an improve\n", + "ment over rejection sampling: intead of indefinitely constructing and rejecting samples, we can guarantee to use at least some of them after a fixed time, and we are using the best guesses among these." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "101", + "metadata": {}, + "outputs": [], + "source": [ + "def resample(\n", + " key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int\n", + "):\n", + " key1, key2 = jax.random.split(key)\n", + " samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))(\n", + " jax.random.split(key1, N * K), constraints, (motion_settings,)\n", + " )\n", + " winners = jax.vmap(genjax.categorical.sampler)(\n", + " jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))\n", + " )\n", + " # indices returned are relative to the start of the K-segment from which they were drawn.\n", + " # globalize the indices by adding back the index of the start of each segment.\n", + " winners += jnp.arange(0, N * K, N)\n", + " selected = jax.tree.map(lambda x: x[winners], samples)\n", + " return selected\n", + "\n", + "\n", + "jit_resample = jax.jit(resample, static_argnums=(3, 4))\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "low_posterior = jit_resample(\n", + " sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20\n", + ")\n", + "key, sub_key = jax.random.split(key)\n", + "high_posterior = jit_resample(\n", + " sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "102", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "def animate_path_as_line(path, **options):\n", + " x_coords = path.p[:, 0]\n", + " y_coords = path.p[:, 1]\n", + " return Plot.line({\"x\": x_coords, \"y\": y_coords}, {\"curve\": \"linear\", **options})\n", + "\n", + "\n", + "#\n", + "(\n", + " world_plot\n", + " + [\n", + " animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke=\"green\")\n", + " for path in jax.vmap(get_path)(low_posterior)\n", + " ]\n", + " + [\n", + " animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke=\"blue\")\n", + " for path in jax.vmap(get_path)(high_posterior)\n", + " ]\n", + " + poses_to_plots(\n", + " path_low_deviation, fill=Plot.constantly(\"low deviation path\"), opacity=0.2\n", + " )\n", + " + poses_to_plots(\n", + " path_high_deviation, fill=Plot.constantly(\"high deviation path\"), opacity=0.2\n", + " )\n", + " + poses_to_plots(\n", + " path_integrated, fill=Plot.constantly(\"integrated path\"), opacity=0.2\n", + " )\n", + " + Plot.color_map(\n", + " {\n", + " \"low deviation path\": \"green\",\n", + " \"high deviation path\": \"blue\",\n", + " \"integrated path\": \"black\",\n", + " }\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "103", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Let's pause a moment to examine this chart. If the robot had no sensors, it would have no alternative but to estimate its position by integrating the control inputs to produce the integrated path in gray. In the low deviation setting, Gen has helped the robot to see that about halfway through its journey, noise in the control-effector relationship has caused the robot to deviate to the south slightly, and *the sensor data combined with importance sampling is enough* to give accurate results in the low deviation setting.\n", + "But in the high deviation setting, the loose nature of the paths in the blue posterior indicate that the robot has not discovered its true position by using importance sampling with the noisy sensor data. In the high deviation setting, more refined inference technique will be required.\n", + "\n", + "Let's approach the problem step by step instead of trying to infer the whole path.\n", + "To get started we'll work with the initial point, and then improve it. Once that's done,\n", + "we can chain together such improved moves to hopefully get a better inference of the\n", + "actual path.\n", + "\n", + "One thing we'll need is a path to improve. We can select one of the importance samples we generated\n", + "earlier." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "104", + "metadata": {}, + "outputs": [], + "source": [ + "def select_by_weight(key: PRNGKey, weights: FloatArray, things):\n", + " \"\"\"Makes a categorical selection from the vector object `things`\n", + " weighted by `weights`. The selected object is returned (with its\n", + " outermost axis removed) with its weight.\"\"\"\n", + " chosen = jax.random.categorical(key, weights)\n", + " return jax.tree.map(lambda v: v[chosen], things), weights[chosen]" + ] + }, + { + "cell_type": "markdown", + "id": "105", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Select an importance sample by weight in both the low and high deviation settings. It will be handy\n", + "to have one path to work with to test our improvements." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "106", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, k1, k2 = jax.random.split(key, 3)\n", + "low_deviation_path, _ = select_by_weight(k1, low_weights, low_deviation_paths)\n", + "high_deviation_path, _ = select_by_weight(k2, high_weights, high_deviation_paths)" + ] + }, + { + "cell_type": "markdown", + "id": "107", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Create a choicemap that will enforce the given sensor observation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "108", + "metadata": {}, + "outputs": [], + "source": [ + "def observation_to_choicemap(observation, pose=None):\n", + " sensor_cm = C[\"sensor\", :, \"distance\"].set(observation)\n", + " pose_cm = (\n", + " C[\"pose\", \"p\"].set(pose.p) + C[\"pose\", \"hd\"].set(pose.hd)\n", + " if pose is not None\n", + " else C.n()\n", + " )\n", + " return sensor_cm + pose_cm" + ] + }, + { + "cell_type": "markdown", + "id": "109", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Let's visualize a cloud of possible poses by coloring the elements proportional to their\n", + "plausibility under the sensor readingss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "110", + "metadata": {}, + "outputs": [], + "source": [ + "def step_sample(key: PRNGKey, N: int, gf, observation):\n", + " tr, ws = jax.vmap(gf.importance, in_axes=(0, None, None))(\n", + " jax.random.split(key, N), observation_to_choicemap(observation), ()\n", + " )\n", + " return tr.get_retval()[0], ws\n", + "\n", + "\n", + "def weighted_small_pose_plot(proposal, truth, weights, poses, zoom=1):\n", + " max_logw = jnp.max(weights)\n", + " lse_ws = max_logw + jnp.log(jnp.sum(jnp.exp(weights - max_logw)))\n", + " scaled_ws = jnp.exp(weights - lse_ws)\n", + " max_scaled_w: FloatArray = jnp.max(scaled_ws)\n", + " scaled_ws /= max_scaled_w\n", + " # the following hack \"boosts\" lower scores a bit, to give us more visibility into\n", + " # the density of the nearby cloud. Aesthetically, I found too many points were\n", + " # invisible without some adjustment, since the score distribution is concentrated\n", + " # closely around 1.0\n", + " scaled_ws = scaled_ws**0.3\n", + " z = 0.03 * zoom\n", + " return Plot.new(\n", + " [pose_plot(p, fill=w, zoom=z) for p, w in zip(poses, scaled_ws)]\n", + " + pose_plot(proposal, fill=\"red\", zoom=z)\n", + " + pose_plot(truth, fill=\"green\", zoom=z)\n", + " ) + {\n", + " \"color\": {\"type\": \"linear\", \"scheme\": \"OrRd\"},\n", + " \"height\": 400,\n", + " \"width\": 400,\n", + " \"aspectRatio\": 1,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "111", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "step_poses, step_scores = step_sample(\n", + " sub_key,\n", + " 1000,\n", + " full_model_kernel(\n", + " motion_settings_low_deviation,\n", + " robot_inputs[\"start\"],\n", + " robot_inputs[\"controls\"][0],\n", + " ),\n", + " observations_low_deviation[0],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "112", + "metadata": {}, + "outputs": [], + "source": [ + "weighted_small_pose_plot(\n", + " path_low_deviation[0], robot_inputs[\"start\"], step_scores, step_poses\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "113", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Develop a function which will produce a grid of evenly spaced nearby poses given\n", + "an initial pose. $n$ is the number of steps to take in each cardinal direction\n", + "(up/down, left/right and changes in heading). For example, if you say $n = 2$, there\n", + "will be a $5\\times 5$ grid of positions with the original pose in the center, and 5 layers\n", + "of this type, each with different heading deltas (including zero), for a total of\n", + "$125 = 5^3$ alternate poses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "114", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "def grid_of_nearby_poses(p, n, motion_settings):\n", + " indices = jnp.arange(-n, n + 1)\n", + " n_indices = len(indices)\n", + " point_deltas = indices * 2 * motion_settings[\"p_noise\"] / n\n", + " hd_deltas = indices * 2 * motion_settings[\"hd_noise\"] / n\n", + " xs = jnp.repeat(point_deltas, n_indices)\n", + " ys = jnp.tile(point_deltas, n_indices)\n", + " points = jnp.repeat(jnp.column_stack((xs, ys)), n_indices, axis=0)\n", + " headings = jnp.tile(hd_deltas, n_indices * n_indices)\n", + " return Pose(p.p + points, p.hd + headings)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "115", + "metadata": {}, + "outputs": [], + "source": [ + "def grid_sample(gf, pose_grid, observation):\n", + " scores, _retvals = jax.vmap(\n", + " lambda pose: gf.assess(observation_to_choicemap(observation, pose), ())\n", + " )(pose_grid)\n", + " return scores" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "116", + "metadata": {}, + "outputs": [], + "source": [ + "# Our grid of nearby poses is actually a cube when we take into consideration the\n", + "# heading deltas. In order to get a 2d density to visualize, we flatten the cube by\n", + "# taking the \"best\" of the headings by score at each point. (Note: for the inference\n", + "# that follows, we will work with the full cube).\n", + "def flatten_pose_cube(pose_grid, cube_step_size, scores):\n", + " n_indices = 2 * cube_step_size + 1\n", + " best_heading_indices = jnp.argmax(\n", + " scores.reshape(n_indices * n_indices, n_indices), axis=1\n", + " )\n", + " # those were block relative; linearize them by adding back block indices\n", + " bs = best_heading_indices + jnp.arange(0, n_indices**3, n_indices)\n", + " return Pose(pose_grid.p[bs], pose_grid.hd[bs]), scores[bs]" + ] + }, + { + "cell_type": "markdown", + "id": "117", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "Prepare a plot showing the density of nearby improvements available using the grid\n", + "search and importance sampling techniques." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "118", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Test our code for visualizing the Boltzmann and grid searches at the initial pose.\n", + "def first_step_chart(key):\n", + " cube_step_size = 6\n", + " pose_grid = grid_of_nearby_poses(\n", + " path_low_deviation[0], cube_step_size, motion_settings_low_deviation\n", + " )\n", + " gf = full_model_kernel(\n", + " motion_settings_low_deviation,\n", + " robot_inputs[\"start\"],\n", + " robot_inputs[\"controls\"][0],\n", + " )\n", + " score_grid = grid_sample(\n", + " gf,\n", + " pose_grid,\n", + " observations_low_deviation[0],\n", + " )\n", + " step_poses, step_scores = step_sample(\n", + " key,\n", + " 1000,\n", + " gf,\n", + " observations_low_deviation[0],\n", + " )\n", + " pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid)\n", + " return weighted_small_pose_plot(\n", + " path_low_deviation[0], robot_inputs[\"start\"], score_plane, pose_plane\n", + " ) & weighted_small_pose_plot(\n", + " path_low_deviation[0], robot_inputs[\"start\"], step_scores, step_poses\n", + " )\n", + "\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "first_step_chart(sub_key)" + ] + }, + { + "cell_type": "markdown", + "id": "119", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Now let's try doing the whole path. We want to produce something that is ultimately\n", + "scan-compatible, so it should have the form state -> update -> new_state. The state\n", + "is obviously the pose; the update will include the sensor readings at the current\n", + "position and the control input for the next step.\n", + "\n", + "Step 1. retire assess_model and use full_model_kernel in both bz and grid improvers.\n", + "Step 2. add the [pose,weight] of `pose` to the vector sampled by select_by_weight in the bz case\n", + "Step 3. How is the weight computed for `pose` ?\n", + " what we have now + correction term\n", + " pose.weight = full_model_kernel.assess(p, (cm,))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "120", + "metadata": {}, + "outputs": [], + "source": [ + "def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray):\n", + " cube_step_size = 8\n", + "\n", + " def grid_search_step(k: PRNGKey, gf, center_pose, observation):\n", + " pose_grid = grid_of_nearby_poses(center_pose, cube_step_size, motion_settings)\n", + " nearby_weights = grid_sample(gf, pose_grid, observation)\n", + " return nearby_weights, pose_grid\n", + "\n", + " def improved_step(state, update):\n", + " observation, control, key = update\n", + " gf = full_model_kernel(motion_settings, state, control)\n", + " # Run a sample and pick an element by weight.\n", + " k1, k2, k3 = jax.random.split(key, 3)\n", + " poses, scores = step_sample(k1, 1000, gf, observation)\n", + " new_pose, new_weight = select_by_weight(k2, scores, poses)\n", + " weights2, poses2 = grid_search_step(k2, gf, new_pose, observation)\n", + " # Note that `new_pose` will be among the poses considered by grid_search_step,\n", + " # so the possibility exists to remain stationary, as Bayesian inference requires\n", + " chosen_pose, _ = select_by_weight(k3, weights2, poses2)\n", + " flat_poses, flat_scores = flatten_pose_cube(poses2, cube_step_size, weights2)\n", + " return chosen_pose, (new_pose, chosen_pose, flat_scores, flat_poses, new_weight)\n", + "\n", + " sub_keys = jax.random.split(key, T + 1)\n", + " return jax.lax.scan(\n", + " improved_step,\n", + " robot_inputs[\"start\"],\n", + " (\n", + " observations, # observation at time t\n", + " robot_inputs[\"controls\"], # guides step from t to t+1\n", + " sub_keys[1:],\n", + " ),\n", + " )\n", + "\n", + "\n", + "jit_improved_path = jax.jit(improved_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "121", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "key, sub_key = jax.random.split(key)\n", + "_, improved_low = jit_improved_path(\n", + " sub_key, motion_settings_low_deviation, observations_low_deviation\n", + ")\n", + "key, sub_key = jax.random.split(key)\n", + "_, improved_high = jit_improved_path(\n", + " sub_key, motion_settings_high_deviation, observations_high_deviation\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "122", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "def path_comparison_plot(*plots):\n", + " types = [\"improved\", \"integrated\", \"importance\", \"true\"]\n", + " plot = world_plot\n", + " plot += [\n", + " animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t))\n", + " for p, t in zip(plots, types)\n", + " ]\n", + " plot += [poses_to_plots(p, fill=Plot.constantly(t)) for p, t in zip(plots, types)]\n", + " return plot + Plot.color_map(\n", + " {\n", + " \"integrated\": \"green\",\n", + " \"improved\": \"blue\",\n", + " \"true\": \"black\",\n", + " \"importance\": \"red\",\n", + " }\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "123", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "path_comparison_plot(\n", + " improved_low[0], path_integrated, low_deviation_path, path_low_deviation\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "124", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "path_comparison_plot(\n", + " improved_high[0], path_integrated, high_deviation_path, path_high_deviation\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "125", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "To see how the grid search improves poses, we play back the grid-search path\n", + "next to an importance sample path. You can see the grid search has a better fit\n", + "of sensor data to wall position at a variety of time steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "126", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "Plot.Row(\n", + " animate_path_and_sensors(\n", + " improved_high[0],\n", + " observations_high_deviation,\n", + " motion_settings_high_deviation,\n", + " frame_key=\"frame\",\n", + " ),\n", + " animate_path_and_sensors(\n", + " high_deviation_path,\n", + " observations_high_deviation,\n", + " motion_settings_high_deviation,\n", + " frame_key=\"frame\",\n", + " ),\n", + ") | Plot.Slider(\"frame\", 0, T - 1, fps=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "127", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Finishing touch: weave together the improved plot and the improvement steps\n", + "# into a slider animation\n", + "# Plot.Frames(\n", + "# [weighted_small_pose_plot(improved_high[0][k], path_high_deviation[k], improved_high[2][k], improved_high[1][k]) for k in range(T)],\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "128", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "def wsp_frame(k):\n", + " return path_comparison_plot(\n", + " improved_high[0][: k + 1],\n", + " path_integrated[: k + 1],\n", + " high_deviation_path[: k + 1],\n", + " path_high_deviation[: k + 1],\n", + " ) & weighted_small_pose_plot(\n", + " improved_high[1][k],\n", + " path_high_deviation[k],\n", + " improved_high[2][k],\n", + " improved_high[3][k],\n", + " zoom=4,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "129", + "metadata": {}, + "outputs": [], + "source": [ + "Plot.Frames([wsp_frame(k) for k in range(1, 6)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "130", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "title,-all", + "custom_cell_magics": "kql", + "encoding": "# -*- coding: utf-8 -*-" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 8c7e957..1ce4d86 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -16,14 +16,14 @@ # --- # pyright: reportUnusedExpression=false # %% -import sys +# import sys -if "google.colab" in sys.modules: - from google.colab import auth # pyright: ignore [reportMissingImports] +# if "google.colab" in sys.modules: +# from google.colab import auth # pyright: ignore [reportMissingImports] - auth.authenticate_user() - %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa - %pip install --quiet genjax==0.5.1 genstudio==2024.7.30.1617 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa +# auth.authenticate_user() +# %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa +# %pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa # %% [markdown] # # ProbComp Localization Tutorial # @@ -35,7 +35,6 @@ import json import genstudio.plot as Plot - import itertools import jax import jax.numpy as jnp @@ -43,9 +42,10 @@ from urllib.request import urlopen from genjax import SelectionBuilder as S from genjax import ChoiceMapBuilder as C -from genjax.typing import FloatArray, PRNGKey +from genjax.typing import Array, FloatArray, PRNGKey, IntArray from penzai import pz -from typing import Any, Iterable +from typing import Any, Iterable, TypeVar, Generic, Callable + import os @@ -78,6 +78,7 @@ # %% # General code here + @pz.pytree_dataclass class Pose(genjax.PythonicPytree): p: FloatArray @@ -169,6 +170,13 @@ def make_world(wall_verts, clutters_vec, start, controls): # How bouncy the walls are in this world. bounce = 0.1 + # We prepend a zero-effect control step to the control array. This allows + # numerous simplifications in what follows: we can consider the initial + # pose uncertainty as well as each subsequent step to be the same function + # of current position and control step. + noop_control = Control(jnp.array(0.0), jnp.array(0.0)) + controls = controls.prepend(noop_control) + # Determine the total number of control steps T = len(controls.ds) @@ -197,7 +205,9 @@ def load_world(file_name): Returns: - tuple: A tuple containing the world configuration, the initial state, and the total number of control steps. """ - with urlopen("https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json") as url: + with urlopen( + "https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json" + ) as url: data = json.load(url) walls_vec = jnp.array(data["wall_verts"]) @@ -223,11 +233,8 @@ def load_world(file_name): # # If the motion of the robot is determined in an ideal manner by the controls, then we may simply integrate to determine the resulting path. Naïvely, this results in the following. -# %% - -noop_control = Control(jnp.array([0.0]), jnp.array([0.0])) - +# %% def integrate_controls_unphysical(robot_inputs): """ Integrates the controls to generate a path from the starting pose. @@ -247,8 +254,7 @@ def integrate_controls_unphysical(robot_inputs): pose.apply_control(control), ), robot_inputs["start"], - # Prepend a no-op control to include the first pose in the result - noop_control + robot_inputs["controls"], + robot_inputs["controls"], )[1] @@ -374,7 +380,7 @@ def integrate_controls_physical(robot_inputs): new_pose, ), robot_inputs["start"], - noop_control + robot_inputs["controls"], + robot_inputs["controls"], )[1] @@ -382,11 +388,15 @@ def integrate_controls_physical(robot_inputs): path_integrated = integrate_controls_physical(robot_inputs) + # %% [markdown] # ### Plot such data # %% -def pose_plot(p, r=0.5, fill: str | Any = "black", **opts): - WING_ANGLE, WING_LENGTH = jnp.pi/12, 0.6 +def pose_plot(p, fill: str | Any = "black", **opts): + z = opts.get("zoom", 1.0) + r = z * 0.15 + wing_opacity = opts.get("opacity", 0.3) + WING_ANGLE, WING_LENGTH = jnp.pi / 12, z * opts.get("wing_length", 0.6) center = p.p angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1]) @@ -399,22 +409,23 @@ def pose_plot(p, r=0.5, fill: str | Any = "black", **opts): # Draw wings wings = Plot.line( [wing_ends[0], center, wing_ends[1]], - strokeWidth=2, + strokeWidth=opts.get("strokeWidth", 2), stroke=fill, - opacity=0.3 + opacity=wing_opacity, ) # Draw center dot - dot = Plot.ellipse([center], r=0.14, fill=fill, **opts) + dot = Plot.ellipse([center], fill=fill, **({"r": r} | opts)) return wings + dot + walls_plot = Plot.new( Plot.line( - Plot.cache(world["wall_verts"]), - strokeWidth=2, - stroke="#ccc", - ), + world["wall_verts"], + strokeWidth=2, + stroke="#ccc", + ), {"margin": 0, "inset": 50, "width": 500, "axis": None, "aspectRatio": 1}, Plot.domain([0, 20]), ) @@ -468,21 +479,23 @@ def pose_plot(p, r=0.5, fill: str | Any = "black", **opts): # # We start with the two building blocks: the starting pose and individual steps of motion. # %% -@genjax.gen -def start_pose_prior(start, motion_settings): - p = genjax.mv_normal(start.p, motion_settings["p_noise"] ** 2.0 * jnp.eye(2)) @ "p" - hd = genjax.normal(start.hd, motion_settings["hd_noise"]) @ "hd" - return Pose(p, hd) + +# TODO(colin,jay): Originally, we passed motion_settings['p_noise'] ** 2 to +# mv_normal_diag, but I think this squares the scale twice. TFP documenentation +# - https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiag +# states that: scale = diag(scale_diag); covariance = scale @ scale.T. The second +# equation will have the effect of squaring the individual diagonal scales. + @genjax.gen -def step_model(start, c, motion_settings): +def step_model(motion_settings, start, control): p = ( - genjax.mv_normal( - start.p + c.ds * start.dp(), motion_settings["p_noise"] ** 2.0 * jnp.eye(2) + genjax.mv_normal_diag( + start.p + control.ds * start.dp(), motion_settings["p_noise"] * jnp.ones(2) ) @ "p" ) - hd = genjax.normal(start.hd + c.dhd, motion_settings["hd_noise"]) @ "hd" + hd = genjax.normal(start.hd + control.dhd, motion_settings["hd_noise"]) @ "hd" return physical_step(start.p, p, hd) @@ -494,8 +507,8 @@ def step_model(start, c, motion_settings): # %% key = jax.random.PRNGKey(0) -start_pose_prior.simulate( - key, (robot_inputs["start"], default_motion_settings) +step_model.simulate( + key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]) ).get_retval() # %% [markdown] @@ -519,27 +532,23 @@ def make_circle(p, r): # Generate N_samples of starting poses from the prior N_samples = 50 key, sub_key = jax.random.split(key) -# pose_samples = [start_pose_prior.simulate(k, (robot_inputs['start'], motion_settings)) for k in sub_keys] pose_samples = jax.vmap(step_model.simulate, in_axes=(0, None))( jax.random.split(sub_key, N_samples), - (robot_inputs["start"], robot_inputs["controls"][0], default_motion_settings), + (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) -def poses_to_plots(poses: Iterable[Pose], **plot_opts): - return [ - pose_plot(pose, **plot_opts) - for pose in poses - ] +def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: + return Pose(jnp.array([pose.p for pose in pl]), [pose.hd for pose in pl]) -# Plot the world, starting pose samples, and 95% confidence region +def poses_to_plots(poses: Iterable[Pose], **plot_opts): + return [pose_plot(pose, **plot_opts) for pose in poses] +# Plot the world, starting pose samples, and 95% confidence region # Calculate the radius of the 95% confidence region def confidence_circle(pose: Pose, p_noise: float): - # TODO - # should this also take into account the hd_noise? return Plot.scaled_circle( *pose.p, fill=Plot.constantly("95% confidence region"), @@ -552,7 +561,7 @@ def confidence_circle(pose: Pose, p_noise: float): + poses_to_plots([robot_inputs["start"]], fill=Plot.constantly("step from here")) + confidence_circle( robot_inputs["start"].apply_control(robot_inputs["controls"][0]), - default_motion_settings['p_noise'], + default_motion_settings["p_noise"], ) + poses_to_plots(pose_samples.get_retval(), fill=Plot.constantly("step samples")) + Plot.color_map({"step from here": "#000", "step samples": "red"}) @@ -568,8 +577,9 @@ def confidence_circle(pose: Pose, p_noise: float): # %% # `simulate` takes the GF plus a tuple of args to pass to it. key, sub_key = jax.random.split(key) -trace = start_pose_prior.simulate( - sub_key, (robot_inputs["start"], default_motion_settings) +trace = step_model.simulate( + sub_key, + (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) trace.get_choices() @@ -638,7 +648,7 @@ def confidence_circle(pose: Pose, p_noise: float): # # A pose consists of a pair $z = (z_\text p, z_\text{hd})$ where $z_\text p$ is a position vector and $z_\text{hd}$ is an angle. A control consists of a pair $(s, \eta)$ where $s$ is a distance of displacement and $\eta$ is a change in angle. Write $u(\theta) = (\cos\theta, \sin\theta)$ for the unit vector in the direction $\theta$. We are given a "world" $w$ and "motion settings" parameters $\nu = (\nu_\text p, \nu_\text{hd})$. # -# The models `start_pose_prior` and `step_model` correspond to distributions over their traces, respectively written $\text{start}$ and $\text{step}$. In both cases these traces consist of the choices at addresses `:p` and `:hd`, so they may be identified with poses $z$ as above. The distributions are defined as follows, when $y$ is a pose: +# The models and `step_proposal` correspond to distributions over their traces, respectively written $\text{start}$ and $\text{step}$. In both cases these traces consist of the choices at addresses `:p` and `:hd`, so they may be identified with poses $z$ as above. The distributions are defined as follows, when $y$ is a pose: # * $z \sim \text{start}(y, \nu)$ means that $z_\text p \sim \text{mvnormal}(y_\text p, \nu_\text p^2 I)$ and $z_\text{hd} \sim \text{normal}(y_\text{hd}, \nu_\text{hd})$ independently. # * $z \sim \text{step}(y, (s, \eta), w, \nu)$ means that $z_\text p \sim \text{mvnormal}(y_\text p + s\,u(y_\text{hd}), \nu_\text p^2 I)$ and $z_\text{hd} \sim \text{normal}(y_\text{hd} + \eta, \nu_\text {hd})$ independently. # @@ -693,60 +703,19 @@ def confidence_circle(pose: Pose, p_noise: float): # (It is worth acknowledging two strange things in the code below: the use of the suffix `.accumulate()` in path_model and the use of that auxiliary function itself. # %% -@genjax.gen -def path_model_start(robot_inputs, motion_settings): - return start_pose_prior(robot_inputs["start"], motion_settings) @ ( - "initial", - "pose", - ) - -@genjax.gen -def path_model_step(motion_settings, previous_pose, control): - return step_model(previous_pose, control, motion_settings) @ ( - "steps", - "pose", - ) - - -def gen_partial(gen_fn, closed_over): - @genjax.gen - def inner(*args): - return gen_fn.inline(closed_over, *args) - return inner - -path_model = gen_partial(path_model_step, default_motion_settings).accumulate() - -# TODO(colin,huebert): talk about accumulate, what it does, and _why_ from the point of view of acceleration. This is the flow control modification we were hinting at above, and it constrains the step function to have the two-argument signature that it does, which is why we reached for `partial` in the first place. Emphasize that this small bit of preparation allows massively parallel execution on a GPU and so it's worth the hassle. - -key, sub_key1, sub_key2 = jax.random.split(key, 3) -initial_pose = path_model_start.simulate( - sub_key1, (robot_inputs, default_motion_settings) -) -step_model.simulate( - sub_key2, - ( - initial_pose.get_retval(), - robot_inputs["controls"][0], - default_motion_settings, - ), +path_model = ( + step_model.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() ) -# %% - +# result[0] ~~ robot_inputs['start'] + control_step[0] (which is zero) + noise +# %% def generate_path_trace(key: PRNGKey) -> genjax.Trace: - key, start_key = jax.random.split(key) - initial_pose = path_model_start.simulate( - start_key, (robot_inputs, default_motion_settings) - ) - key, step_key = jax.random.split(key) - return path_model.simulate( - step_key, (initial_pose.get_retval(), robot_inputs["controls"]) - ) + return path_model.simulate(key, (robot_inputs["start"], robot_inputs["controls"])) def path_from_trace(tr: genjax.Trace) -> Pose: - return tr.get_retval() + return tr.get_retval()[1] def generate_path(key: PRNGKey) -> Pose: @@ -755,18 +724,20 @@ def generate_path(key: PRNGKey) -> Pose: # %% key, sub_key = jax.random.split(key) -generate_path_trace(sub_key) +pt = generate_path_trace(sub_key) +pt # %% N_samples = 12 key, sub_key = jax.random.split(key) sample_paths_v = jax.vmap(generate_path)(jax.random.split(sub_key, N_samples)) -Plot.Grid([walls_plot + poses_to_plots(path) for path in sample_paths_v]) - +Plot.Grid(*[walls_plot + poses_to_plots(path) for path in sample_paths_v]) # %% # Animation showing a single path with confidence circles +# TODO: is there an off-by-one here possibly as a result of the zero initial step? +# TODO: how about plot the control vector? def plot_path_with_confidence(path: Pose, step: int, p_noise: float): plot = ( world_plot @@ -776,7 +747,8 @@ def plot_path_with_confidence(path: Pose, step: int, p_noise: float): if step < len(path) - 1: plot += [ confidence_circle( - path[step].apply_control(robot_inputs["controls"][step]), + # for a given index, step[index] is current pose, controls[index] is what was applied to prev pose + path[step].apply_control(robot_inputs["controls"][step + 1]), p_noise, ), pose_plot(path[step + 1], fill=Plot.constantly("next pose")), @@ -786,7 +758,7 @@ def plot_path_with_confidence(path: Pose, step: int, p_noise: float): def animate_path_with_confidence(path: Pose, motion_settings: dict): frames = [ - plot_path_with_confidence(path, step, motion_settings['p_noise']) + plot_path_with_confidence(path, step, motion_settings["p_noise"]) for step in range(len(path.p)) ] @@ -798,7 +770,7 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): path = generate_path(sample_key) Plot.Frames( [ - plot_path_with_confidence(path, step, default_motion_settings['p_noise']) + plot_path_with_confidence(path, step, default_motion_settings["p_noise"]) + Plot.title("Motion model (samples)") for step in range(len(path)) ], @@ -814,8 +786,9 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): # %% key, sub_key = jax.random.split(key) -trace = start_pose_prior.simulate( - sub_key, (robot_inputs["start"], default_motion_settings) +trace = step_model.simulate( + sub_key, + (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) key, sub_key = jax.random.split(key) rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( @@ -846,7 +819,7 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): key, sub_key = jax.random.split(key) rotated_first_step, rotated_first_step_weight_diff, _, _ = trace.update( - sub_key, C[0, "steps", "pose", "hd"].set(jnp.pi / 2.0) + sub_key, C[0, "hd"].set(jnp.pi / 2.0) ) # %% @@ -918,10 +891,10 @@ def plot_sensors(pose: Pose, readings): pose.rotate(angle).step_along(s) for angle, s in zip(sensor_angles, readings) ] - return ( + return [ Plot.line( [(x, y, i) for i, p in enumerate(projections) for x, y in [pose.p, p.p]], - stroke=Plot.constantly("sensor rays") + stroke=Plot.constantly("sensor rays"), ), [ Plot.dot( @@ -931,12 +904,10 @@ def plot_sensors(pose: Pose, readings): ) ], Plot.color_map({"sensor rays": "rgba(0,0,0,0.1)", "sensor readings": "#f80"}), - ) + ] -world_plot + plot_sensors( - initial_pose.get_retval(), ideal_sensor(initial_pose.get_retval()) -) +world_plot + plot_sensors(robot_inputs["start"], ideal_sensor(robot_inputs["start"])) # %% @@ -966,6 +937,7 @@ def animate_path_with_sensor(path, readings): # up to some noise. We model this as follows. # %% + @genjax.gen def sensor_model_one(pose, angle): sensor_pose = pose.rotate(angle) @@ -1025,42 +997,36 @@ def noisy_sensor(pose): # We fold the sensor model into the motion model to form a "full model", whose traces describe simulations of the entire robot situation as we have described it. # %% -def make_full_model(motion_settings): - @genjax.gen - def full_model_initial(): - pose = start_pose_prior(robot_inputs["start"], motion_settings) @ "pose" - sensor_model(pose, sensor_angles) @ "sensor" - return pose - @genjax.gen - def full_model_kernel(state, control): - pose = step_model(state, control, motion_settings) @ "pose" - sensor_model(pose, sensor_angles) @ "sensor" - return pose, pose +@genjax.gen +def full_model_kernel(motion_settings, state, control): + pose = step_model(motion_settings, state, control) @ "pose" + sensor_model(pose, sensor_angles) @ "sensor" + return pose - @genjax.gen - def full_model(): - initial = full_model_initial() @ "initial" - return full_model_kernel.scan(n=T)(initial, robot_inputs["controls"]) @ "steps" - return full_model +@genjax.gen +def full_model(motion_settings): + return ( + full_model_kernel.partial_apply(motion_settings) + .map(lambda r: (r, r)) + .scan()(robot_inputs["start"], robot_inputs["controls"]) + @ "steps" + ) + def get_path(trace): - p = trace.get_subtrace(("initial",)).get_retval() ps = trace.get_retval()[1] - return ps.prepend(p) + return ps def get_sensors(trace): ch = trace.get_choices() - return jnp.concatenate(( - ch["initial", "sensor", ..., "distance"][jnp.newaxis], - ch["steps", ..., "sensor", ..., "distance"] - )) + return ch["steps", :, "sensor", :, "distance"] + -default_full_model = make_full_model(default_motion_settings) key, sub_key = jax.random.split(key) -tr = default_full_model.simulate(sub_key, ()) +tr = full_model.simulate(sub_key, (default_motion_settings,)) pz.ts.display(tr) # %% @@ -1080,25 +1046,28 @@ def get_sensors(trace): # %% key, sub_key = jax.random.split(key) -tr = default_full_model.simulate(sub_key, ()) +tr = full_model.simulate(sub_key, (default_motion_settings,)) -def animate_full_trace(trace, frame_key=None): - path = get_path(trace) - readings = get_sensors(trace) - # since we use make_full_model to curry motion_settings around the scan combinator, - # that object will not be in the outer trace's argument list; but we can be a little - # crafty and find it at a lower level. - motion_settings = trace.get_subtrace(('initial',)).get_subtrace(('pose',)).get_args()[1] - +def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): frames = [ - plot_path_with_confidence(path, step, motion_settings['p_noise']) + plot_path_with_confidence(path, step, motion_settings["p_noise"]) + plot_sensors(pose, readings[step]) for step, pose in enumerate(path) ] return Plot.Frames(frames, fps=2, key=frame_key) + +def animate_full_trace(trace, frame_key=None): + path = get_path(trace) + readings = get_sensors(trace) + motion_settings = trace.get_args()[0] + return animate_path_and_sensors( + path, readings, motion_settings, frame_key=frame_key + ) + + animate_full_trace(tr) # %% [markdown] # ## The data @@ -1112,12 +1081,9 @@ def animate_full_trace(trace, frame_key=None): } key, k_low, k_high = jax.random.split(key, 3) -low_deviation_model = make_full_model(motion_settings_low_deviation) -trace_low_deviation = low_deviation_model.simulate(k_low, ()) - +trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,)) motion_settings_high_deviation = {"p_noise": 0.25, "hd_noise": 2 * jnp.pi / 360} -high_deviation_model = make_full_model(motion_settings_high_deviation) -trace_high_deviation = high_deviation_model.simulate(k_high, ()) +trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,)) animate_full_trace(trace_low_deviation) # %% @@ -1129,7 +1095,7 @@ def animate_full_trace(trace, frame_key=None): # Since we imagine these data as having been recorded from the real world, keep only their extracted data, *discarding* the traces that produced them. # %% -# These are are what we hope to recover... +# These are what we hope to recover... path_low_deviation = get_path(trace_low_deviation) path_high_deviation = get_path(trace_high_deviation) @@ -1140,13 +1106,11 @@ def animate_full_trace(trace, frame_key=None): # Encode sensor readings into choice map. -def constraint_from_sensors(readings): - angle_indices = jnp.arange(len(sensor_angles)) - return jax.vmap( - lambda ix, v: C["steps", ix, "sensor", angle_indices, "distance"].set(v) - )( - jnp.arange(T), readings[1:] - ) + C['initial', 'sensor', angle_indices, 'distance'].set(readings[0]) +def constraint_from_sensors(readings, t: int = T): + return C["steps", jnp.arange(t + 1), "sensor", :, "distance"].set(readings[: t + 1]) + # return jax.vmap( + # lambda v: C["steps", :, "sensor", :, "distance"].set(v) + # )(readings[:t]) constraints_low_deviation = constraint_from_sensors(observations_low_deviation) @@ -1155,14 +1119,14 @@ def constraint_from_sensors(readings): # %% [markdown] # We summarize the information available to the robot to determine its location. On the one hand, one has to produce a guess of the start pose plus some controls, which one might integrate to produce an idealized guess of path. On the other hand, one has the sensor data. + # %% def animate_bare_sensors(path, plot_base=[]): def frame(pose, readings1, readings2): def plt(readings): return Plot.new( - plot_base, + plot_base or Plot.domain([0, 20]), plot_sensors(pose, readings), - Plot.domain([0, 20]), {"width": 400, "height": 400}, ) @@ -1183,7 +1147,9 @@ def plt(readings): # The path obtained by integrating the controls serves as a proposal for the true path, but it is unsatisfactory, especially in the high motion deviation case. The picture gives an intuitive sense of the fit: # %% -animate_bare_sensors(path_integrated, walls_plot) +animate_bare_sensors(path_integrated, world_plot) +# %% +world_plot + plot_sensors(robot_inputs["start"], observations_low_deviation[0]) # %% [markdown] # It would seem that the fit is reasonable in low motion deviation, but really breaks down in high motion deviation. # @@ -1197,16 +1163,18 @@ def plt(readings): # Given a choice map of *constraints* that declare fixed values of some of the primitive choices, the operation `importance` proposes traces of the generative function that are consistent with these constraints. # %% -low_deviation_importance = jax.jit(low_deviation_model.importance) -high_deviation_importance = jax.jit(high_deviation_model.importance) +model_importance = jax.jit(full_model.importance) key, sub_key = jax.random.split(key) -sample, log_weight = low_deviation_importance(sub_key, constraints_low_deviation, ()) - +sample, log_weight = model_importance( + sub_key, constraints_low_deviation, (motion_settings_low_deviation,) +) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% key, sub_key = jax.random.split(key) -sample, log_weight = high_deviation_importance(sub_key, constraints_high_deviation, ()) +sample, log_weight = model_importance( + sub_key, constraints_high_deviation, (motion_settings_high_deviation,) +) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% [markdown] # A trace resulting from a call to `importance` is structurally indistinguishable from one drawn from `simulate`. But there is a key situational difference: while `get_score` always returns the frequency with which `simulate` stochastically produces the trace, this value is **no longer equal to** the frequency with which the trace is stochastically produced by `importance`. This is both true in an obvious and less relevant sense, as well as true in a more subtle and extremely germane sense. @@ -1244,33 +1212,39 @@ def plt(readings): # In words, the data are incongruously unlikely for the integrated path. The (log) density of the measurement data, given the integrated path... # %% -path_integrated -sample.get_choices() -constraints_path_integrated = C[""] -def constraint_from_path(path): - - c_ps = jax.vmap( - lambda ix, p: C["steps", ix, "pose", "p"].set(p) - )(jnp.arange(T), path.p[1:]) - c_hds = jax.vmap( - lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd) - )(jnp.arange(T), path.hd[1:]) +def constraint_from_path(path): + c_ps = jax.vmap(lambda ix, p: C["steps", ix, "pose", "p"].set(p))( + jnp.arange(T), path.p + ) - c_p = C["initial", "pose", "p"].set(path.p[0]) - c_hd = C["initial", "pose", "hd"].set(path.hd[0]) + c_hds = jax.vmap(lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd))( + jnp.arange(T), path.hd + ) + return c_ps + c_hds # + c_p + c_hd - return c_ps + c_hds + c_p + c_hd constraints_path_integrated = constraint_from_path(path_integrated) -constraints_path_integrated_observations_low_deviation = constraints_path_integrated ^ constraints_low_deviation -constraints_path_integrated_observations_high_deviation = constraints_path_integrated ^ constraints_high_deviation +constraints_path_integrated_observations_low_deviation = ( + constraints_path_integrated ^ constraints_low_deviation +) +constraints_path_integrated_observations_high_deviation = ( + constraints_path_integrated ^ constraints_high_deviation +) key, sub_key = jax.random.split(key) -trace_path_integrated_observations_low_deviation, w_low = low_deviation_importance(sub_key, constraints_path_integrated_observations_low_deviation, ()) +trace_path_integrated_observations_low_deviation, w_low = model_importance( + sub_key, + constraints_path_integrated_observations_low_deviation, + (motion_settings_low_deviation,), +) key, sub_key = jax.random.split(key) -trace_path_integrated_observations_high_deviation, w_high = high_deviation_importance(sub_key, constraints_path_integrated_observations_high_deviation, ()) +trace_path_integrated_observations_high_deviation, w_high = model_importance( + sub_key, + constraints_path_integrated_observations_high_deviation, + (motion_settings_high_deviation,), +) w_low, w_high # TODO: Jay then does two projections to compare the log-weights of these two things, @@ -1280,18 +1254,28 @@ def constraint_from_path(path): # %% Plot.Row( - *[(html("div.f3.b.tc", title) - | animate_full_trace(trace, frame_key="frame") - | html("span.tc", f"score: {score:,.2f}")) - for (title, trace, motion_settings, score) in - [["Low deviation", - trace_path_integrated_observations_low_deviation, - motion_settings_low_deviation, - w_low], - ["High deviation", - trace_path_integrated_observations_high_deviation, - motion_settings_high_deviation, - w_high]]]) | Plot.Slider("frame", T, fps=2) + *[ + ( + html("div.f3.b.tc", title) + | animate_full_trace(trace, frame_key="frame") + | html("span.tc", f"score: {score:,.2f}") + ) + for (title, trace, motion_settings, score) in [ + [ + "Low deviation", + trace_path_integrated_observations_low_deviation, + motion_settings_low_deviation, + w_low, + ], + [ + "High deviation", + trace_path_integrated_observations_high_deviation, + motion_settings_high_deviation, + w_high, + ], + ] + ] +) | Plot.Slider("frame", 0, T, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... @@ -1302,9 +1286,21 @@ def constraint_from_path(path): key, sub_key = jax.random.split(key) -traces_generated_low_deviation, low_weights = jax.vmap(low_deviation_importance, in_axes=(0, None, None))(jax.random.split(sub_key, N_samples), constraints_low_deviation, ()) +traces_generated_low_deviation, low_weights = jax.vmap( + model_importance, in_axes=(0, None, None) +)( + jax.random.split(sub_key, N_samples), + constraints_low_deviation, + (motion_settings_low_deviation,), +) -traces_generated_high_deviation, high_weights = jax.vmap(high_deviation_importance, in_axes=(0, None, None))(jax.random.split(sub_key, N_samples), constraints_high_deviation, ()) +traces_generated_high_deviation, high_weights = jax.vmap( + model_importance, in_axes=(0, None, None) +)( + jax.random.split(sub_key, N_samples), + constraints_high_deviation, + (motion_settings_high_deviation,), +) # low_weights, high_weights # two histograms @@ -1314,10 +1310,17 @@ def constraint_from_path(path): high_deviation_paths = jax.vmap(get_path)(traces_generated_high_deviation) # %% -Plot.new(world_plot, - [poses_to_plots(pose, fill="blue", opacity=0.1) for pose in high_deviation_paths[:20]], - [poses_to_plots(pose, fill="green", opacity=0.1) for pose in low_deviation_paths[:20]] - ) +Plot.new( + world_plot, + [ + poses_to_plots(pose, fill="blue", opacity=0.1) + for pose in high_deviation_paths[:20] + ], + [ + poses_to_plots(pose, fill="green", opacity=0.1) + for pose in low_deviation_paths[:20] + ], +) # %% [markdown] # ## Generic strategies for inference # @@ -1402,118 +1405,274 @@ def constraint_from_path(path): # ment over rejection sampling: intead of indefinitely constructing and rejecting samples, we can guarantee to use at least some of them after a fixed time, and we are using the best guesses among these. # %% -categorical_sampler = jax.jit(genjax.categorical.sampler) -def resample(key: PRNGKey, constraints: genjax.ChoiceMap, importance_model, N: int, K: int): + +def importance_sample( + key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int +): + """Produce N importance samples of depth K from the model. That is, N times, we + generate K importance samples conditioned by the constraints, and categorically + select one of them.""" key1, key2 = jax.random.split(key) - samples, log_weights = jax.vmap(importance_model, in_axes=(0, None, None))(jax.random.split(key1, N*K), constraints, ()) - winners = jax.vmap(categorical_sampler)(jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))) + samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( + jax.random.split(key1, N * K), constraints, (motion_settings,) + ) + winners = jax.vmap(genjax.categorical.sampler)( + jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) + ) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. - winners += jnp.arange(0, N*K, N) + winners += jnp.arange(0, N * K, N) selected = jax.tree.map(lambda x: x[winners], samples) return selected + +jit_resample = jax.jit(importance_sample, static_argnums=(3, 4)) + key, sub_key = jax.random.split(key) -low_posterior = resample(sub_key, constraints_low_deviation, low_deviation_importance, 2000, 20) +low_posterior = jit_resample( + sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20 +) key, sub_key = jax.random.split(key) -high_posterior = resample(sub_key, constraints_high_deviation, high_deviation_importance, 2000, 20) +high_posterior = jit_resample( + sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20 +) # %% -def animate_path_as_line(path, **options): - x_coords = path.p[:, 0] - y_coords = path.p[:, 1] - return Plot.line({"x": x_coords, "y": y_coords}, - {"curve": "cardinal-open", - **options}) -# -(world_plot - + [animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="green") for path in jax.vmap(get_path)(low_posterior)] - + [animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="blue") for path in jax.vmap(get_path)(high_posterior)] - + poses_to_plots(path_low_deviation, fill=Plot.constantly("low deviation path"), opacity=0.2) - + poses_to_plots(path_high_deviation, fill=Plot.constantly("high deviation path"), opacity=0.2) - + poses_to_plots(path_integrated, fill=Plot.constantly("integrated path"), opacity=0.2) - + Plot.color_map({"low deviation path": "green", "high deviation path": "blue", "integrated path": "black"})) + +def path_to_polyline(path, **options): + if len(path.p.shape) > 1: + x_coords = path.p[:, 0] + y_coords = path.p[:, 1] + return Plot.line({"x": x_coords, "y": y_coords}, {"curve": "linear", **options}) + else: + return Plot.dot([path.p], fill=options["stroke"], r=2, **options) + + +# +( + world_plot + + [ + path_to_polyline(path, opacity=0.2, strokeWidth=2, stroke="green") + for path in jax.vmap(get_path)(low_posterior) + ] + + [ + path_to_polyline(path, opacity=0.2, strokeWidth=2, stroke="blue") + for path in jax.vmap(get_path)(high_posterior) + ] + + poses_to_plots( + path_low_deviation, fill=Plot.constantly("low deviation path"), opacity=0.2 + ) + + poses_to_plots( + path_high_deviation, fill=Plot.constantly("high deviation path"), opacity=0.2 + ) + + poses_to_plots( + path_integrated, fill=Plot.constantly("integrated path"), opacity=0.2 + ) + + Plot.color_map( + { + "low deviation path": "green", + "high deviation path": "blue", + "integrated path": "black", + } + ) +) # %% [markdown] # Let's pause a moment to examine this chart. If the robot had no sensors, it would have no alternative but to estimate its position by integrating the control inputs to produce the integrated path in gray. In the low deviation setting, Gen has helped the robot to see that about halfway through its journey, noise in the control-effector relationship has caused the robot to deviate to the south slightly, and *the sensor data combined with importance sampling is enough* to give accurate results in the low deviation setting. # But in the high deviation setting, the loose nature of the paths in the blue posterior indicate that the robot has not discovered its true position by using importance sampling with the noisy sensor data. In the high deviation setting, more refined inference technique will be required. -# %% [markdown] -# One way to approach this task would be to pause after each step and use importance sampling to refine our estimate of pose using the sensor data available at that step. Instead of a global importance sample, which generated the traces above, we will use a stepwise importance sampling technique that accumulates the information we have learned about previous steps. -# %% -# Take one trace from the high deviation collection, and extract the step pose data from it. -t0 = jax.tree.map(lambda v: v[0], high_posterior) -t0_ch = t0.get_choices() -pz.ts.display(t0_ch['steps', ..., 'pose', 'p'], t0_ch['steps', ..., 'pose', 'hd']) -# %% [markdown] -# That's the collection of pose data for one trace. Suppose we could get a measure of what influence a small perturbation might make to the overall likelihood of the path given the sensor data? This is what `update` is for! Using GenJAX, we can use acceleration to propose many updates in parallel, if you have a GPU, and select from among the results properly weighted by their likelihooods. -# In order to make this happen, we will write a function that make *one* update propasal for *one* trace, and use the power of `vmap` to perform a local exploration of probability space. +# +# Let's approach the problem step by step instead of trying to infer the whole path at once. +# The technique we will use is called Sequential Importance Sampling or a +# [Particle Filter](https://en.wikipedia.org/wiki/Particle_filter). It works like this. +# +# When we designed the step model for the robot, we arranged things so that the model +# could be used with `scan`: the model takes a *state* and a *control input* to produce +# a new *state*. Imagine at some time step $t$ that we use importance sampling with this +# model at a pose $\mathbf{z}_t$ and control input $\mathbf{u}_t$, scored with respect to the +# sensor observations $\mathbf{y}_t$ observed at that time. We will get a weighted collection +# of possible updated poses $\mathbf{z}_t^N$ and weights $w^N$. +# +# The particle filter "winnows" this set by replacing it with $N$ weighted selections +# *with replacement* from this collection. This may select better candidates several +# times, and is likely to drop poor candidates from the collection. We can arrange to +# to this at each time step with a little preparation: we start by "cloning" our idea +# of the robot's initial position into an N vector and this becomes the initial particle +# collection. At each step, we generate an importance sample and winnow it. +# +# This can also be done as a scan. Our previous attempt used `scan` to produce candidate +# paths from start to end, and these were scored for importance using all of the sensor +# readings at once. The results were better than guesses, but not accurate, in the +# high deviation setting. +# +# The technique we will use here discards steps with low likelihood at each step, and +# reinforces steps with high likelihood, allowing better particles to proportionately +# search more of the probability space while discarding unpromising particles. +# +# The following class attempts to generatlize this idea: # %% -def gaussian_drift(key, trace: genjax.Trace, motion_settings): - k1, k2, k3 = jax.random.split(key, 3) - ch = trace.get_choices() - # Get existing path data. - ps = ch['steps', ..., 'pose', 'p'] - hds = ch['steps', ..., 'pose', 'hd'] - ps += jax.random.normal(key=k1, shape=ps.shape) * motion_settings['p_noise'] - hds += jax.random.normal(key=k2, shape=hds.shape) * motion_settings['hd_noise'] - N = ps.shape[0] - update_choices = ( C['steps', jnp.arange(N), 'pose', 'p'].set(ps) + C['steps', jnp.arange(N), 'pose', 'hd'].set(hds)) - return trace.update(k3, update_choices) - -# Let's try it on the test path -key, sub_key = jax.random.split(key) -gaussian_drift(sub_key, t0, motion_settings_high_deviation) +StateT = TypeVar("StateT") +ControlT = TypeVar("ControlT") -# %% [markdown] -# Now that we can do one, we can do 1000. -# %% -key, sub_key = jax.random.split(key) -N_updates = 1000 -drift_traces, log_weights, _, _ = jax.vmap(gaussian_drift, in_axes=(0, None, None))(jax.random.split(sub_key, 1000), t0, motion_settings_high_deviation) -# %% [markdown] -# Let's weightedly-select 10 from among those and see if there's any improvement -# %% -key, sub_key = jax.random.split(key) -N_selection = 10 -selected_indices = jax.vmap(categorical_sampler, in_axes=(0, None))(jax.random.split(sub_key, N_selection), log_weights) -selected_indices -# %% [markdown] -# Do you notice that many (or all) the selected indices are repeats? This is because we are searching a probability space of high dimension: it's unlikely that there will be many traces producing a dramatic improvement. Even if there's only one, we'll write the plotting function for a selection of drifted traces: after that, we will fix the problem of repeated selections. -# %% +class SequentialImportanceSampling(Generic[StateT, ControlT]): + """ + Given: + - a functional wrapper for the importance method of a generative function + - an initial state of type StateT, which should be a PyTree $z_0$ + - a vector of control inputs, also a PyTree $u_i, of shape $(T, \ldots)$ + - an array of observations $y_i$, also of shape $(T, \ldots)$ + perform the inference technique known as Sequential Importance Sampling. + + The signature of the GFI importance method is + key -> constraint -> args -> (trace, weight) + For importance sampling, this is vmapped over key to get + [keys] -> constraint -> args -> ([trace], [weight]) + The functional wrapper's purpose is to maneuver the state and control + inputs into whatever argument shape the underlying model is expecting, + and to turn the observation at step $t$ into a choicemap asserting + that constraint. + + After the object is constructed, SIS can be performed at any importance + depth with the `run` method, which will perform the following steps: + + - inflate the initial value to a vector of size N of identical initial + values + - vmap over N keys generated from the supplied key + - each vmap cell will scan over the control inputs and observations + + Between each step, categorical sampling with replacement is formed to + create a particle filter. Favorable importance draws are likely to + be replicated, and unfavorable ones discarded. The resampled vector of + states is sent the the next step, while the values drawn from the + importance sample and the indices chosen are emitted from teh scan step, + where, at the end of the process, they will be available as matrices + of shape (N, T). + """ -selected_traces = jax.tree.map(lambda v: v[selected_indices], drift_traces) + def __init__( + self, + importance: Callable[ + [PRNGKey, StateT, ControlT, Array], tuple[genjax.Trace[StateT], float] + ], + init: StateT, + controls: ControlT, + observations: Array, + ): + self.importance = jax.jit(importance) + self.init = init + self.controls = controls + self.observations = observations + + class Result(Generic[StateT]): + """This object contains all of the information generated by the SIS scan, + and offers some convenient methods to reconstruct the paths explored + (`flood_fill`) or ultimately chosen (`backtrack`). + """ -def plot_traces(traces): - return (world_plot - + [animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="green") for path in jax.vmap(get_path)(traces)] - + poses_to_plots(path_high_deviation, fill=Plot.constantly("high deviation path"), opacity=0.2) - + Plot.color_map({"low deviation path": "green", "high deviation path": "blue", "integrated path": "black"})) + def __init__( + self, N: int, end: StateT, samples: genjax.Trace[StateT], indices: IntArray + ): + self.N = N + self.end = end + self.samples = samples + self.indices = indices + + def flood_fill(self) -> list[list[StateT]]: + samples = self.samples.get_retval() + active_paths = [[p] for p in samples[0]] + complete_paths = [] + for i in range(1, len(samples)): + indices = self.indices[i - 1] + counts = jnp.bincount(indices, length=self.N) + new_active_paths = self.N * [None] + for j in range(self.N): + if counts[j] == 0: + complete_paths.append(active_paths[j]) + new_active_paths[j] = active_paths[indices[j]] + [samples[i][j]] + active_paths = new_active_paths + + return complete_paths + active_paths + + def backtrack(self) -> list[list[StateT]]: + paths = [[p] for p in self.end] + samples = self.samples.get_retval() + for i in reversed(range(len(samples))): + for j in range(len(paths)): + paths[j].append(samples[i][self.indices[i][j].item()]) + for p in paths: + p.reverse() + return paths + + def run(self, key: PRNGKey, N: int) -> dict: + def step(state, update): + key, control, observation = update + ks = jax.random.split(key, (2, N)) + sample, log_weights = jax.vmap(self.importance, in_axes=(0, 0, None, None))( + ks[0], state, control, observation + ) + indices = jax.vmap(genjax.categorical.sampler, in_axes=(0, None))( + ks[1], log_weights + ) + resample = jax.tree.map(lambda v: v[indices], sample) + return resample.get_retval(), (sample, indices) -plot_traces(selected_traces) + init_array = jax.tree.map( + lambda a: jnp.broadcast_to(a, (N,) + a.shape), self.init + ) + end, (samples, indices) = jax.lax.scan( + step, + init_array, + ( + jax.random.split(key, len(self.controls)), + self.controls, + self.observations, + ), + ) + return SequentialImportanceSampling.Result(N, end, samples, indices) -# %% [markdown] -# That looks promising, but there may only be one path in that output, since one of the drifted traces is probabilistically dominant. How can we get more candidate traces? We can use `vmap` *again*, to provide a fresh batch of drift samples for each desired trace. That will give us a weighted sample of potentially-improved traces to work with. # %% -# Generate K drifted samples, by generating N importance samples for each K and making a weighted selection from each batch. -def multi_drift(key, trace: genjax.Trace, scale, K: int, N: int): - k1, k2 = jax.random.split(key) - kn_samples, log_weights, _, _ = jax.vmap(gaussian_drift, in_axes=(0, None, None))(jax.random.split(k1, N*K), trace, scale) - batched_weights = log_weights.reshape((K, N)) - winners = jax.vmap(categorical_sampler)(jax.random.split(k2, K), batched_weights) - # The winning indices are relative to the batch from which they were drawn. Reset the indices to linear form. - winners += jnp.arange(0, N*K, N) - return jax.tree.map(lambda v: v[winners], kn_samples) +def localization_sis(motion_settings, observations): + return SequentialImportanceSampling( + lambda key, pose, control, observation: full_model_kernel.importance( + key, + C["sensor", :, "distance"].set(observation), + (motion_settings, pose, control), + ), + robot_inputs["start"], + robot_inputs["controls"], + observations, + ) # %% + key, sub_key = jax.random.split(key) -drifted_traces = multi_drift(sub_key, t0, motion_settings_high_deviation, 20, 1000) -plot_traces(drifted_traces) -# %% [markdown] -# We can see some improvement in the density of the paths selected. It's possible to imagine improving the search by repeating this drift process on all of the samples retured by the original importance sample. But we must face one important fact: we have used acceleration to improve what amounts to a brute-force search. The next inference step should take advantage of the information we have about the control steps, iteratively improving the path from the starting point, combining the control step and sensor data information to refine the selection of each step as it is made. +smc_result = localization_sis( + motion_settings_high_deviation, observations_high_deviation +).run(sub_key, 100) +( + world_plot + + path_to_polyline(path_high_deviation, stroke="blue", strokeWidth=2) + + [ + path_to_polyline(pose_list_to_plural_pose(p), opacity=0.1, stroke="green") + for p in smc_result.flood_fill() + ] +) # %% +# Try it in the low deviation setting +key, sub_key = jax.random.split(key) +low_smc_result = localization_sis( + motion_settings_low_deviation, observations_low_deviation +).run(sub_key, 20) +( + world_plot + + path_to_polyline(path_low_deviation, stroke="blue", strokeWidth=2) + + [ + path_to_polyline(pose_list_to_plural_pose(p), opacity=0.1, stroke="green") + for p in low_smc_result.flood_fill() + ] +) diff --git a/poetry.lock b/poetry.lock index 43f9fcb..cfcb83f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -13,13 +13,13 @@ files = [ [[package]] name = "anyio" -version = "4.4.0" +version = "4.6.2.post1" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, - {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, + {file = "anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d"}, + {file = "anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c"}, ] [package.dependencies] @@ -27,9 +27,9 @@ idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.23)"] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +trio = ["trio (>=0.26.1)"] [[package]] name = "anywidget" @@ -82,22 +82,22 @@ test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] [[package]] name = "attrs" -version = "23.2.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, - {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] -tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "beartype" @@ -119,74 +119,89 @@ test-tox-coverage = ["coverage (>=5.5)"] [[package]] name = "certifi" -version = "2024.7.4" +version = "2024.8.30" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, - {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, + {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"}, + {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, ] [[package]] name = "cffi" -version = "1.16.0" +version = "1.17.1" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" files = [ - {file = "cffi-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088"}, - {file = "cffi-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7"}, - {file = "cffi-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614"}, - {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743"}, - {file = "cffi-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d"}, - {file = "cffi-1.16.0-cp310-cp310-win32.whl", hash = "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a"}, - {file = "cffi-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1"}, - {file = "cffi-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404"}, - {file = "cffi-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56"}, - {file = "cffi-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e"}, - {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc"}, - {file = "cffi-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb"}, - {file = "cffi-1.16.0-cp311-cp311-win32.whl", hash = "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab"}, - {file = "cffi-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba"}, - {file = "cffi-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956"}, - {file = "cffi-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6"}, - {file = "cffi-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969"}, - {file = "cffi-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520"}, - {file = "cffi-1.16.0-cp312-cp312-win32.whl", hash = "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b"}, - {file = "cffi-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235"}, - {file = "cffi-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b"}, - {file = "cffi-1.16.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324"}, - {file = "cffi-1.16.0-cp38-cp38-win32.whl", hash = "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a"}, - {file = "cffi-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36"}, - {file = "cffi-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed"}, - {file = "cffi-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4"}, - {file = "cffi-1.16.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098"}, - {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000"}, - {file = "cffi-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe"}, - {file = "cffi-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4"}, - {file = "cffi-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8"}, - {file = "cffi-1.16.0.tar.gz", hash = "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"}, + {file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"}, + {file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"}, + {file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"}, + {file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"}, + {file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"}, + {file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"}, + {file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"}, + {file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"}, + {file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"}, + {file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"}, + {file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"}, + {file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"}, + {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, + {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, ] [package.dependencies] @@ -194,112 +209,127 @@ pycparser = "*" [[package]] name = "charset-normalizer" -version = "3.3.2" +version = "3.4.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, - {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ed2e36c3e9b4f21dd9422f6893dec0abf2cca553af509b10cd630f878d3eb99"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d3ff7fc90b98c637bda91c89d51264a3dcf210cade3a2c6f838c7268d7a4ca"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1110e22af8ca26b90bd6364fe4c763329b0ebf1ee213ba32b68c73de5752323d"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86f4e8cca779080f66ff4f191a685ced73d2f72d50216f7112185dc02b90b9b7"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f683ddc7eedd742e2889d2bfb96d69573fde1d92fcb811979cdb7165bb9c7d3"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27623ba66c183eca01bf9ff833875b459cad267aeeb044477fedac35e19ba907"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f606a1881d2663630ea5b8ce2efe2111740df4b687bd78b34a8131baa007f79b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0b309d1747110feb25d7ed6b01afdec269c647d382c857ef4663bbe6ad95a912"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:136815f06a3ae311fae551c3df1f998a1ebd01ddd424aa5603a4336997629e95"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:14215b71a762336254351b00ec720a8e85cada43b987da5a042e4ce3e82bd68e"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:79983512b108e4a164b9c8d34de3992f76d48cadc9554c9e60b43f308988aabe"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win32.whl", hash = "sha256:c94057af19bc953643a33581844649a7fdab902624d2eb739738a30e2b3e60fc"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55f56e2ebd4e3bc50442fbc0888c9d8c94e4e06a933804e2af3e89e2f9c1c749"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0d99dd8ff461990f12d6e42c7347fd9ab2532fb70e9621ba520f9e8637161d7c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c57516e58fd17d03ebe67e181a4e4e2ccab1168f8c2976c6a334d4f819fe5944"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dba5d19c4dfab08e58d5b36304b3f92f3bd5d42c1a3fa37b5ba5cdf6dfcbcee"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf4475b82be41b07cc5e5ff94810e6a01f276e37c2d55571e3fe175e467a1a1c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce031db0408e487fd2775d745ce30a7cd2923667cf3b69d48d219f1d8f5ddeb6"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ff4e7cdfdb1ab5698e675ca622e72d58a6fa2a8aa58195de0c0061288e6e3ea"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3710a9751938947e6327ea9f3ea6332a09bf0ba0c09cae9cb1f250bd1f1549bc"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82357d85de703176b5587dbe6ade8ff67f9f69a41c0733cf2425378b49954de5"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47334db71978b23ebcf3c0f9f5ee98b8d65992b65c9c4f2d34c2eaf5bcaf0594"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ce7fd6767a1cc5a92a639b391891bf1c268b03ec7e021c7d6d902285259685c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f1a2f519ae173b5b6a2c9d5fa3116ce16e48b3462c8b96dfdded11055e3d6365"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:63bc5c4ae26e4bc6be6469943b8253c0fd4e4186c43ad46e713ea61a0ba49129"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcb4f8ea87d03bc51ad04add8ceaf9b0f085ac045ab4d74e73bbc2dc033f0236"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win32.whl", hash = "sha256:9ae4ef0b3f6b41bad6366fb0ea4fc1d7ed051528e113a60fa2a65a9abb5b1d99"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cee4373f4d3ad28f1ab6290684d8e2ebdb9e7a1b74fdc39e4c211995f77bec27"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dbe03226baf438ac4fda9e2d0715022fd579cb641c4cf639fa40d53b2fe6f3e2"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd9a8bd8900e65504a305bf8ae6fa9fbc66de94178c420791d0293702fce2df7"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8831399554b92b72af5932cdbbd4ddc55c55f631bb13ff8fe4e6536a06c5c51"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a14969b8691f7998e74663b77b4c36c0337cb1df552da83d5c9004a93afdb574"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaf7c1524c0542ee2fc82cc8ec337f7a9f7edee2532421ab200d2b920fc97cf"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425c5f215d0eecee9a56cdb703203dda90423247421bf0d67125add85d0c4455"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:d5b054862739d276e09928de37c79ddeec42a6e1bfc55863be96a36ba22926f6"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:f3e73a4255342d4eb26ef6df01e3962e73aa29baa3124a8e824c5d3364a65748"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:2f6c34da58ea9c1a9515621f4d9ac379871a8f21168ba1b5e09d74250de5ad62"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f09cb5a7bbe1ecae6e87901a2eb23e0256bb524a79ccc53eb0b7629fbe7677c4"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:0099d79bdfcf5c1f0c2c72f91516702ebf8b0b8ddd8905f97a8aecf49712c621"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win32.whl", hash = "sha256:9c98230f5042f4945f957d006edccc2af1e03ed5e37ce7c373f00a5a4daa6149"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62f60aebecfc7f4b82e3f639a7d1433a20ec32824db2199a11ad4f5e146ef5ee"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:af73657b7a68211996527dbfeffbb0864e043d270580c5aef06dc4b659a4b578"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cab5d0b79d987c67f3b9e9c53f54a61360422a5a0bc075f43cab5621d530c3b6"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9289fd5dddcf57bab41d044f1756550f9e7cf0c8e373b8cdf0ce8773dc4bd417"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b493a043635eb376e50eedf7818f2f322eabbaa974e948bd8bdd29eb7ef2a51"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fa2566ca27d67c86569e8c85297aaf413ffab85a8960500f12ea34ff98e4c41"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8e538f46104c815be19c975572d74afb53f29650ea2025bbfaef359d2de2f7f"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd30dc99682dc2c603c2b315bded2799019cea829f8bf57dc6b61efde6611c8"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2006769bd1640bdf4d5641c69a3d63b71b81445473cac5ded39740a226fa88ab"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc15e99b2d8a656f8e666854404f1ba54765871104e50c8e9813af8a7db07f12"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ab2e5bef076f5a235c3774b4f4028a680432cded7cad37bba0fd90d64b187d19"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:4ec9dd88a5b71abfc74e9df5ebe7921c35cbb3b641181a531ca65cdb5e8e4dea"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:43193c5cda5d612f247172016c4bb71251c784d7a4d9314677186a838ad34858"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:aa693779a8b50cd97570e5a0f343538a8dbd3e496fa5dcb87e29406ad0299654"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win32.whl", hash = "sha256:7706f5850360ac01d80c89bcef1640683cc12ed87f42579dab6c5d3ed6888613"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:c3e446d253bd88f6377260d07c895816ebf33ffffd56c1c792b13bff9c3e1ade"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:980b4f289d1d90ca5efcf07958d3eb38ed9c0b7676bf2831a54d4f66f9c27dfa"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f28f891ccd15c514a0981f3b9db9aa23d62fe1a99997512b0491d2ed323d229a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8aacce6e2e1edcb6ac625fb0f8c3a9570ccc7bfba1f63419b3769ccf6a00ed0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7af3717683bea4c87acd8c0d3d5b44d56120b26fd3f8a692bdd2d5260c620a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff2ed8194587faf56555927b3aa10e6fb69d931e33953943bc4f837dfee2242"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e91f541a85298cf35433bf66f3fab2a4a2cff05c127eeca4af174f6d497f0d4b"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309a7de0a0ff3040acaebb35ec45d18db4b28232f21998851cfa709eeff49d62"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:285e96d9d53422efc0d7a17c60e59f37fbf3dfa942073f666db4ac71e8d726d0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d447056e2ca60382d460a604b6302d8db69476fd2015c81e7c35417cfabe4cd"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:20587d20f557fe189b7947d8e7ec5afa110ccf72a3128d61a2a387c3313f46be"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:130272c698667a982a5d0e626851ceff662565379baf0ff2cc58067b81d4f11d"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ab22fbd9765e6954bc0bcff24c25ff71dcbfdb185fcdaca49e81bac68fe724d3"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7782afc9b6b42200f7362858f9e73b1f8316afb276d316336c0ec3bd73312742"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win32.whl", hash = "sha256:2de62e8801ddfff069cd5c504ce3bc9672b23266597d4e4f50eda28846c322f2"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:95c3c157765b031331dd4db3c775e58deaee050a3042fcad72cbc4189d7c8dca"}, + {file = "charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079"}, + {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"}, ] [[package]] name = "cloudpickle" -version = "3.0.0" +version = "3.1.0" description = "Pickler class to extend the standard pickle.Pickler functionality" optional = false python-versions = ">=3.8" files = [ - {file = "cloudpickle-3.0.0-py3-none-any.whl", hash = "sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7"}, - {file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"}, + {file = "cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e"}, + {file = "cloudpickle-3.1.0.tar.gz", hash = "sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b"}, ] [[package]] @@ -332,66 +362,87 @@ test = ["pytest"] [[package]] name = "contourpy" -version = "1.2.1" +version = "1.3.0" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false python-versions = ">=3.9" files = [ - {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, - {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"}, - {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"}, - {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"}, - {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"}, - {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, - {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, - {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"}, - {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"}, - {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"}, - {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, + {file = "contourpy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7"}, + {file = "contourpy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223"}, + {file = "contourpy-1.3.0-cp310-cp310-win32.whl", hash = "sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f"}, + {file = "contourpy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb"}, + {file = "contourpy-1.3.0-cp311-cp311-win32.whl", hash = "sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c"}, + {file = "contourpy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35"}, + {file = "contourpy-1.3.0-cp312-cp312-win32.whl", hash = "sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb"}, + {file = "contourpy-1.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8"}, + {file = "contourpy-1.3.0-cp313-cp313-win32.whl", hash = "sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294"}, + {file = "contourpy-1.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800"}, + {file = "contourpy-1.3.0-cp39-cp39-win32.whl", hash = "sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5"}, + {file = "contourpy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb"}, + {file = "contourpy-1.3.0.tar.gz", hash = "sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4"}, ] [package.dependencies] -numpy = ">=1.20" +numpy = ">=1.23" [package.extras] bokeh = ["bokeh", "selenium"] docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] -mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pillow"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pillow"] test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] -test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] +test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] [[package]] name = "cycler" @@ -410,33 +461,37 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"] [[package]] name = "debugpy" -version = "1.8.2" +version = "1.8.7" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" files = [ - {file = "debugpy-1.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7ee2e1afbf44b138c005e4380097d92532e1001580853a7cb40ed84e0ef1c3d2"}, - {file = "debugpy-1.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f8c3f7c53130a070f0fc845a0f2cee8ed88d220d6b04595897b66605df1edd6"}, - {file = "debugpy-1.8.2-cp310-cp310-win32.whl", hash = "sha256:f179af1e1bd4c88b0b9f0fa153569b24f6b6f3de33f94703336363ae62f4bf47"}, - {file = "debugpy-1.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:0600faef1d0b8d0e85c816b8bb0cb90ed94fc611f308d5fde28cb8b3d2ff0fe3"}, - {file = "debugpy-1.8.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:8a13417ccd5978a642e91fb79b871baded925d4fadd4dfafec1928196292aa0a"}, - {file = "debugpy-1.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acdf39855f65c48ac9667b2801234fc64d46778021efac2de7e50907ab90c634"}, - {file = "debugpy-1.8.2-cp311-cp311-win32.whl", hash = "sha256:2cbd4d9a2fc5e7f583ff9bf11f3b7d78dfda8401e8bb6856ad1ed190be4281ad"}, - {file = "debugpy-1.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:d3408fddd76414034c02880e891ea434e9a9cf3a69842098ef92f6e809d09afa"}, - {file = "debugpy-1.8.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:5d3ccd39e4021f2eb86b8d748a96c766058b39443c1f18b2dc52c10ac2757835"}, - {file = "debugpy-1.8.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62658aefe289598680193ff655ff3940e2a601765259b123dc7f89c0239b8cd3"}, - {file = "debugpy-1.8.2-cp312-cp312-win32.whl", hash = "sha256:bd11fe35d6fd3431f1546d94121322c0ac572e1bfb1f6be0e9b8655fb4ea941e"}, - {file = "debugpy-1.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:15bc2f4b0f5e99bf86c162c91a74c0631dbd9cef3c6a1d1329c946586255e859"}, - {file = "debugpy-1.8.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:5a019d4574afedc6ead1daa22736c530712465c0c4cd44f820d803d937531b2d"}, - {file = "debugpy-1.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40f062d6877d2e45b112c0bbade9a17aac507445fd638922b1a5434df34aed02"}, - {file = "debugpy-1.8.2-cp38-cp38-win32.whl", hash = "sha256:c78ba1680f1015c0ca7115671fe347b28b446081dada3fedf54138f44e4ba031"}, - {file = "debugpy-1.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cf327316ae0c0e7dd81eb92d24ba8b5e88bb4d1b585b5c0d32929274a66a5210"}, - {file = "debugpy-1.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:1523bc551e28e15147815d1397afc150ac99dbd3a8e64641d53425dba57b0ff9"}, - {file = "debugpy-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e24ccb0cd6f8bfaec68d577cb49e9c680621c336f347479b3fce060ba7c09ec1"}, - {file = "debugpy-1.8.2-cp39-cp39-win32.whl", hash = "sha256:7f8d57a98c5a486c5c7824bc0b9f2f11189d08d73635c326abef268f83950326"}, - {file = "debugpy-1.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:16c8dcab02617b75697a0a925a62943e26a0330da076e2a10437edd9f0bf3755"}, - {file = "debugpy-1.8.2-py2.py3-none-any.whl", hash = "sha256:16e16df3a98a35c63c3ab1e4d19be4cbc7fdda92d9ddc059294f18910928e0ca"}, - {file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"}, + {file = "debugpy-1.8.7-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:95fe04a573b8b22896c404365e03f4eda0ce0ba135b7667a1e57bd079793b96b"}, + {file = "debugpy-1.8.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:628a11f4b295ffb4141d8242a9bb52b77ad4a63a2ad19217a93be0f77f2c28c9"}, + {file = "debugpy-1.8.7-cp310-cp310-win32.whl", hash = "sha256:85ce9c1d0eebf622f86cc68618ad64bf66c4fc3197d88f74bb695a416837dd55"}, + {file = "debugpy-1.8.7-cp310-cp310-win_amd64.whl", hash = "sha256:29e1571c276d643757ea126d014abda081eb5ea4c851628b33de0c2b6245b037"}, + {file = "debugpy-1.8.7-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:caf528ff9e7308b74a1749c183d6808ffbedbb9fb6af78b033c28974d9b8831f"}, + {file = "debugpy-1.8.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cba1d078cf2e1e0b8402e6bda528bf8fda7ccd158c3dba6c012b7897747c41a0"}, + {file = "debugpy-1.8.7-cp311-cp311-win32.whl", hash = "sha256:171899588bcd412151e593bd40d9907133a7622cd6ecdbdb75f89d1551df13c2"}, + {file = "debugpy-1.8.7-cp311-cp311-win_amd64.whl", hash = "sha256:6e1c4ffb0c79f66e89dfd97944f335880f0d50ad29525dc792785384923e2211"}, + {file = "debugpy-1.8.7-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:4d27d842311353ede0ad572600c62e4bcd74f458ee01ab0dd3a1a4457e7e3706"}, + {file = "debugpy-1.8.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:703c1fd62ae0356e194f3e7b7a92acd931f71fe81c4b3be2c17a7b8a4b546ec2"}, + {file = "debugpy-1.8.7-cp312-cp312-win32.whl", hash = "sha256:2f729228430ef191c1e4df72a75ac94e9bf77413ce5f3f900018712c9da0aaca"}, + {file = "debugpy-1.8.7-cp312-cp312-win_amd64.whl", hash = "sha256:45c30aaefb3e1975e8a0258f5bbd26cd40cde9bfe71e9e5a7ac82e79bad64e39"}, + {file = "debugpy-1.8.7-cp313-cp313-macosx_14_0_universal2.whl", hash = "sha256:d050a1ec7e925f514f0f6594a1e522580317da31fbda1af71d1530d6ea1f2b40"}, + {file = "debugpy-1.8.7-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2f4349a28e3228a42958f8ddaa6333d6f8282d5edaea456070e48609c5983b7"}, + {file = "debugpy-1.8.7-cp313-cp313-win32.whl", hash = "sha256:11ad72eb9ddb436afb8337891a986302e14944f0f755fd94e90d0d71e9100bba"}, + {file = "debugpy-1.8.7-cp313-cp313-win_amd64.whl", hash = "sha256:2efb84d6789352d7950b03d7f866e6d180284bc02c7e12cb37b489b7083d81aa"}, + {file = "debugpy-1.8.7-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:4b908291a1d051ef3331484de8e959ef3e66f12b5e610c203b5b75d2725613a7"}, + {file = "debugpy-1.8.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da8df5b89a41f1fd31503b179d0a84a5fdb752dddd5b5388dbd1ae23cda31ce9"}, + {file = "debugpy-1.8.7-cp38-cp38-win32.whl", hash = "sha256:b12515e04720e9e5c2216cc7086d0edadf25d7ab7e3564ec8b4521cf111b4f8c"}, + {file = "debugpy-1.8.7-cp38-cp38-win_amd64.whl", hash = "sha256:93176e7672551cb5281577cdb62c63aadc87ec036f0c6a486f0ded337c504596"}, + {file = "debugpy-1.8.7-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:90d93e4f2db442f8222dec5ec55ccfc8005821028982f1968ebf551d32b28907"}, + {file = "debugpy-1.8.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6db2a370e2700557a976eaadb16243ec9c91bd46f1b3bb15376d7aaa7632c81"}, + {file = "debugpy-1.8.7-cp39-cp39-win32.whl", hash = "sha256:a6cf2510740e0c0b4a40330640e4b454f928c7b99b0c9dbf48b11efba08a8cda"}, + {file = "debugpy-1.8.7-cp39-cp39-win_amd64.whl", hash = "sha256:6a9d9d6d31846d8e34f52987ee0f1a904c7baa4912bf4843ab39dadf9b8f3e0d"}, + {file = "debugpy-1.8.7-py2.py3-none-any.whl", hash = "sha256:57b00de1c8d2c84a61b90880f7e5b6deaf4c312ecbde3a0e8912f2a56c4ac9ae"}, + {file = "debugpy-1.8.7.zip", hash = "sha256:18b8f731ed3e2e1df8e9cdaa23fb1fc9c24e570cd0081625308ec51c82efe42e"}, ] [[package]] @@ -491,13 +546,6 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, - {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, - {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, - {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, - {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, - {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, - {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, - {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -522,31 +570,15 @@ files = [ {file = "dm_tree-0.1.8-cp39-cp39-win_amd64.whl", hash = "sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368"}, ] -[[package]] -name = "equinox" -version = "0.11.4" -description = "Elegant easy-to-use neural networks in JAX." -optional = false -python-versions = "~=3.9" -files = [ - {file = "equinox-0.11.4-py3-none-any.whl", hash = "sha256:a9527b1fe0c4778c3c959d9091b1eea28c3fdcca01790a47e71b47df94889315"}, - {file = "equinox-0.11.4.tar.gz", hash = "sha256:0033d9731083f402a855b12a0777a80aa8507651f7aa86d9f0f9503bcddfd320"}, -] - -[package.dependencies] -jax = ">=0.4.13" -jaxtyping = ">=0.2.20" -typing-extensions = ">=4.5.0" - [[package]] name = "executing" -version = "2.0.1" +version = "2.1.0" description = "Get the currently executing AST node of a frame, and other information" optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, - {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, + {file = "executing-2.1.0-py2.py3-none-any.whl", hash = "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf"}, + {file = "executing-2.1.0.tar.gz", hash = "sha256:8ea27ddd260da8150fa5a708269c4a10e76161e2496ec3e587da9e3c0fe4b9ab"}, ] [package.extras] @@ -568,53 +600,59 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "fonttools" -version = "4.53.1" +version = "4.54.1" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" files = [ - {file = "fonttools-4.53.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397"}, - {file = "fonttools-4.53.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8bf06b94694251861ba7fdeea15c8ec0967f84c3d4143ae9daf42bbc7717fe3"}, - {file = "fonttools-4.53.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b96cd370a61f4d083c9c0053bf634279b094308d52fdc2dd9a22d8372fdd590d"}, - {file = "fonttools-4.53.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1c7c5aa18dd3b17995898b4a9b5929d69ef6ae2af5b96d585ff4005033d82f0"}, - {file = "fonttools-4.53.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e013aae589c1c12505da64a7d8d023e584987e51e62006e1bb30d72f26522c41"}, - {file = "fonttools-4.53.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9efd176f874cb6402e607e4cc9b4a9cd584d82fc34a4b0c811970b32ba62501f"}, - {file = "fonttools-4.53.1-cp310-cp310-win32.whl", hash = "sha256:c8696544c964500aa9439efb6761947393b70b17ef4e82d73277413f291260a4"}, - {file = "fonttools-4.53.1-cp310-cp310-win_amd64.whl", hash = "sha256:8959a59de5af6d2bec27489e98ef25a397cfa1774b375d5787509c06659b3671"}, - {file = "fonttools-4.53.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:da33440b1413bad53a8674393c5d29ce64d8c1a15ef8a77c642ffd900d07bfe1"}, - {file = "fonttools-4.53.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ff7e5e9bad94e3a70c5cd2fa27f20b9bb9385e10cddab567b85ce5d306ea923"}, - {file = "fonttools-4.53.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6e7170d675d12eac12ad1a981d90f118c06cf680b42a2d74c6c931e54b50719"}, - {file = "fonttools-4.53.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee32ea8765e859670c4447b0817514ca79054463b6b79784b08a8df3a4d78e3"}, - {file = "fonttools-4.53.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6e08f572625a1ee682115223eabebc4c6a2035a6917eac6f60350aba297ccadb"}, - {file = "fonttools-4.53.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b21952c092ffd827504de7e66b62aba26fdb5f9d1e435c52477e6486e9d128b2"}, - {file = "fonttools-4.53.1-cp311-cp311-win32.whl", hash = "sha256:9dfdae43b7996af46ff9da520998a32b105c7f098aeea06b2226b30e74fbba88"}, - {file = "fonttools-4.53.1-cp311-cp311-win_amd64.whl", hash = "sha256:d4d0096cb1ac7a77b3b41cd78c9b6bc4a400550e21dc7a92f2b5ab53ed74eb02"}, - {file = "fonttools-4.53.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58"}, - {file = "fonttools-4.53.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8"}, - {file = "fonttools-4.53.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60"}, - {file = "fonttools-4.53.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f"}, - {file = "fonttools-4.53.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2"}, - {file = "fonttools-4.53.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f"}, - {file = "fonttools-4.53.1-cp312-cp312-win32.whl", hash = "sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670"}, - {file = "fonttools-4.53.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab"}, - {file = "fonttools-4.53.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c818c058404eb2bba05e728d38049438afd649e3c409796723dfc17cd3f08749"}, - {file = "fonttools-4.53.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:651390c3b26b0c7d1f4407cad281ee7a5a85a31a110cbac5269de72a51551ba2"}, - {file = "fonttools-4.53.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e54f1bba2f655924c1138bbc7fa91abd61f45c68bd65ab5ed985942712864bbb"}, - {file = "fonttools-4.53.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9cd19cf4fe0595ebdd1d4915882b9440c3a6d30b008f3cc7587c1da7b95be5f"}, - {file = "fonttools-4.53.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2af40ae9cdcb204fc1d8f26b190aa16534fcd4f0df756268df674a270eab575d"}, - {file = "fonttools-4.53.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:35250099b0cfb32d799fb5d6c651220a642fe2e3c7d2560490e6f1d3f9ae9169"}, - {file = "fonttools-4.53.1-cp38-cp38-win32.whl", hash = "sha256:f08df60fbd8d289152079a65da4e66a447efc1d5d5a4d3f299cdd39e3b2e4a7d"}, - {file = "fonttools-4.53.1-cp38-cp38-win_amd64.whl", hash = "sha256:7b6b35e52ddc8fb0db562133894e6ef5b4e54e1283dff606fda3eed938c36fc8"}, - {file = "fonttools-4.53.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75a157d8d26c06e64ace9df037ee93a4938a4606a38cb7ffaf6635e60e253b7a"}, - {file = "fonttools-4.53.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4824c198f714ab5559c5be10fd1adf876712aa7989882a4ec887bf1ef3e00e31"}, - {file = "fonttools-4.53.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:becc5d7cb89c7b7afa8321b6bb3dbee0eec2b57855c90b3e9bf5fb816671fa7c"}, - {file = "fonttools-4.53.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84ec3fb43befb54be490147b4a922b5314e16372a643004f182babee9f9c3407"}, - {file = "fonttools-4.53.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:73379d3ffdeecb376640cd8ed03e9d2d0e568c9d1a4e9b16504a834ebadc2dfb"}, - {file = "fonttools-4.53.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122"}, - {file = "fonttools-4.53.1-cp39-cp39-win32.whl", hash = "sha256:aae7bd54187e8bf7fd69f8ab87b2885253d3575163ad4d669a262fe97f0136cb"}, - {file = "fonttools-4.53.1-cp39-cp39-win_amd64.whl", hash = "sha256:e5b708073ea3d684235648786f5f6153a48dc8762cdfe5563c57e80787c29fbb"}, - {file = "fonttools-4.53.1-py3-none-any.whl", hash = "sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d"}, - {file = "fonttools-4.53.1.tar.gz", hash = "sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4"}, + {file = "fonttools-4.54.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ed7ee041ff7b34cc62f07545e55e1468808691dddfd315d51dd82a6b37ddef2"}, + {file = "fonttools-4.54.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41bb0b250c8132b2fcac148e2e9198e62ff06f3cc472065dff839327945c5882"}, + {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7965af9b67dd546e52afcf2e38641b5be956d68c425bef2158e95af11d229f10"}, + {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:278913a168f90d53378c20c23b80f4e599dca62fbffae4cc620c8eed476b723e"}, + {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0e88e3018ac809b9662615072dcd6b84dca4c2d991c6d66e1970a112503bba7e"}, + {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4aa4817f0031206e637d1e685251ac61be64d1adef111060df84fdcbc6ab6c44"}, + {file = "fonttools-4.54.1-cp310-cp310-win32.whl", hash = "sha256:7e3b7d44e18c085fd8c16dcc6f1ad6c61b71ff463636fcb13df7b1b818bd0c02"}, + {file = "fonttools-4.54.1-cp310-cp310-win_amd64.whl", hash = "sha256:dd9cc95b8d6e27d01e1e1f1fae8559ef3c02c76317da650a19047f249acd519d"}, + {file = "fonttools-4.54.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5419771b64248484299fa77689d4f3aeed643ea6630b2ea750eeab219588ba20"}, + {file = "fonttools-4.54.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:301540e89cf4ce89d462eb23a89464fef50915255ece765d10eee8b2bf9d75b2"}, + {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ae5091547e74e7efecc3cbf8e75200bc92daaeb88e5433c5e3e95ea8ce5aa7"}, + {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82834962b3d7c5ca98cb56001c33cf20eb110ecf442725dc5fdf36d16ed1ab07"}, + {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d26732ae002cc3d2ecab04897bb02ae3f11f06dd7575d1df46acd2f7c012a8d8"}, + {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:58974b4987b2a71ee08ade1e7f47f410c367cdfc5a94fabd599c88165f56213a"}, + {file = "fonttools-4.54.1-cp311-cp311-win32.whl", hash = "sha256:ab774fa225238986218a463f3fe151e04d8c25d7de09df7f0f5fce27b1243dbc"}, + {file = "fonttools-4.54.1-cp311-cp311-win_amd64.whl", hash = "sha256:07e005dc454eee1cc60105d6a29593459a06321c21897f769a281ff2d08939f6"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714"}, + {file = "fonttools-4.54.1-cp312-cp312-win32.whl", hash = "sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac"}, + {file = "fonttools-4.54.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb"}, + {file = "fonttools-4.54.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a"}, + {file = "fonttools-4.54.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c"}, + {file = "fonttools-4.54.1-cp313-cp313-win32.whl", hash = "sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58"}, + {file = "fonttools-4.54.1-cp313-cp313-win_amd64.whl", hash = "sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d"}, + {file = "fonttools-4.54.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ed2f80ca07025551636c555dec2b755dd005e2ea8fbeb99fc5cdff319b70b23b"}, + {file = "fonttools-4.54.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9dc080e5a1c3b2656caff2ac2633d009b3a9ff7b5e93d0452f40cd76d3da3b3c"}, + {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d152d1be65652fc65e695e5619e0aa0982295a95a9b29b52b85775243c06556"}, + {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8583e563df41fdecef31b793b4dd3af8a9caa03397be648945ad32717a92885b"}, + {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0d1d353ef198c422515a3e974a1e8d5b304cd54a4c2eebcae708e37cd9eeffb1"}, + {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fda582236fee135d4daeca056c8c88ec5f6f6d88a004a79b84a02547c8f57386"}, + {file = "fonttools-4.54.1-cp38-cp38-win32.whl", hash = "sha256:e7d82b9e56716ed32574ee106cabca80992e6bbdcf25a88d97d21f73a0aae664"}, + {file = "fonttools-4.54.1-cp38-cp38-win_amd64.whl", hash = "sha256:ada215fd079e23e060157aab12eba0d66704316547f334eee9ff26f8c0d7b8ab"}, + {file = "fonttools-4.54.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f5b8a096e649768c2f4233f947cf9737f8dbf8728b90e2771e2497c6e3d21d13"}, + {file = "fonttools-4.54.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e10d2e0a12e18f4e2dd031e1bf7c3d7017be5c8dbe524d07706179f355c5dac"}, + {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31c32d7d4b0958600eac75eaf524b7b7cb68d3a8c196635252b7a2c30d80e986"}, + {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c39287f5c8f4a0c5a55daf9eaf9ccd223ea59eed3f6d467133cc727d7b943a55"}, + {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a7a310c6e0471602fe3bf8efaf193d396ea561486aeaa7adc1f132e02d30c4b9"}, + {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d3b659d1029946f4ff9b6183984578041b520ce0f8fb7078bb37ec7445806b33"}, + {file = "fonttools-4.54.1-cp39-cp39-win32.whl", hash = "sha256:e96bc94c8cda58f577277d4a71f51c8e2129b8b36fd05adece6320dd3d57de8a"}, + {file = "fonttools-4.54.1-cp39-cp39-win_amd64.whl", hash = "sha256:e8a4b261c1ef91e7188a30571be6ad98d1c6d9fa2427244c545e2fa0a2494dd7"}, + {file = "fonttools-4.54.1-py3-none-any.whl", hash = "sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd"}, + {file = "fonttools-4.54.1.tar.gz", hash = "sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285"}, ] [package.extras] @@ -643,27 +681,28 @@ files = [ [[package]] name = "genjax" -version = "0.5.0.post30.dev0+2df4f579" +version = "0.7.0.post4.dev0+eacb241e" description = "Probabilistic programming with Gen, built on top of JAX." optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genjax-0.5.0.post30.dev0+2df4f579-py3-none-any.whl", hash = "sha256:d34589d12278343d0066263789c9584077fc79d50d4096781e62cb205acfe757"}, - {file = "genjax-0.5.0.post30.dev0+2df4f579.tar.gz", hash = "sha256:495f5aae0be9ef3a68b6ad483a6472a00cf6bf1f9d366ed747d9313d489a9055"}, + {file = "genjax-0.7.0.post4.dev0+eacb241e-py3-none-any.whl", hash = "sha256:c6374155c6b772e65919115613264e372606844438e75a6ef1d3db0350d5c79f"}, + {file = "genjax-0.7.0.post4.dev0+eacb241e.tar.gz", hash = "sha256:d738d029a7a5a40390236ab35b8d1e8745c7ba02cc5fc1d2723382c1f8b0cb01"}, ] [package.dependencies] beartype = ">=0.18.5,<0.19.0" deprecated = ">=1.2.14,<2.0.0" -jax = ">=0.4.28,<0.5.0" -jaxtyping = ">=0.2.28,<0.3.0" -penzai = ">=0.1.1,<0.2.0" +jax = ">=0.4.24,<0.5.0" +jaxtyping = ">=0.2.24,<0.3.0" +numpy = ">=1.22,<2.0.0" +penzai = ">=0.2.2,<0.3.0" tensorflow-probability = ">=0.23.0,<0.24.0" +treescope = ">=0.1.5,<0.2.0" [package.extras] -all = ["genstudio (==2024.06.07.0931.49)", "msgpack (>=1.0.8,<2.0.0)"] -genstudio = ["genstudio (==2024.06.07.0931.49)"] -msgpack = ["msgpack (>=1.0.8,<2.0.0)"] +all = ["genstudio (==2024.09.003)"] +genstudio = ["genstudio (==2024.09.003)"] [package.source] type = "legacy" @@ -672,23 +711,19 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.7.29.2335" +version = "2024.10.1" description = "" optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genstudio-2024.7.29.2335-py3-none-any.whl", hash = "sha256:720d76d14af58e128a68f0bddf675e4604741a07c3ca02e658636e5b8a8d54e0"}, - {file = "genstudio-2024.7.29.2335.tar.gz", hash = "sha256:c7ddb15e94655cec26c84d1c62505513a695d2b154edbee2434101c0f21f868b"}, + {file = "genstudio-2024.10.1-py3-none-any.whl", hash = "sha256:c95cffb1e3d9ca8d9424a535ba227c3e8ecbdc95673e907f9da78e89d6c77b3c"}, + {file = "genstudio-2024.10.1.tar.gz", hash = "sha256:279d461dbec2c6d58f27c99216d9199f40f233a7add506f1c909cf48e9aff8e7"}, ] [package.dependencies] anywidget = ">=0.9.10,<0.10.0" html2image = ">=2.0.4.3,<3.0.0.0" -jax = ">=0.4.25,<0.5.0" -jaxlib = ">=0.4.28,<0.5.0" -numpy = ">=1.26.4,<2.0.0" orjson = ">=3.10.6,<4.0.0" -penzai = ">=0.1.1,<0.2.0" pillow = ">=10.4.0,<11.0.0" traitlets = ">=5.14.3,<6.0.0" @@ -699,30 +734,33 @@ reference = "gcp" [[package]] name = "html2image" -version = "2.0.4.3" +version = "2.0.5" description = "Package acting as a wrapper around the headless mode of existing web browsers to generate images from URLs and from HTML+CSS strings or files." optional = false -python-versions = ">=3.6,<4.0" +python-versions = "<4.0,>=3.6" files = [ - {file = "html2image-2.0.4.3-py3-none-any.whl", hash = "sha256:e39bc1be8cb39bd36a1b9412d22f5db88d56e762f9ad3461124fa05fa7982945"}, - {file = "html2image-2.0.4.3.tar.gz", hash = "sha256:878e69122eabf8263415784888c4366f04a8b301516fc5d13b9e0acf8db591e7"}, + {file = "html2image-2.0.5-py3-none-any.whl", hash = "sha256:71593b6b1e100a585201833656cc7446ce1d2d5c7ec101b507c5c2b71bca6ecc"}, + {file = "html2image-2.0.5.tar.gz", hash = "sha256:8d3cf5ee805647d1fb21442137349b3ab0e352b4cac7e4280e2a0f841466e87c"}, ] [package.dependencies] requests = "*" -websocket-client = ">=1.0.0,<2.0.0" +websocket-client = "==1.*" [[package]] name = "idna" -version = "3.7" +version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false -python-versions = ">=3.5" +python-versions = ">=3.6" files = [ - {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, - {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, ] +[package.extras] +all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] + [[package]] name = "ipykernel" version = "6.29.5" @@ -758,13 +796,13 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio [[package]] name = "ipython" -version = "8.26.0" +version = "8.28.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" files = [ - {file = "ipython-8.26.0-py3-none-any.whl", hash = "sha256:e6b347c27bdf9c32ee9d31ae85defc525755a1869f14057e900675b9e8d6e6ff"}, - {file = "ipython-8.26.0.tar.gz", hash = "sha256:1cec0fbba8404af13facebe83d04436a7434c7400e59f47acf467c64abd0956c"}, + {file = "ipython-8.28.0-py3-none-any.whl", hash = "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35"}, + {file = "ipython-8.28.0.tar.gz", hash = "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a"}, ] [package.dependencies] @@ -795,104 +833,105 @@ test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "num [[package]] name = "ipywidgets" -version = "8.1.3" +version = "8.1.5" description = "Jupyter interactive widgets" optional = false python-versions = ">=3.7" files = [ - {file = "ipywidgets-8.1.3-py3-none-any.whl", hash = "sha256:efafd18f7a142248f7cb0ba890a68b96abd4d6e88ddbda483c9130d12667eaf2"}, - {file = "ipywidgets-8.1.3.tar.gz", hash = "sha256:f5f9eeaae082b1823ce9eac2575272952f40d748893972956dc09700a6392d9c"}, + {file = "ipywidgets-8.1.5-py3-none-any.whl", hash = "sha256:3290f526f87ae6e77655555baba4f36681c555b8bdbbff430b70e52c34c86245"}, + {file = "ipywidgets-8.1.5.tar.gz", hash = "sha256:870e43b1a35656a80c18c9503bbf2d16802db1cb487eec6fab27d683381dde17"}, ] [package.dependencies] comm = ">=0.1.3" ipython = ">=6.1.0" -jupyterlab-widgets = ">=3.0.11,<3.1.0" +jupyterlab-widgets = ">=3.0.12,<3.1.0" traitlets = ">=4.3.1" -widgetsnbextension = ">=4.0.11,<4.1.0" +widgetsnbextension = ">=4.0.12,<4.1.0" [package.extras] test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] [[package]] name = "jax" -version = "0.4.30" +version = "0.4.35" description = "Differentiate, compile, and transform Numpy code." optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "jax-0.4.30-py3-none-any.whl", hash = "sha256:289b30ae03b52f7f4baf6ef082a9f4e3e29c1080e22d13512c5ecf02d5f1a55b"}, - {file = "jax-0.4.30.tar.gz", hash = "sha256:94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577"}, + {file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"}, + {file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"}, ] [package.dependencies] -jaxlib = ">=0.4.27,<=0.4.30" -ml-dtypes = ">=0.2.0" +jaxlib = ">=0.4.34,<=0.4.35" +ml-dtypes = ">=0.4.0" numpy = [ {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.24", markers = "python_version < \"3.12\""}, ] opt-einsum = "*" scipy = [ {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, - {version = ">=1.9", markers = "python_version < \"3.12\""}, + {version = ">=1.10", markers = "python_version < \"3.12\""}, ] [package.extras] -ci = ["jaxlib (==0.4.29)"] -cuda = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"] -cuda12 = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"] -cuda12-local = ["jax-cuda12-plugin (==0.4.30)", "jaxlib (==0.4.30)"] -cuda12-pip = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"] -minimum-jaxlib = ["jaxlib (==0.4.27)"] -tpu = ["jaxlib (==0.4.30)", "libtpu-nightly (==0.1.dev20240617)", "requests"] +ci = ["jaxlib (==0.4.34)"] +cuda = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] +cuda12 = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] +cuda12-local = ["jax-cuda12-plugin (==0.4.34)", "jaxlib (==0.4.34)"] +cuda12-pip = ["jax-cuda12-plugin[with-cuda] (>=0.4.34,<=0.4.35)", "jaxlib (==0.4.34)"] +k8s = ["kubernetes"] +minimum-jaxlib = ["jaxlib (==0.4.34)"] +tpu = ["jaxlib (>=0.4.34,<=0.4.35)", "libtpu (==0.0.2)", "libtpu-nightly (==0.1.dev20241010+nightly.cleanup)", "requests"] [[package]] name = "jaxlib" -version = "0.4.30" +version = "0.4.35" description = "XLA library for JAX" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "jaxlib-0.4.30-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c40856e28f300938c6824ab1a615166193d6997dec946578823f6d402ad454e5"}, - {file = "jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bdfda6a3c7a2b0cc0a7131009eb279e98ca4a6f25679fabb5302dd135a5e349"}, - {file = "jaxlib-0.4.30-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:28e032c9b394ab7624d89b0d9d3bbcf4d1d71694fe8b3e09d3fe64122eda7b0c"}, - {file = "jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d83f36ef42a403bbf7c7f2da526b34ba286988e170f4df5e58b3bb735417868c"}, - {file = "jaxlib-0.4.30-cp310-cp310-win_amd64.whl", hash = "sha256:a56678b28f96b524ded6da8ef4b38e72a532356d139cfd434da804abf4234e14"}, - {file = "jaxlib-0.4.30-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:bfb5d85b69c29c3c6e8051a0ea715ac1e532d6e54494c8d9c3813dcc00deac30"}, - {file = "jaxlib-0.4.30-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:974998cd8a78550402e6c09935c1f8d850cad9cc19ccd7488bde45b6f7f99c12"}, - {file = "jaxlib-0.4.30-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e93eb0646b41ba213252b51b0b69096b9cd1d81a35ea85c9d06663b5d11efe45"}, - {file = "jaxlib-0.4.30-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:16b2ab18ea90d2e15941bcf45de37afc2f289a029129c88c8d7aba0404dd0043"}, - {file = "jaxlib-0.4.30-cp311-cp311-win_amd64.whl", hash = "sha256:3a2e2c11c179f8851a72249ba1ae40ae817dfaee9877d23b3b8f7c6b7a012f76"}, - {file = "jaxlib-0.4.30-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7704db5962b32a2be3cc07185433cbbcc94ed90ee50c84021a3f8a1ecfd66ee3"}, - {file = "jaxlib-0.4.30-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:57090d33477fd0f0c99dc686274882ea75c44c7d712ae42dd2460b10f896131d"}, - {file = "jaxlib-0.4.30-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:0a3850e76278038e21685975a62b622bcf3708485f13125757a0561ee4512940"}, - {file = "jaxlib-0.4.30-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:c58a8071c4e00898282118169f6a5a97eb15a79c2897858f3a732b17891c99ab"}, - {file = "jaxlib-0.4.30-cp312-cp312-win_amd64.whl", hash = "sha256:b7079a5b1ab6864a7d4f2afaa963841451186d22c90f39719a3ff85735ce3915"}, - {file = "jaxlib-0.4.30-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ea3a00005faafbe3c18b178d3b534208b3b4027b2be6230227e7b87ce399fc29"}, - {file = "jaxlib-0.4.30-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d31e01191ce8052bd611aaf16ff967d8d0ec0b63f1ea4b199020cecb248d667"}, - {file = "jaxlib-0.4.30-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:11602d5556e8baa2f16314c36518e9be4dfae0c2c256a361403fb29dc9dc79a4"}, - {file = "jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f74a6b0e09df4b5e2ee399ebb9f0e01190e26e84ccb0a758fadb516415c07f18"}, - {file = "jaxlib-0.4.30-cp39-cp39-win_amd64.whl", hash = "sha256:54987e97a22db70f3829b437b9329e4799d653634bacc8b398554d3b90c76b2a"}, + {file = "jaxlib-0.4.35-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:907e548ad6ce53b242a55c5f36c2a2a4c37d38f6cd8c356fc550a2f18ab0e82f"}, + {file = "jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f8c499644660aefd0ae2ee31039da6d4df0f26d0ee67ba9fb316183a5304288"}, + {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5d2d8a5b89d334b875ede98d7fcee946bebef1a1b5abd118ff543bcef4ab09f5"}, + {file = "jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:91a283a72263feebe0d110d1136df96950744e47530f12df42c03f36888c971e"}, + {file = "jaxlib-0.4.35-cp310-cp310-win_amd64.whl", hash = "sha256:d210bab7e1ce0b2f2e568548b3903ea6aec349019fc1398cd2a0c069e8342e62"}, + {file = "jaxlib-0.4.35-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:7f8bfc90f68857b223b7e38a9bdf466a4f1cb405c9a4aa11698dc9ab7b35c29b"}, + {file = "jaxlib-0.4.35-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261570c94b169dc90f3af903282eeec856b52736c0944d243504ced93d19b217"}, + {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e1cee6dc291251f3fb6b0127fdd96c0439ac1ea97e01571d06910df72d6ac6e1"}, + {file = "jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc9eafba001ff8569cfa252fe7f04ba553622702b4b473b656dd0866edf6b8d4"}, + {file = "jaxlib-0.4.35-cp311-cp311-win_amd64.whl", hash = "sha256:0fd990354d5623d3a34493fcd7213493390dbf5039bea19b62e2aaee1049eda9"}, + {file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"}, + {file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"}, + {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"}, + {file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"}, + {file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"}, + {file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"}, + {file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"}, + {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"}, + {file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"}, + {file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"}, ] [package.dependencies] ml-dtypes = ">=0.2.0" -numpy = ">=1.22" +numpy = ">=1.24" scipy = [ {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, - {version = ">=1.9", markers = "python_version < \"3.12\""}, + {version = ">=1.10", markers = "python_version < \"3.12\""}, ] [[package]] name = "jaxtyping" -version = "0.2.33" +version = "0.2.34" description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." optional = false python-versions = "~=3.9" files = [ - {file = "jaxtyping-0.2.33-py3-none-any.whl", hash = "sha256:918d6094c73f28d3196185ef55d1832cbcd2804d1d388f180060c4366a9e2107"}, - {file = "jaxtyping-0.2.33.tar.gz", hash = "sha256:9a9cfccae4fe05114b9fb27a5ea5440be4971a5a075bbd0526f6dd7d2730f83e"}, + {file = "jaxtyping-0.2.34-py3-none-any.whl", hash = "sha256:2f81fb6d1586e497a6ea2d28c06dcab37b108a096cbb36ea3fe4fa2e1c1f32e5"}, + {file = "jaxtyping-0.2.34.tar.gz", hash = "sha256:eed9a3458ec8726c84ea5457ebde53c964f65d2c22c0ec40d0555ae3fed5bbaf"}, ] [package.dependencies] @@ -940,13 +979,13 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- [[package]] name = "jsonschema-specifications" -version = "2023.12.1" +version = "2024.10.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, - {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, + {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"}, + {file = "jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272"}, ] [package.dependencies] @@ -954,13 +993,13 @@ referencing = ">=0.31.0" [[package]] name = "jupyter-client" -version = "8.6.2" +version = "8.6.3" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_client-8.6.2-py3-none-any.whl", hash = "sha256:50cbc5c66fd1b8f65ecb66bc490ab73217993632809b6e505687de18e9dea39f"}, - {file = "jupyter_client-8.6.2.tar.gz", hash = "sha256:2bda14d55ee5ba58552a8c53ae43d215ad9868853489213f37da060ced54d8df"}, + {file = "jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f"}, + {file = "jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419"}, ] [package.dependencies] @@ -996,24 +1035,24 @@ test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout" [[package]] name = "jupyterlab-widgets" -version = "3.0.11" +version = "3.0.13" description = "Jupyter interactive widgets for JupyterLab" optional = false python-versions = ">=3.7" files = [ - {file = "jupyterlab_widgets-3.0.11-py3-none-any.whl", hash = "sha256:78287fd86d20744ace330a61625024cf5521e1c012a352ddc0a3cdc2348becd0"}, - {file = "jupyterlab_widgets-3.0.11.tar.gz", hash = "sha256:dd5ac679593c969af29c9bed054c24f26842baa51352114736756bc035deee27"}, + {file = "jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54"}, + {file = "jupyterlab_widgets-3.0.13.tar.gz", hash = "sha256:a2966d385328c1942b683a8cd96b89b8dd82c8b8f81dda902bb2bc06d46f5bed"}, ] [[package]] name = "jupytext" -version = "1.16.3" +version = "1.16.4" description = "Jupyter notebooks as Markdown documents, Julia, Python or R scripts" optional = false python-versions = ">=3.8" files = [ - {file = "jupytext-1.16.3-py3-none-any.whl", hash = "sha256:870e0d7a716dcb1303df6ad1cec65e3315a20daedd808a55cb3dae2d56e4ed20"}, - {file = "jupytext-1.16.3.tar.gz", hash = "sha256:1ebac990461dd9f477ff7feec9e3003fa1acc89f3c16ba01b73f79fd76f01a98"}, + {file = "jupytext-1.16.4-py3-none-any.whl", hash = "sha256:76989d2690e65667ea6fb411d8056abe7cd0437c07bd774660b83d62acf9490a"}, + {file = "jupytext-1.16.4.tar.gz", hash = "sha256:28e33f46f2ce7a41fb9d677a4a2c95327285579b64ca104437c4b9eb1e4174e9"}, ] [package.dependencies] @@ -1035,126 +1074,136 @@ test-ui = ["calysto-bash"] [[package]] name = "kiwisolver" -version = "1.4.5" +version = "1.4.7" description = "A fast implementation of the Cassowary constraint solver" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"}, - {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"}, - {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"}, - {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"}, - {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"}, - {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"}, - {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-win32.whl", hash = "sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-win_amd64.whl", hash = "sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250"}, - {file = "kiwisolver-1.4.5-cp38-cp38-win32.whl", hash = "sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e"}, - {file = "kiwisolver-1.4.5-cp38-cp38-win_amd64.whl", hash = "sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"}, - {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"}, - {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"}, - {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win32.whl", hash = "sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_amd64.whl", hash = "sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_arm64.whl", hash = "sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win32.whl", hash = "sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_amd64.whl", hash = "sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_arm64.whl", hash = "sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win32.whl", hash = "sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win32.whl", hash = "sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win32.whl", hash = "sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win_amd64.whl", hash = "sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win32.whl", hash = "sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_amd64.whl", hash = "sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_arm64.whl", hash = "sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0"}, + {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, ] [[package]] name = "markdown" -version = "3.6" +version = "3.7" description = "Python implementation of John Gruber's Markdown." optional = false python-versions = ">=3.8" files = [ - {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, - {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, + {file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"}, + {file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"}, ] [package.extras] @@ -1187,40 +1236,51 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "matplotlib" -version = "3.9.1" +version = "3.9.2" description = "Python plotting package" optional = false python-versions = ">=3.9" files = [ - {file = "matplotlib-3.9.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:7ccd6270066feb9a9d8e0705aa027f1ff39f354c72a87efe8fa07632f30fc6bb"}, - {file = "matplotlib-3.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:591d3a88903a30a6d23b040c1e44d1afdd0d778758d07110eb7596f811f31842"}, - {file = "matplotlib-3.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd2a59ff4b83d33bca3b5ec58203cc65985367812cb8c257f3e101632be86d92"}, - {file = "matplotlib-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fc001516ffcf1a221beb51198b194d9230199d6842c540108e4ce109ac05cc0"}, - {file = "matplotlib-3.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:83c6a792f1465d174c86d06f3ae85a8fe36e6f5964633ae8106312ec0921fdf5"}, - {file = "matplotlib-3.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:421851f4f57350bcf0811edd754a708d2275533e84f52f6760b740766c6747a7"}, - {file = "matplotlib-3.9.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b3fce58971b465e01b5c538f9d44915640c20ec5ff31346e963c9e1cd66fa812"}, - {file = "matplotlib-3.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a973c53ad0668c53e0ed76b27d2eeeae8799836fd0d0caaa4ecc66bf4e6676c0"}, - {file = "matplotlib-3.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82cd5acf8f3ef43f7532c2f230249720f5dc5dd40ecafaf1c60ac8200d46d7eb"}, - {file = "matplotlib-3.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab38a4f3772523179b2f772103d8030215b318fef6360cb40558f585bf3d017f"}, - {file = "matplotlib-3.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2315837485ca6188a4b632c5199900e28d33b481eb083663f6a44cfc8987ded3"}, - {file = "matplotlib-3.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:a0c977c5c382f6696caf0bd277ef4f936da7e2aa202ff66cad5f0ac1428ee15b"}, - {file = "matplotlib-3.9.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:565d572efea2b94f264dd86ef27919515aa6d629252a169b42ce5f570db7f37b"}, - {file = "matplotlib-3.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d397fd8ccc64af2ec0af1f0efc3bacd745ebfb9d507f3f552e8adb689ed730a"}, - {file = "matplotlib-3.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26040c8f5121cd1ad712abffcd4b5222a8aec3a0fe40bc8542c94331deb8780d"}, - {file = "matplotlib-3.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12cb1837cffaac087ad6b44399d5e22b78c729de3cdae4629e252067b705e2b"}, - {file = "matplotlib-3.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0e835c6988edc3d2d08794f73c323cc62483e13df0194719ecb0723b564e0b5c"}, - {file = "matplotlib-3.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:44a21d922f78ce40435cb35b43dd7d573cf2a30138d5c4b709d19f00e3907fd7"}, - {file = "matplotlib-3.9.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0c584210c755ae921283d21d01f03a49ef46d1afa184134dd0f95b0202ee6f03"}, - {file = "matplotlib-3.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11fed08f34fa682c2b792942f8902e7aefeed400da71f9e5816bea40a7ce28fe"}, - {file = "matplotlib-3.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0000354e32efcfd86bda75729716b92f5c2edd5b947200be9881f0a671565c33"}, - {file = "matplotlib-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4db17fea0ae3aceb8e9ac69c7e3051bae0b3d083bfec932240f9bf5d0197a049"}, - {file = "matplotlib-3.9.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:208cbce658b72bf6a8e675058fbbf59f67814057ae78165d8a2f87c45b48d0ff"}, - {file = "matplotlib-3.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:dc23f48ab630474264276be156d0d7710ac6c5a09648ccdf49fef9200d8cbe80"}, - {file = "matplotlib-3.9.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:3fda72d4d472e2ccd1be0e9ccb6bf0d2eaf635e7f8f51d737ed7e465ac020cb3"}, - {file = "matplotlib-3.9.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:84b3ba8429935a444f1fdc80ed930babbe06725bcf09fbeb5c8757a2cd74af04"}, - {file = "matplotlib-3.9.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b918770bf3e07845408716e5bbda17eadfc3fcbd9307dc67f37d6cf834bb3d98"}, - {file = "matplotlib-3.9.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f1f2e5d29e9435c97ad4c36fb6668e89aee13d48c75893e25cef064675038ac9"}, - {file = "matplotlib-3.9.1.tar.gz", hash = "sha256:de06b19b8db95dd33d0dc17c926c7c9ebed9f572074b6fac4f65068a6814d010"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66"}, + {file = "matplotlib-3.9.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a"}, + {file = "matplotlib-3.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447"}, + {file = "matplotlib-3.9.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e"}, + {file = "matplotlib-3.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c"}, + {file = "matplotlib-3.9.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e"}, + {file = "matplotlib-3.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413"}, + {file = "matplotlib-3.9.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b"}, + {file = "matplotlib-3.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c"}, + {file = "matplotlib-3.9.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca"}, + {file = "matplotlib-3.9.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea"}, + {file = "matplotlib-3.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697"}, + {file = "matplotlib-3.9.2.tar.gz", hash = "sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92"}, ] [package.dependencies] @@ -1253,13 +1313,13 @@ traitlets = "*" [[package]] name = "mdit-py-plugins" -version = "0.4.1" +version = "0.4.2" description = "Collection of plugins for markdown-it-py" optional = false python-versions = ">=3.8" files = [ - {file = "mdit_py_plugins-0.4.1-py3-none-any.whl", hash = "sha256:1020dfe4e6bfc2c79fb49ae4e3f5b297f5ccd20f010187acc52af2921e27dc6a"}, - {file = "mdit_py_plugins-0.4.1.tar.gz", hash = "sha256:834b8ac23d1cd60cec703646ffd22ae97b7955a6d596eb1d304be1e251ae499c"}, + {file = "mdit_py_plugins-0.4.2-py3-none-any.whl", hash = "sha256:0c673c3f889399a33b95e88d2f0d111b4447bdfea7f237dab2d488f459835636"}, + {file = "mdit_py_plugins-0.4.2.tar.gz", hash = "sha256:5f2cd1fdb606ddf152d37ec30e46101a60512bc0e5fa1a7002c36647b09e26b5"}, ] [package.dependencies] @@ -1283,28 +1343,32 @@ files = [ [[package]] name = "ml-dtypes" -version = "0.4.0" +version = "0.5.0" description = "" optional = false python-versions = ">=3.9" files = [ - {file = "ml_dtypes-0.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81"}, - {file = "ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d"}, - {file = "ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5"}, - {file = "ml_dtypes-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6"}, - {file = "ml_dtypes-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259"}, - {file = "ml_dtypes-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1"}, - {file = "ml_dtypes-0.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e"}, - {file = "ml_dtypes-0.4.0.tar.gz", hash = "sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856"}, + {file = "ml_dtypes-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f"}, + {file = "ml_dtypes-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a"}, + {file = "ml_dtypes-0.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab"}, + {file = "ml_dtypes-0.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859"}, + {file = "ml_dtypes-0.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf"}, + {file = "ml_dtypes-0.5.0.tar.gz", hash = "sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128"}, ] [package.dependencies] @@ -1395,22 +1459,15 @@ files = [ [[package]] name = "opt-einsum" -version = "3.3.0" -description = "Optimizing numpys einsum function" +version = "3.4.0" +description = "Path optimization of einsum functions." optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, - {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, + {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"}, + {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"}, ] -[package.dependencies] -numpy = ">=1.7" - -[package.extras] -docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] -tests = ["pytest", "pytest-cov", "pytest-pep8"] - [[package]] name = "ordered-set" version = "4.1.0" @@ -1427,62 +1484,69 @@ dev = ["black", "mypy", "pytest"] [[package]] name = "orjson" -version = "3.10.6" +version = "3.10.10" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.6-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:fb0ee33124db6eaa517d00890fc1a55c3bfe1cf78ba4a8899d71a06f2d6ff5c7"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c1c4b53b24a4c06547ce43e5fee6ec4e0d8fe2d597f4647fc033fd205707365"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eadc8fd310edb4bdbd333374f2c8fec6794bbbae99b592f448d8214a5e4050c0"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61272a5aec2b2661f4fa2b37c907ce9701e821b2c1285d5c3ab0207ebd358d38"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57985ee7e91d6214c837936dc1608f40f330a6b88bb13f5a57ce5257807da143"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:633a3b31d9d7c9f02d49c4ab4d0a86065c4a6f6adc297d63d272e043472acab5"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1c680b269d33ec444afe2bdc647c9eb73166fa47a16d9a75ee56a374f4a45f43"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f759503a97a6ace19e55461395ab0d618b5a117e8d0fbb20e70cfd68a47327f2"}, - {file = "orjson-3.10.6-cp310-none-win32.whl", hash = "sha256:95a0cce17f969fb5391762e5719575217bd10ac5a189d1979442ee54456393f3"}, - {file = "orjson-3.10.6-cp310-none-win_amd64.whl", hash = "sha256:df25d9271270ba2133cc88ee83c318372bdc0f2cd6f32e7a450809a111efc45c"}, - {file = "orjson-3.10.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b1ec490e10d2a77c345def52599311849fc063ae0e67cf4f84528073152bb2ba"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55d43d3feb8f19d07e9f01e5b9be4f28801cf7c60d0fa0d279951b18fae1932b"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3045267e98fe749408eee1593a142e02357c5c99be0802185ef2170086a863"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27bc6a28ae95923350ab382c57113abd38f3928af3c80be6f2ba7eb8d8db0b0"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d27456491ca79532d11e507cadca37fb8c9324a3976294f68fb1eff2dc6ced5a"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05ac3d3916023745aa3b3b388e91b9166be1ca02b7c7e41045da6d12985685f0"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1335d4ef59ab85cab66fe73fd7a4e881c298ee7f63ede918b7faa1b27cbe5212"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4bbc6d0af24c1575edc79994c20e1b29e6fb3c6a570371306db0993ecf144dc5"}, - {file = "orjson-3.10.6-cp311-none-win32.whl", hash = "sha256:450e39ab1f7694465060a0550b3f6d328d20297bf2e06aa947b97c21e5241fbd"}, - {file = "orjson-3.10.6-cp311-none-win_amd64.whl", hash = "sha256:227df19441372610b20e05bdb906e1742ec2ad7a66ac8350dcfd29a63014a83b"}, - {file = "orjson-3.10.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ea2977b21f8d5d9b758bb3f344a75e55ca78e3ff85595d248eee813ae23ecdfb"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6f3d167d13a16ed263b52dbfedff52c962bfd3d270b46b7518365bcc2121eed"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f710f346e4c44a4e8bdf23daa974faede58f83334289df80bc9cd12fe82573c7"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7275664f84e027dcb1ad5200b8b18373e9c669b2a9ec33d410c40f5ccf4b257e"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0943e4c701196b23c240b3d10ed8ecd674f03089198cf503105b474a4f77f21f"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:446dee5a491b5bc7d8f825d80d9637e7af43f86a331207b9c9610e2f93fee22a"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:64c81456d2a050d380786413786b057983892db105516639cb5d3ee3c7fd5148"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, - {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, - {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, - {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2c116072a8533f2fec435fde4d134610f806bdac20188c7bd2081f3e9e0133f"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6eeb13218c8cf34c61912e9df2de2853f1d009de0e46ea09ccdf3d757896af0a"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:965a916373382674e323c957d560b953d81d7a8603fbeee26f7b8248638bd48b"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03c95484d53ed8e479cade8628c9cea00fd9d67f5554764a1110e0d5aa2de96e"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e060748a04cccf1e0a6f2358dffea9c080b849a4a68c28b1b907f272b5127e9b"}, - {file = "orjson-3.10.6-cp38-none-win32.whl", hash = "sha256:738dbe3ef909c4b019d69afc19caf6b5ed0e2f1c786b5d6215fbb7539246e4c6"}, - {file = "orjson-3.10.6-cp38-none-win_amd64.whl", hash = "sha256:d40f839dddf6a7d77114fe6b8a70218556408c71d4d6e29413bb5f150a692ff7"}, - {file = "orjson-3.10.6-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:697a35a083c4f834807a6232b3e62c8b280f7a44ad0b759fd4dce748951e70db"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd502f96bf5ea9a61cbc0b2b5900d0dd68aa0da197179042bdd2be67e51a1e4b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f215789fb1667cdc874c1b8af6a84dc939fd802bf293a8334fce185c79cd359b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2debd8ddce948a8c0938c8c93ade191d2f4ba4649a54302a7da905a81f00b56"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5410111d7b6681d4b0d65e0f58a13be588d01b473822483f77f513c7f93bd3b2"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb1f28a137337fdc18384079fa5726810681055b32b92253fa15ae5656e1dddb"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bf2fbbce5fe7cd1aa177ea3eab2b8e6a6bc6e8592e4279ed3db2d62e57c0e1b2"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:79b9b9e33bd4c517445a62b90ca0cc279b0f1f3970655c3df9e608bc3f91741a"}, - {file = "orjson-3.10.6-cp39-none-win32.whl", hash = "sha256:30b0a09a2014e621b1adf66a4f705f0809358350a757508ee80209b2d8dae219"}, - {file = "orjson-3.10.6-cp39-none-win_amd64.whl", hash = "sha256:49e3bc615652617d463069f91b867a4458114c5b104e13b7ae6872e5f79d0844"}, - {file = "orjson-3.10.6.tar.gz", hash = "sha256:e54b63d0a7c6c54a5f5f726bc93a2078111ef060fec4ecbf34c5db800ca3b3a7"}, + {file = "orjson-3.10.10-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b788a579b113acf1c57e0a68e558be71d5d09aa67f62ca1f68e01117e550a998"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:804b18e2b88022c8905bb79bd2cbe59c0cd014b9328f43da8d3b28441995cda4"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9972572a1d042ec9ee421b6da69f7cc823da5962237563fa548ab17f152f0b9b"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc6993ab1c2ae7dd0711161e303f1db69062955ac2668181bfdf2dd410e65258"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d78e4cacced5781b01d9bc0f0cd8b70b906a0e109825cb41c1b03f9c41e4ce86"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6eb2598df518281ba0cbc30d24c5b06124ccf7e19169e883c14e0831217a0bc"}, + {file = "orjson-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:23776265c5215ec532de6238a52707048401a568f0fa0d938008e92a147fe2c7"}, + {file = "orjson-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8cc2a654c08755cef90b468ff17c102e2def0edd62898b2486767204a7f5cc9c"}, + {file = "orjson-3.10.10-cp310-none-win32.whl", hash = "sha256:081b3fc6a86d72efeb67c13d0ea7c030017bd95f9868b1e329a376edc456153b"}, + {file = "orjson-3.10.10-cp310-none-win_amd64.whl", hash = "sha256:ff38c5fb749347768a603be1fb8a31856458af839f31f064c5aa74aca5be9efe"}, + {file = "orjson-3.10.10-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:879e99486c0fbb256266c7c6a67ff84f46035e4f8749ac6317cc83dacd7f993a"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019481fa9ea5ff13b5d5d95e6fd5ab25ded0810c80b150c2c7b1cc8660b662a7"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0dd57eff09894938b4c86d4b871a479260f9e156fa7f12f8cad4b39ea8028bb5"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dbde6d70cd95ab4d11ea8ac5e738e30764e510fc54d777336eec09bb93b8576c"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2625cb37b8fb42e2147404e5ff7ef08712099197a9cd38895006d7053e69d6"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf3c20c6a7db69df58672a0d5815647ecf78c8e62a4d9bd284e8621c1fe5ccb"}, + {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:75c38f5647e02d423807d252ce4528bf6a95bd776af999cb1fb48867ed01d1f6"}, + {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23458d31fa50ec18e0ec4b0b4343730928296b11111df5f547c75913714116b2"}, + {file = "orjson-3.10.10-cp311-none-win32.whl", hash = "sha256:2787cd9dedc591c989f3facd7e3e86508eafdc9536a26ec277699c0aa63c685b"}, + {file = "orjson-3.10.10-cp311-none-win_amd64.whl", hash = "sha256:6514449d2c202a75183f807bc755167713297c69f1db57a89a1ef4a0170ee269"}, + {file = "orjson-3.10.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8564f48f3620861f5ef1e080ce7cd122ee89d7d6dacf25fcae675ff63b4d6e05"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bf161a32b479034098c5b81f2608f09167ad2fa1c06abd4e527ea6bf4837a9"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68b65c93617bcafa7f04b74ae8bc2cc214bd5cb45168a953256ff83015c6747d"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8e28406f97fc2ea0c6150f4c1b6e8261453318930b334abc419214c82314f85"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4d0d9fe174cc7a5bdce2e6c378bcdb4c49b2bf522a8f996aa586020e1b96cee"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3be81c42f1242cbed03cbb3973501fcaa2675a0af638f8be494eaf37143d999"}, + {file = "orjson-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65f9886d3bae65be026219c0a5f32dbbe91a9e6272f56d092ab22561ad0ea33b"}, + {file = "orjson-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:730ed5350147db7beb23ddaf072f490329e90a1d059711d364b49fe352ec987b"}, + {file = "orjson-3.10.10-cp312-none-win32.whl", hash = "sha256:a8f4bf5f1c85bea2170800020d53a8877812892697f9c2de73d576c9307a8a5f"}, + {file = "orjson-3.10.10-cp312-none-win_amd64.whl", hash = "sha256:384cd13579a1b4cd689d218e329f459eb9ddc504fa48c5a83ef4889db7fd7a4f"}, + {file = "orjson-3.10.10-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44bffae68c291f94ff5a9b4149fe9d1bdd4cd0ff0fb575bcea8351d48db629a1"}, + {file = "orjson-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e27b4c6437315df3024f0835887127dac2a0a3ff643500ec27088d2588fa5ae1"}, + {file = "orjson-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca84df16d6b49325a4084fd8b2fe2229cb415e15c46c529f868c3387bb1339d"}, + {file = "orjson-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c14ce70e8f39bd71f9f80423801b5d10bf93d1dceffdecd04df0f64d2c69bc01"}, + {file = "orjson-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:24ac62336da9bda1bd93c0491eff0613003b48d3cb5d01470842e7b52a40d5b4"}, + {file = "orjson-3.10.10-cp313-none-win32.whl", hash = "sha256:eb0a42831372ec2b05acc9ee45af77bcaccbd91257345f93780a8e654efc75db"}, + {file = "orjson-3.10.10-cp313-none-win_amd64.whl", hash = "sha256:f0c4f37f8bf3f1075c6cc8dd8a9f843689a4b618628f8812d0a71e6968b95ffd"}, + {file = "orjson-3.10.10-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:829700cc18503efc0cf502d630f612884258020d98a317679cd2054af0259568"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0ceb5e0e8c4f010ac787d29ae6299846935044686509e2f0f06ed441c1ca949"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0c25908eb86968613216f3db4d3003f1c45d78eb9046b71056ca327ff92bdbd4"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:218cb0bc03340144b6328a9ff78f0932e642199ac184dd74b01ad691f42f93ff"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2277ec2cea3775640dc81ab5195bb5b2ada2fe0ea6eee4677474edc75ea6785"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:848ea3b55ab5ccc9d7bbd420d69432628b691fba3ca8ae3148c35156cbd282aa"}, + {file = "orjson-3.10.10-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e3e67b537ac0c835b25b5f7d40d83816abd2d3f4c0b0866ee981a045287a54f3"}, + {file = "orjson-3.10.10-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:7948cfb909353fce2135dcdbe4521a5e7e1159484e0bb024c1722f272488f2b8"}, + {file = "orjson-3.10.10-cp38-none-win32.whl", hash = "sha256:78bee66a988f1a333dc0b6257503d63553b1957889c17b2c4ed72385cd1b96ae"}, + {file = "orjson-3.10.10-cp38-none-win_amd64.whl", hash = "sha256:f1d647ca8d62afeb774340a343c7fc023efacfd3a39f70c798991063f0c681dd"}, + {file = "orjson-3.10.10-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5a059afddbaa6dd733b5a2d76a90dbc8af790b993b1b5cb97a1176ca713b5df8"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f9b5c59f7e2a1a410f971c5ebc68f1995822837cd10905ee255f96074537ee6"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d5ef198bafdef4aa9d49a4165ba53ffdc0a9e1c7b6f76178572ab33118afea25"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aaf29ce0bb5d3320824ec3d1508652421000ba466abd63bdd52c64bcce9eb1fa"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dddd5516bcc93e723d029c1633ae79c4417477b4f57dad9bfeeb6bc0315e654a"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12f2003695b10817f0fa8b8fca982ed7f5761dcb0d93cff4f2f9f6709903fd7"}, + {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:672f9874a8a8fb9bb1b771331d31ba27f57702c8106cdbadad8bda5d10bc1019"}, + {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dcbb0ca5fafb2b378b2c74419480ab2486326974826bbf6588f4dc62137570a"}, + {file = "orjson-3.10.10-cp39-none-win32.whl", hash = "sha256:d9bbd3a4b92256875cb058c3381b782649b9a3c68a4aa9a2fff020c2f9cfc1be"}, + {file = "orjson-3.10.10-cp39-none-win_amd64.whl", hash = "sha256:766f21487a53aee8524b97ca9582d5c6541b03ab6210fbaf10142ae2f3ced2aa"}, + {file = "orjson-3.10.10.tar.gz", hash = "sha256:37949383c4df7b4337ce82ee35b6d7471e55195efa7dcb45ab8226ceadb0fe3b"}, ] [[package]] @@ -1513,21 +1577,21 @@ testing = ["docopt", "pytest"] [[package]] name = "penzai" -version = "0.1.5" +version = "0.2.2" description = "Penzai: A JAX research toolkit for building, editing, and visualizing neural networks." optional = false python-versions = ">=3.10" files = [ - {file = "penzai-0.1.5-py3-none-any.whl", hash = "sha256:bc46c5f0ed2d0e74b558c1dc39dfb94c168eaddfe5ca446627e9d7affda78e20"}, - {file = "penzai-0.1.5.tar.gz", hash = "sha256:2feba0079af6689dac9e44623744858e74b144688ad42c368561a6851fe404e4"}, + {file = "penzai-0.2.2-py3-none-any.whl", hash = "sha256:387caf4a0af4067658528ae931648421b8fe2d62d0b32eb0c9254b2d95771e7e"}, + {file = "penzai-0.2.2.tar.gz", hash = "sha256:f08b1c7151ea07dfe80b99abc5c749942fb0da112ed2ba82d5588ec255f8e8be"}, ] [package.dependencies] absl-py = ">=1.4.0" -equinox = ">=0.11.3" jax = ">=0.4.23" numpy = ">=1.25.2" ordered_set = ">=4.1.0" +treescope = ">=0.1.3" typing_extensions = ">=4.2" [package.extras] @@ -1649,29 +1713,29 @@ xmp = ["defusedxml"] [[package]] name = "platformdirs" -version = "4.2.2" +version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, + {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, + {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] [[package]] name = "prompt-toolkit" -version = "3.0.47" +version = "3.0.48" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, - {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, + {file = "prompt_toolkit-3.0.48-py3-none-any.whl", hash = "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e"}, + {file = "prompt_toolkit-3.0.48.tar.gz", hash = "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90"}, ] [package.dependencies] @@ -1679,32 +1743,33 @@ wcwidth = "*" [[package]] name = "psutil" -version = "6.0.0" +version = "6.1.0" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ - {file = "psutil-6.0.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6"}, - {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0"}, - {file = "psutil-6.0.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c"}, - {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3"}, - {file = "psutil-6.0.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c"}, - {file = "psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35"}, - {file = "psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1"}, - {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, - {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, - {file = "psutil-6.0.0-cp36-cp36m-win32.whl", hash = "sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14"}, - {file = "psutil-6.0.0-cp36-cp36m-win_amd64.whl", hash = "sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c"}, - {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, - {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, - {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, - {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, + {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc"}, + {file = "psutil-6.1.0-cp27-none-win32.whl", hash = "sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e"}, + {file = "psutil-6.1.0-cp27-none-win_amd64.whl", hash = "sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a"}, + {file = "psutil-6.1.0-cp36-cp36m-win32.whl", hash = "sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca"}, + {file = "psutil-6.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747"}, + {file = "psutil-6.1.0-cp37-abi3-win32.whl", hash = "sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e"}, + {file = "psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be"}, + {file = "psutil-6.1.0.tar.gz", hash = "sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a"}, ] [package.extras] -test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] [[package]] name = "psygnal" @@ -1797,13 +1862,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyparsing" -version = "3.1.2" +version = "3.2.0" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false -python-versions = ">=3.6.8" +python-versions = ">=3.9" files = [ - {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, - {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, + {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, + {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, ] [package.extras] @@ -1825,182 +1890,209 @@ six = ">=1.5" [[package]] name = "pywin32" -version = "306" +version = "308" description = "Python for Window Extensions" optional = false python-versions = "*" files = [ - {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, - {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, - {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, - {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, - {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, - {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, - {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, - {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, - {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, - {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, - {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, - {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, - {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, - {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, + {file = "pywin32-308-cp310-cp310-win32.whl", hash = "sha256:796ff4426437896550d2981b9c2ac0ffd75238ad9ea2d3bfa67a1abd546d262e"}, + {file = "pywin32-308-cp310-cp310-win_amd64.whl", hash = "sha256:4fc888c59b3c0bef905ce7eb7e2106a07712015ea1c8234b703a088d46110e8e"}, + {file = "pywin32-308-cp310-cp310-win_arm64.whl", hash = "sha256:a5ab5381813b40f264fa3495b98af850098f814a25a63589a8e9eb12560f450c"}, + {file = "pywin32-308-cp311-cp311-win32.whl", hash = "sha256:5d8c8015b24a7d6855b1550d8e660d8daa09983c80e5daf89a273e5c6fb5095a"}, + {file = "pywin32-308-cp311-cp311-win_amd64.whl", hash = "sha256:575621b90f0dc2695fec346b2d6302faebd4f0f45c05ea29404cefe35d89442b"}, + {file = "pywin32-308-cp311-cp311-win_arm64.whl", hash = "sha256:100a5442b7332070983c4cd03f2e906a5648a5104b8a7f50175f7906efd16bb6"}, + {file = "pywin32-308-cp312-cp312-win32.whl", hash = "sha256:587f3e19696f4bf96fde9d8a57cec74a57021ad5f204c9e627e15c33ff568897"}, + {file = "pywin32-308-cp312-cp312-win_amd64.whl", hash = "sha256:00b3e11ef09ede56c6a43c71f2d31857cf7c54b0ab6e78ac659497abd2834f47"}, + {file = "pywin32-308-cp312-cp312-win_arm64.whl", hash = "sha256:9b4de86c8d909aed15b7011182c8cab38c8850de36e6afb1f0db22b8959e3091"}, + {file = "pywin32-308-cp313-cp313-win32.whl", hash = "sha256:1c44539a37a5b7b21d02ab34e6a4d314e0788f1690d65b48e9b0b89f31abbbed"}, + {file = "pywin32-308-cp313-cp313-win_amd64.whl", hash = "sha256:fd380990e792eaf6827fcb7e187b2b4b1cede0585e3d0c9e84201ec27b9905e4"}, + {file = "pywin32-308-cp313-cp313-win_arm64.whl", hash = "sha256:ef313c46d4c18dfb82a2431e3051ac8f112ccee1a34f29c263c583c568db63cd"}, + {file = "pywin32-308-cp37-cp37m-win32.whl", hash = "sha256:1f696ab352a2ddd63bd07430080dd598e6369152ea13a25ebcdd2f503a38f1ff"}, + {file = "pywin32-308-cp37-cp37m-win_amd64.whl", hash = "sha256:13dcb914ed4347019fbec6697a01a0aec61019c1046c2b905410d197856326a6"}, + {file = "pywin32-308-cp38-cp38-win32.whl", hash = "sha256:5794e764ebcabf4ff08c555b31bd348c9025929371763b2183172ff4708152f0"}, + {file = "pywin32-308-cp38-cp38-win_amd64.whl", hash = "sha256:3b92622e29d651c6b783e368ba7d6722b1634b8e70bd376fd7610fe1992e19de"}, + {file = "pywin32-308-cp39-cp39-win32.whl", hash = "sha256:7873ca4dc60ab3287919881a7d4f88baee4a6e639aa6962de25a98ba6b193341"}, + {file = "pywin32-308-cp39-cp39-win_amd64.whl", hash = "sha256:71b3322d949b4cc20776436a9c9ba0eeedcbc9c650daa536df63f0ff111bb920"}, ] [[package]] name = "pyyaml" -version = "6.0.1" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, - {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, - {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, - {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, - {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, - {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, - {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, - {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, - {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, - {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, - {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, - {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, - {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, - {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, - {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, - {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, - {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, - {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, - {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, - {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, - {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, - {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, - {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, - {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] name = "pyzmq" -version = "26.0.3" +version = "26.2.0" description = "Python bindings for 0MQ" optional = false python-versions = ">=3.7" files = [ - {file = "pyzmq-26.0.3-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:44dd6fc3034f1eaa72ece33588867df9e006a7303725a12d64c3dff92330f625"}, - {file = "pyzmq-26.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:acb704195a71ac5ea5ecf2811c9ee19ecdc62b91878528302dd0be1b9451cc90"}, - {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dbb9c997932473a27afa93954bb77a9f9b786b4ccf718d903f35da3232317de"}, - {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6bcb34f869d431799c3ee7d516554797f7760cb2198ecaa89c3f176f72d062be"}, - {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38ece17ec5f20d7d9b442e5174ae9f020365d01ba7c112205a4d59cf19dc38ee"}, - {file = "pyzmq-26.0.3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ba6e5e6588e49139a0979d03a7deb9c734bde647b9a8808f26acf9c547cab1bf"}, - {file = "pyzmq-26.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3bf8b000a4e2967e6dfdd8656cd0757d18c7e5ce3d16339e550bd462f4857e59"}, - {file = "pyzmq-26.0.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2136f64fbb86451dbbf70223635a468272dd20075f988a102bf8a3f194a411dc"}, - {file = "pyzmq-26.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e8918973fbd34e7814f59143c5f600ecd38b8038161239fd1a3d33d5817a38b8"}, - {file = "pyzmq-26.0.3-cp310-cp310-win32.whl", hash = "sha256:0aaf982e68a7ac284377d051c742610220fd06d330dcd4c4dbb4cdd77c22a537"}, - {file = "pyzmq-26.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f1a9b7d00fdf60b4039f4455afd031fe85ee8305b019334b72dcf73c567edc47"}, - {file = "pyzmq-26.0.3-cp310-cp310-win_arm64.whl", hash = "sha256:80b12f25d805a919d53efc0a5ad7c0c0326f13b4eae981a5d7b7cc343318ebb7"}, - {file = "pyzmq-26.0.3-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:a72a84570f84c374b4c287183debc776dc319d3e8ce6b6a0041ce2e400de3f32"}, - {file = "pyzmq-26.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7ca684ee649b55fd8f378127ac8462fb6c85f251c2fb027eb3c887e8ee347bcd"}, - {file = "pyzmq-26.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e222562dc0f38571c8b1ffdae9d7adb866363134299264a1958d077800b193b7"}, - {file = "pyzmq-26.0.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f17cde1db0754c35a91ac00b22b25c11da6eec5746431d6e5092f0cd31a3fea9"}, - {file = "pyzmq-26.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b7c0c0b3244bb2275abe255d4a30c050d541c6cb18b870975553f1fb6f37527"}, - {file = "pyzmq-26.0.3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac97a21de3712afe6a6c071abfad40a6224fd14fa6ff0ff8d0c6e6cd4e2f807a"}, - {file = "pyzmq-26.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:88b88282e55fa39dd556d7fc04160bcf39dea015f78e0cecec8ff4f06c1fc2b5"}, - {file = "pyzmq-26.0.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:72b67f966b57dbd18dcc7efbc1c7fc9f5f983e572db1877081f075004614fcdd"}, - {file = "pyzmq-26.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f4b6cecbbf3b7380f3b61de3a7b93cb721125dc125c854c14ddc91225ba52f83"}, - {file = "pyzmq-26.0.3-cp311-cp311-win32.whl", hash = "sha256:eed56b6a39216d31ff8cd2f1d048b5bf1700e4b32a01b14379c3b6dde9ce3aa3"}, - {file = "pyzmq-26.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:3191d312c73e3cfd0f0afdf51df8405aafeb0bad71e7ed8f68b24b63c4f36500"}, - {file = "pyzmq-26.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:b6907da3017ef55139cf0e417c5123a84c7332520e73a6902ff1f79046cd3b94"}, - {file = "pyzmq-26.0.3-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:068ca17214038ae986d68f4a7021f97e187ed278ab6dccb79f837d765a54d753"}, - {file = "pyzmq-26.0.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7821d44fe07335bea256b9f1f41474a642ca55fa671dfd9f00af8d68a920c2d4"}, - {file = "pyzmq-26.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eeb438a26d87c123bb318e5f2b3d86a36060b01f22fbdffd8cf247d52f7c9a2b"}, - {file = "pyzmq-26.0.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:69ea9d6d9baa25a4dc9cef5e2b77b8537827b122214f210dd925132e34ae9b12"}, - {file = "pyzmq-26.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7daa3e1369355766dea11f1d8ef829905c3b9da886ea3152788dc25ee6079e02"}, - {file = "pyzmq-26.0.3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6ca7a9a06b52d0e38ccf6bca1aeff7be178917893f3883f37b75589d42c4ac20"}, - {file = "pyzmq-26.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1b7d0e124948daa4d9686d421ef5087c0516bc6179fdcf8828b8444f8e461a77"}, - {file = "pyzmq-26.0.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e746524418b70f38550f2190eeee834db8850088c834d4c8406fbb9bc1ae10b2"}, - {file = "pyzmq-26.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:6b3146f9ae6af82c47a5282ac8803523d381b3b21caeae0327ed2f7ecb718798"}, - {file = "pyzmq-26.0.3-cp312-cp312-win32.whl", hash = "sha256:2b291d1230845871c00c8462c50565a9cd6026fe1228e77ca934470bb7d70ea0"}, - {file = "pyzmq-26.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:926838a535c2c1ea21c903f909a9a54e675c2126728c21381a94ddf37c3cbddf"}, - {file = "pyzmq-26.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:5bf6c237f8c681dfb91b17f8435b2735951f0d1fad10cc5dfd96db110243370b"}, - {file = "pyzmq-26.0.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0c0991f5a96a8e620f7691e61178cd8f457b49e17b7d9cfa2067e2a0a89fc1d5"}, - {file = "pyzmq-26.0.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dbf012d8fcb9f2cf0643b65df3b355fdd74fc0035d70bb5c845e9e30a3a4654b"}, - {file = "pyzmq-26.0.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:01fbfbeb8249a68d257f601deb50c70c929dc2dfe683b754659569e502fbd3aa"}, - {file = "pyzmq-26.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c8eb19abe87029c18f226d42b8a2c9efdd139d08f8bf6e085dd9075446db450"}, - {file = "pyzmq-26.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5344b896e79800af86ad643408ca9aa303a017f6ebff8cee5a3163c1e9aec987"}, - {file = "pyzmq-26.0.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:204e0f176fd1d067671157d049466869b3ae1fc51e354708b0dc41cf94e23a3a"}, - {file = "pyzmq-26.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a42db008d58530efa3b881eeee4991146de0b790e095f7ae43ba5cc612decbc5"}, - {file = "pyzmq-26.0.3-cp37-cp37m-win32.whl", hash = "sha256:8d7a498671ca87e32b54cb47c82a92b40130a26c5197d392720a1bce1b3c77cf"}, - {file = "pyzmq-26.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:3b4032a96410bdc760061b14ed6a33613ffb7f702181ba999df5d16fb96ba16a"}, - {file = "pyzmq-26.0.3-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:2cc4e280098c1b192c42a849de8de2c8e0f3a84086a76ec5b07bfee29bda7d18"}, - {file = "pyzmq-26.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5bde86a2ed3ce587fa2b207424ce15b9a83a9fa14422dcc1c5356a13aed3df9d"}, - {file = "pyzmq-26.0.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:34106f68e20e6ff253c9f596ea50397dbd8699828d55e8fa18bd4323d8d966e6"}, - {file = "pyzmq-26.0.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ebbbd0e728af5db9b04e56389e2299a57ea8b9dd15c9759153ee2455b32be6ad"}, - {file = "pyzmq-26.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6b1d1c631e5940cac5a0b22c5379c86e8df6a4ec277c7a856b714021ab6cfad"}, - {file = "pyzmq-26.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e891ce81edd463b3b4c3b885c5603c00141151dd9c6936d98a680c8c72fe5c67"}, - {file = "pyzmq-26.0.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9b273ecfbc590a1b98f014ae41e5cf723932f3b53ba9367cfb676f838038b32c"}, - {file = "pyzmq-26.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b32bff85fb02a75ea0b68f21e2412255b5731f3f389ed9aecc13a6752f58ac97"}, - {file = "pyzmq-26.0.3-cp38-cp38-win32.whl", hash = "sha256:f6c21c00478a7bea93caaaef9e7629145d4153b15a8653e8bb4609d4bc70dbfc"}, - {file = "pyzmq-26.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:3401613148d93ef0fd9aabdbddb212de3db7a4475367f49f590c837355343972"}, - {file = "pyzmq-26.0.3-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:2ed8357f4c6e0daa4f3baf31832df8a33334e0fe5b020a61bc8b345a3db7a606"}, - {file = "pyzmq-26.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c1c8f2a2ca45292084c75bb6d3a25545cff0ed931ed228d3a1810ae3758f975f"}, - {file = "pyzmq-26.0.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:b63731993cdddcc8e087c64e9cf003f909262b359110070183d7f3025d1c56b5"}, - {file = "pyzmq-26.0.3-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b3cd31f859b662ac5d7f4226ec7d8bd60384fa037fc02aee6ff0b53ba29a3ba8"}, - {file = "pyzmq-26.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:115f8359402fa527cf47708d6f8a0f8234f0e9ca0cab7c18c9c189c194dbf620"}, - {file = "pyzmq-26.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:715bdf952b9533ba13dfcf1f431a8f49e63cecc31d91d007bc1deb914f47d0e4"}, - {file = "pyzmq-26.0.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e1258c639e00bf5e8a522fec6c3eaa3e30cf1c23a2f21a586be7e04d50c9acab"}, - {file = "pyzmq-26.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:15c59e780be8f30a60816a9adab900c12a58d79c1ac742b4a8df044ab2a6d920"}, - {file = "pyzmq-26.0.3-cp39-cp39-win32.whl", hash = "sha256:d0cdde3c78d8ab5b46595054e5def32a755fc028685add5ddc7403e9f6de9879"}, - {file = "pyzmq-26.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:ce828058d482ef860746bf532822842e0ff484e27f540ef5c813d516dd8896d2"}, - {file = "pyzmq-26.0.3-cp39-cp39-win_arm64.whl", hash = "sha256:788f15721c64109cf720791714dc14afd0f449d63f3a5487724f024345067381"}, - {file = "pyzmq-26.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c18645ef6294d99b256806e34653e86236eb266278c8ec8112622b61db255de"}, - {file = "pyzmq-26.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e6bc96ebe49604df3ec2c6389cc3876cabe475e6bfc84ced1bf4e630662cb35"}, - {file = "pyzmq-26.0.3-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:971e8990c5cc4ddcff26e149398fc7b0f6a042306e82500f5e8db3b10ce69f84"}, - {file = "pyzmq-26.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8416c23161abd94cc7da80c734ad7c9f5dbebdadfdaa77dad78244457448223"}, - {file = "pyzmq-26.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:082a2988364b60bb5de809373098361cf1dbb239623e39e46cb18bc035ed9c0c"}, - {file = "pyzmq-26.0.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d57dfbf9737763b3a60d26e6800e02e04284926329aee8fb01049635e957fe81"}, - {file = "pyzmq-26.0.3-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:77a85dca4c2430ac04dc2a2185c2deb3858a34fe7f403d0a946fa56970cf60a1"}, - {file = "pyzmq-26.0.3-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4c82a6d952a1d555bf4be42b6532927d2a5686dd3c3e280e5f63225ab47ac1f5"}, - {file = "pyzmq-26.0.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4496b1282c70c442809fc1b151977c3d967bfb33e4e17cedbf226d97de18f709"}, - {file = "pyzmq-26.0.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:e4946d6bdb7ba972dfda282f9127e5756d4f299028b1566d1245fa0d438847e6"}, - {file = "pyzmq-26.0.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:03c0ae165e700364b266876d712acb1ac02693acd920afa67da2ebb91a0b3c09"}, - {file = "pyzmq-26.0.3-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:3e3070e680f79887d60feeda051a58d0ac36622e1759f305a41059eff62c6da7"}, - {file = "pyzmq-26.0.3-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6ca08b840fe95d1c2bd9ab92dac5685f949fc6f9ae820ec16193e5ddf603c3b2"}, - {file = "pyzmq-26.0.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e76654e9dbfb835b3518f9938e565c7806976c07b37c33526b574cc1a1050480"}, - {file = "pyzmq-26.0.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:871587bdadd1075b112e697173e946a07d722459d20716ceb3d1bd6c64bd08ce"}, - {file = "pyzmq-26.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d0a2d1bd63a4ad79483049b26514e70fa618ce6115220da9efdff63688808b17"}, - {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0270b49b6847f0d106d64b5086e9ad5dc8a902413b5dbbb15d12b60f9c1747a4"}, - {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:703c60b9910488d3d0954ca585c34f541e506a091a41930e663a098d3b794c67"}, - {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74423631b6be371edfbf7eabb02ab995c2563fee60a80a30829176842e71722a"}, - {file = "pyzmq-26.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4adfbb5451196842a88fda3612e2c0414134874bffb1c2ce83ab4242ec9e027d"}, - {file = "pyzmq-26.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:3516119f4f9b8671083a70b6afaa0a070f5683e431ab3dc26e9215620d7ca1ad"}, - {file = "pyzmq-26.0.3.tar.gz", hash = "sha256:dba7d9f2e047dfa2bca3b01f4f84aa5246725203d6284e3790f2ca15fba6b40a"}, + {file = "pyzmq-26.2.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:ddf33d97d2f52d89f6e6e7ae66ee35a4d9ca6f36eda89c24591b0c40205a3629"}, + {file = "pyzmq-26.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dacd995031a01d16eec825bf30802fceb2c3791ef24bcce48fa98ce40918c27b"}, + {file = "pyzmq-26.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89289a5ee32ef6c439086184529ae060c741334b8970a6855ec0b6ad3ff28764"}, + {file = "pyzmq-26.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5506f06d7dc6ecf1efacb4a013b1f05071bb24b76350832c96449f4a2d95091c"}, + {file = "pyzmq-26.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ea039387c10202ce304af74def5021e9adc6297067f3441d348d2b633e8166a"}, + {file = "pyzmq-26.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a2224fa4a4c2ee872886ed00a571f5e967c85e078e8e8c2530a2fb01b3309b88"}, + {file = "pyzmq-26.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:28ad5233e9c3b52d76196c696e362508959741e1a005fb8fa03b51aea156088f"}, + {file = "pyzmq-26.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:1c17211bc037c7d88e85ed8b7d8f7e52db6dc8eca5590d162717c654550f7282"}, + {file = "pyzmq-26.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b8f86dd868d41bea9a5f873ee13bf5551c94cf6bc51baebc6f85075971fe6eea"}, + {file = "pyzmq-26.2.0-cp310-cp310-win32.whl", hash = "sha256:46a446c212e58456b23af260f3d9fb785054f3e3653dbf7279d8f2b5546b21c2"}, + {file = "pyzmq-26.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:49d34ab71db5a9c292a7644ce74190b1dd5a3475612eefb1f8be1d6961441971"}, + {file = "pyzmq-26.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:bfa832bfa540e5b5c27dcf5de5d82ebc431b82c453a43d141afb1e5d2de025fa"}, + {file = "pyzmq-26.2.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:8f7e66c7113c684c2b3f1c83cdd3376103ee0ce4c49ff80a648643e57fb22218"}, + {file = "pyzmq-26.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3a495b30fc91db2db25120df5847d9833af237546fd59170701acd816ccc01c4"}, + {file = "pyzmq-26.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77eb0968da535cba0470a5165468b2cac7772cfb569977cff92e240f57e31bef"}, + {file = "pyzmq-26.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ace4f71f1900a548f48407fc9be59c6ba9d9aaf658c2eea6cf2779e72f9f317"}, + {file = "pyzmq-26.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92a78853d7280bffb93df0a4a6a2498cba10ee793cc8076ef797ef2f74d107cf"}, + {file = "pyzmq-26.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:689c5d781014956a4a6de61d74ba97b23547e431e9e7d64f27d4922ba96e9d6e"}, + {file = "pyzmq-26.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aca98bc423eb7d153214b2df397c6421ba6373d3397b26c057af3c904452e37"}, + {file = "pyzmq-26.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f3496d76b89d9429a656293744ceca4d2ac2a10ae59b84c1da9b5165f429ad3"}, + {file = "pyzmq-26.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5c2b3bfd4b9689919db068ac6c9911f3fcb231c39f7dd30e3138be94896d18e6"}, + {file = "pyzmq-26.2.0-cp311-cp311-win32.whl", hash = "sha256:eac5174677da084abf378739dbf4ad245661635f1600edd1221f150b165343f4"}, + {file = "pyzmq-26.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:5a509df7d0a83a4b178d0f937ef14286659225ef4e8812e05580776c70e155d5"}, + {file = "pyzmq-26.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:c0e6091b157d48cbe37bd67233318dbb53e1e6327d6fc3bb284afd585d141003"}, + {file = "pyzmq-26.2.0-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:ded0fc7d90fe93ae0b18059930086c51e640cdd3baebdc783a695c77f123dcd9"}, + {file = "pyzmq-26.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:17bf5a931c7f6618023cdacc7081f3f266aecb68ca692adac015c383a134ca52"}, + {file = "pyzmq-26.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55cf66647e49d4621a7e20c8d13511ef1fe1efbbccf670811864452487007e08"}, + {file = "pyzmq-26.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4661c88db4a9e0f958c8abc2b97472e23061f0bc737f6f6179d7a27024e1faa5"}, + {file = "pyzmq-26.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea7f69de383cb47522c9c208aec6dd17697db7875a4674c4af3f8cfdac0bdeae"}, + {file = "pyzmq-26.2.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7f98f6dfa8b8ccaf39163ce872bddacca38f6a67289116c8937a02e30bbe9711"}, + {file = "pyzmq-26.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e3e0210287329272539eea617830a6a28161fbbd8a3271bf4150ae3e58c5d0e6"}, + {file = "pyzmq-26.2.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6b274e0762c33c7471f1a7471d1a2085b1a35eba5cdc48d2ae319f28b6fc4de3"}, + {file = "pyzmq-26.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:29c6a4635eef69d68a00321e12a7d2559fe2dfccfa8efae3ffb8e91cd0b36a8b"}, + {file = "pyzmq-26.2.0-cp312-cp312-win32.whl", hash = "sha256:989d842dc06dc59feea09e58c74ca3e1678c812a4a8a2a419046d711031f69c7"}, + {file = "pyzmq-26.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:2a50625acdc7801bc6f74698c5c583a491c61d73c6b7ea4dee3901bb99adb27a"}, + {file = "pyzmq-26.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:4d29ab8592b6ad12ebbf92ac2ed2bedcfd1cec192d8e559e2e099f648570e19b"}, + {file = "pyzmq-26.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9dd8cd1aeb00775f527ec60022004d030ddc51d783d056e3e23e74e623e33726"}, + {file = "pyzmq-26.2.0-cp313-cp313-macosx_10_15_universal2.whl", hash = "sha256:28c812d9757fe8acecc910c9ac9dafd2ce968c00f9e619db09e9f8f54c3a68a3"}, + {file = "pyzmq-26.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d80b1dd99c1942f74ed608ddb38b181b87476c6a966a88a950c7dee118fdf50"}, + {file = "pyzmq-26.2.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c997098cc65e3208eca09303630e84d42718620e83b733d0fd69543a9cab9cb"}, + {file = "pyzmq-26.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ad1bc8d1b7a18497dda9600b12dc193c577beb391beae5cd2349184db40f187"}, + {file = "pyzmq-26.2.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:bea2acdd8ea4275e1278350ced63da0b166421928276c7c8e3f9729d7402a57b"}, + {file = "pyzmq-26.2.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:23f4aad749d13698f3f7b64aad34f5fc02d6f20f05999eebc96b89b01262fb18"}, + {file = "pyzmq-26.2.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:a4f96f0d88accc3dbe4a9025f785ba830f968e21e3e2c6321ccdfc9aef755115"}, + {file = "pyzmq-26.2.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ced65e5a985398827cc9276b93ef6dfabe0273c23de8c7931339d7e141c2818e"}, + {file = "pyzmq-26.2.0-cp313-cp313-win32.whl", hash = "sha256:31507f7b47cc1ead1f6e86927f8ebb196a0bab043f6345ce070f412a59bf87b5"}, + {file = "pyzmq-26.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:70fc7fcf0410d16ebdda9b26cbd8bf8d803d220a7f3522e060a69a9c87bf7bad"}, + {file = "pyzmq-26.2.0-cp313-cp313-win_arm64.whl", hash = "sha256:c3789bd5768ab5618ebf09cef6ec2b35fed88709b104351748a63045f0ff9797"}, + {file = "pyzmq-26.2.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:034da5fc55d9f8da09015d368f519478a52675e558c989bfcb5cf6d4e16a7d2a"}, + {file = "pyzmq-26.2.0-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:c92d73464b886931308ccc45b2744e5968cbaade0b1d6aeb40d8ab537765f5bc"}, + {file = "pyzmq-26.2.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:794a4562dcb374f7dbbfb3f51d28fb40123b5a2abadee7b4091f93054909add5"}, + {file = "pyzmq-26.2.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aee22939bb6075e7afededabad1a56a905da0b3c4e3e0c45e75810ebe3a52672"}, + {file = "pyzmq-26.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ae90ff9dad33a1cfe947d2c40cb9cb5e600d759ac4f0fd22616ce6540f72797"}, + {file = "pyzmq-26.2.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:43a47408ac52647dfabbc66a25b05b6a61700b5165807e3fbd40063fcaf46386"}, + {file = "pyzmq-26.2.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:25bf2374a2a8433633c65ccb9553350d5e17e60c8eb4de4d92cc6bd60f01d306"}, + {file = "pyzmq-26.2.0-cp313-cp313t-musllinux_1_1_i686.whl", hash = "sha256:007137c9ac9ad5ea21e6ad97d3489af654381324d5d3ba614c323f60dab8fae6"}, + {file = "pyzmq-26.2.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:470d4a4f6d48fb34e92d768b4e8a5cc3780db0d69107abf1cd7ff734b9766eb0"}, + {file = "pyzmq-26.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3b55a4229ce5da9497dd0452b914556ae58e96a4381bb6f59f1305dfd7e53fc8"}, + {file = "pyzmq-26.2.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9cb3a6460cdea8fe8194a76de8895707e61ded10ad0be97188cc8463ffa7e3a8"}, + {file = "pyzmq-26.2.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8ab5cad923cc95c87bffee098a27856c859bd5d0af31bd346035aa816b081fe1"}, + {file = "pyzmq-26.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ed69074a610fad1c2fda66180e7b2edd4d31c53f2d1872bc2d1211563904cd9"}, + {file = "pyzmq-26.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:cccba051221b916a4f5e538997c45d7d136a5646442b1231b916d0164067ea27"}, + {file = "pyzmq-26.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:0eaa83fc4c1e271c24eaf8fb083cbccef8fde77ec8cd45f3c35a9a123e6da097"}, + {file = "pyzmq-26.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9edda2df81daa129b25a39b86cb57dfdfe16f7ec15b42b19bfac503360d27a93"}, + {file = "pyzmq-26.2.0-cp37-cp37m-win32.whl", hash = "sha256:ea0eb6af8a17fa272f7b98d7bebfab7836a0d62738e16ba380f440fceca2d951"}, + {file = "pyzmq-26.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:4ff9dc6bc1664bb9eec25cd17506ef6672d506115095411e237d571e92a58231"}, + {file = "pyzmq-26.2.0-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:2eb7735ee73ca1b0d71e0e67c3739c689067f055c764f73aac4cc8ecf958ee3f"}, + {file = "pyzmq-26.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a534f43bc738181aa7cbbaf48e3eca62c76453a40a746ab95d4b27b1111a7d2"}, + {file = "pyzmq-26.2.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:aedd5dd8692635813368e558a05266b995d3d020b23e49581ddd5bbe197a8ab6"}, + {file = "pyzmq-26.2.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8be4700cd8bb02cc454f630dcdf7cfa99de96788b80c51b60fe2fe1dac480289"}, + {file = "pyzmq-26.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fcc03fa4997c447dce58264e93b5aa2d57714fbe0f06c07b7785ae131512732"}, + {file = "pyzmq-26.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:402b190912935d3db15b03e8f7485812db350d271b284ded2b80d2e5704be780"}, + {file = "pyzmq-26.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8685fa9c25ff00f550c1fec650430c4b71e4e48e8d852f7ddcf2e48308038640"}, + {file = "pyzmq-26.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:76589c020680778f06b7e0b193f4b6dd66d470234a16e1df90329f5e14a171cd"}, + {file = "pyzmq-26.2.0-cp38-cp38-win32.whl", hash = "sha256:8423c1877d72c041f2c263b1ec6e34360448decfb323fa8b94e85883043ef988"}, + {file = "pyzmq-26.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:76589f2cd6b77b5bdea4fca5992dc1c23389d68b18ccc26a53680ba2dc80ff2f"}, + {file = "pyzmq-26.2.0-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:b1d464cb8d72bfc1a3adc53305a63a8e0cac6bc8c5a07e8ca190ab8d3faa43c2"}, + {file = "pyzmq-26.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4da04c48873a6abdd71811c5e163bd656ee1b957971db7f35140a2d573f6949c"}, + {file = "pyzmq-26.2.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d049df610ac811dcffdc147153b414147428567fbbc8be43bb8885f04db39d98"}, + {file = "pyzmq-26.2.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05590cdbc6b902101d0e65d6a4780af14dc22914cc6ab995d99b85af45362cc9"}, + {file = "pyzmq-26.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c811cfcd6a9bf680236c40c6f617187515269ab2912f3d7e8c0174898e2519db"}, + {file = "pyzmq-26.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6835dd60355593de10350394242b5757fbbd88b25287314316f266e24c61d073"}, + {file = "pyzmq-26.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc6bee759a6bddea5db78d7dcd609397449cb2d2d6587f48f3ca613b19410cfc"}, + {file = "pyzmq-26.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c530e1eecd036ecc83c3407f77bb86feb79916d4a33d11394b8234f3bd35b940"}, + {file = "pyzmq-26.2.0-cp39-cp39-win32.whl", hash = "sha256:367b4f689786fca726ef7a6c5ba606958b145b9340a5e4808132cc65759abd44"}, + {file = "pyzmq-26.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:e6fa2e3e683f34aea77de8112f6483803c96a44fd726d7358b9888ae5bb394ec"}, + {file = "pyzmq-26.2.0-cp39-cp39-win_arm64.whl", hash = "sha256:7445be39143a8aa4faec43b076e06944b8f9d0701b669df4af200531b21e40bb"}, + {file = "pyzmq-26.2.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:706e794564bec25819d21a41c31d4df2d48e1cc4b061e8d345d7fb4dd3e94072"}, + {file = "pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b435f2753621cd36e7c1762156815e21c985c72b19135dac43a7f4f31d28dd1"}, + {file = "pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160c7e0a5eb178011e72892f99f918c04a131f36056d10d9c1afb223fc952c2d"}, + {file = "pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4a71d5d6e7b28a47a394c0471b7e77a0661e2d651e7ae91e0cab0a587859ca"}, + {file = "pyzmq-26.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:90412f2db8c02a3864cbfc67db0e3dcdbda336acf1c469526d3e869394fe001c"}, + {file = "pyzmq-26.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2ea4ad4e6a12e454de05f2949d4beddb52460f3de7c8b9d5c46fbb7d7222e02c"}, + {file = "pyzmq-26.2.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fc4f7a173a5609631bb0c42c23d12c49df3966f89f496a51d3eb0ec81f4519d6"}, + {file = "pyzmq-26.2.0-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:878206a45202247781472a2d99df12a176fef806ca175799e1c6ad263510d57c"}, + {file = "pyzmq-26.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17c412bad2eb9468e876f556eb4ee910e62d721d2c7a53c7fa31e643d35352e6"}, + {file = "pyzmq-26.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:0d987a3ae5a71c6226b203cfd298720e0086c7fe7c74f35fa8edddfbd6597eed"}, + {file = "pyzmq-26.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:39887ac397ff35b7b775db7201095fc6310a35fdbae85bac4523f7eb3b840e20"}, + {file = "pyzmq-26.2.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fdb5b3e311d4d4b0eb8b3e8b4d1b0a512713ad7e6a68791d0923d1aec433d919"}, + {file = "pyzmq-26.2.0-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:226af7dcb51fdb0109f0016449b357e182ea0ceb6b47dfb5999d569e5db161d5"}, + {file = "pyzmq-26.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bed0e799e6120b9c32756203fb9dfe8ca2fb8467fed830c34c877e25638c3fc"}, + {file = "pyzmq-26.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:29c7947c594e105cb9e6c466bace8532dc1ca02d498684128b339799f5248277"}, + {file = "pyzmq-26.2.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:cdeabcff45d1c219636ee2e54d852262e5c2e085d6cb476d938aee8d921356b3"}, + {file = "pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35cffef589bcdc587d06f9149f8d5e9e8859920a071df5a2671de2213bef592a"}, + {file = "pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18c8dc3b7468d8b4bdf60ce9d7141897da103c7a4690157b32b60acb45e333e6"}, + {file = "pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7133d0a1677aec369d67dd78520d3fa96dd7f3dcec99d66c1762870e5ea1a50a"}, + {file = "pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6a96179a24b14fa6428cbfc08641c779a53f8fcec43644030328f44034c7f1f4"}, + {file = "pyzmq-26.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:4f78c88905461a9203eac9faac157a2a0dbba84a0fd09fd29315db27be40af9f"}, + {file = "pyzmq-26.2.0.tar.gz", hash = "sha256:070672c258581c8e4f640b5159297580a9974b026043bd4ab0470be9ed324f1f"}, ] [package.dependencies] @@ -2044,148 +2136,156 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "rpds-py" -version = "0.19.1" +version = "0.20.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" files = [ - {file = "rpds_py-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:aaf71f95b21f9dc708123335df22e5a2fef6307e3e6f9ed773b2e0938cc4d491"}, - {file = "rpds_py-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca0dda0c5715efe2ab35bb83f813f681ebcd2840d8b1b92bfc6fe3ab382fae4a"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81db2e7282cc0487f500d4db203edc57da81acde9e35f061d69ed983228ffe3b"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a8dfa125b60ec00c7c9baef945bb04abf8ac772d8ebefd79dae2a5f316d7850"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:271accf41b02687cef26367c775ab220372ee0f4925591c6796e7c148c50cab5"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9bc4161bd3b970cd6a6fcda70583ad4afd10f2750609fb1f3ca9505050d4ef3"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0cf2a0dbb5987da4bd92a7ca727eadb225581dd9681365beba9accbe5308f7d"}, - {file = "rpds_py-0.19.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b5e28e56143750808c1c79c70a16519e9bc0a68b623197b96292b21b62d6055c"}, - {file = "rpds_py-0.19.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c7af6f7b80f687b33a4cdb0a785a5d4de1fb027a44c9a049d8eb67d5bfe8a687"}, - {file = "rpds_py-0.19.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e429fc517a1c5e2a70d576077231538a98d59a45dfc552d1ac45a132844e6dfb"}, - {file = "rpds_py-0.19.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d2dbd8f4990d4788cb122f63bf000357533f34860d269c1a8e90ae362090ff3a"}, - {file = "rpds_py-0.19.1-cp310-none-win32.whl", hash = "sha256:e0f9d268b19e8f61bf42a1da48276bcd05f7ab5560311f541d22557f8227b866"}, - {file = "rpds_py-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:df7c841813f6265e636fe548a49664c77af31ddfa0085515326342a751a6ba51"}, - {file = "rpds_py-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:902cf4739458852fe917104365ec0efbea7d29a15e4276c96a8d33e6ed8ec137"}, - {file = "rpds_py-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f3d73022990ab0c8b172cce57c69fd9a89c24fd473a5e79cbce92df87e3d9c48"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3837c63dd6918a24de6c526277910e3766d8c2b1627c500b155f3eecad8fad65"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cdb7eb3cf3deb3dd9e7b8749323b5d970052711f9e1e9f36364163627f96da58"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26ab43b6d65d25b1a333c8d1b1c2f8399385ff683a35ab5e274ba7b8bb7dc61c"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75130df05aae7a7ac171b3b5b24714cffeabd054ad2ebc18870b3aa4526eba23"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c34f751bf67cab69638564eee34023909380ba3e0d8ee7f6fe473079bf93f09b"}, - {file = "rpds_py-0.19.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f2671cb47e50a97f419a02cd1e0c339b31de017b033186358db92f4d8e2e17d8"}, - {file = "rpds_py-0.19.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c73254c256081704dba0a333457e2fb815364018788f9b501efe7c5e0ada401"}, - {file = "rpds_py-0.19.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4383beb4a29935b8fa28aca8fa84c956bf545cb0c46307b091b8d312a9150e6a"}, - {file = "rpds_py-0.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dbceedcf4a9329cc665452db1aaf0845b85c666e4885b92ee0cddb1dbf7e052a"}, - {file = "rpds_py-0.19.1-cp311-none-win32.whl", hash = "sha256:f0a6d4a93d2a05daec7cb885157c97bbb0be4da739d6f9dfb02e101eb40921cd"}, - {file = "rpds_py-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:c149a652aeac4902ecff2dd93c3b2681c608bd5208c793c4a99404b3e1afc87c"}, - {file = "rpds_py-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:56313be667a837ff1ea3508cebb1ef6681d418fa2913a0635386cf29cff35165"}, - {file = "rpds_py-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d1d7539043b2b31307f2c6c72957a97c839a88b2629a348ebabe5aa8b626d6b"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e1dc59a5e7bc7f44bd0c048681f5e05356e479c50be4f2c1a7089103f1621d5"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8f78398e67a7227aefa95f876481485403eb974b29e9dc38b307bb6eb2315ea"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ef07a0a1d254eeb16455d839cef6e8c2ed127f47f014bbda64a58b5482b6c836"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8124101e92c56827bebef084ff106e8ea11c743256149a95b9fd860d3a4f331f"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08ce9c95a0b093b7aec75676b356a27879901488abc27e9d029273d280438505"}, - {file = "rpds_py-0.19.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0b02dd77a2de6e49078c8937aadabe933ceac04b41c5dde5eca13a69f3cf144e"}, - {file = "rpds_py-0.19.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4dd02e29c8cbed21a1875330b07246b71121a1c08e29f0ee3db5b4cfe16980c4"}, - {file = "rpds_py-0.19.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9c7042488165f7251dc7894cd533a875d2875af6d3b0e09eda9c4b334627ad1c"}, - {file = "rpds_py-0.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f809a17cc78bd331e137caa25262b507225854073fd319e987bd216bed911b7c"}, - {file = "rpds_py-0.19.1-cp312-none-win32.whl", hash = "sha256:3ddab996807c6b4227967fe1587febade4e48ac47bb0e2d3e7858bc621b1cace"}, - {file = "rpds_py-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:32e0db3d6e4f45601b58e4ac75c6f24afbf99818c647cc2066f3e4b192dabb1f"}, - {file = "rpds_py-0.19.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:747251e428406b05fc86fee3904ee19550c4d2d19258cef274e2151f31ae9d38"}, - {file = "rpds_py-0.19.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:dc733d35f861f8d78abfaf54035461e10423422999b360966bf1c443cbc42705"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbda75f245caecff8faa7e32ee94dfaa8312a3367397975527f29654cd17a6ed"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd04d8cab16cab5b0a9ffc7d10f0779cf1120ab16c3925404428f74a0a43205a"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2d66eb41ffca6cc3c91d8387509d27ba73ad28371ef90255c50cb51f8953301"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdf4890cda3b59170009d012fca3294c00140e7f2abe1910e6a730809d0f3f9b"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1fa67ef839bad3815124f5f57e48cd50ff392f4911a9f3cf449d66fa3df62a5"}, - {file = "rpds_py-0.19.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b82c9514c6d74b89a370c4060bdb80d2299bc6857e462e4a215b4ef7aa7b090e"}, - {file = "rpds_py-0.19.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c7b07959866a6afb019abb9564d8a55046feb7a84506c74a6f197cbcdf8a208e"}, - {file = "rpds_py-0.19.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4f580ae79d0b861dfd912494ab9d477bea535bfb4756a2269130b6607a21802e"}, - {file = "rpds_py-0.19.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c6d20c8896c00775e6f62d8373aba32956aa0b850d02b5ec493f486c88e12859"}, - {file = "rpds_py-0.19.1-cp313-none-win32.whl", hash = "sha256:afedc35fe4b9e30ab240b208bb9dc8938cb4afe9187589e8d8d085e1aacb8309"}, - {file = "rpds_py-0.19.1-cp313-none-win_amd64.whl", hash = "sha256:1d4af2eb520d759f48f1073ad3caef997d1bfd910dc34e41261a595d3f038a94"}, - {file = "rpds_py-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:34bca66e2e3eabc8a19e9afe0d3e77789733c702c7c43cd008e953d5d1463fde"}, - {file = "rpds_py-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:24f8ae92c7fae7c28d0fae9b52829235df83f34847aa8160a47eb229d9666c7b"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71157f9db7f6bc6599a852852f3389343bea34315b4e6f109e5cbc97c1fb2963"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1d494887d40dc4dd0d5a71e9d07324e5c09c4383d93942d391727e7a40ff810b"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b3661e6d4ba63a094138032c1356d557de5b3ea6fd3cca62a195f623e381c76"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97fbb77eaeb97591efdc654b8b5f3ccc066406ccfb3175b41382f221ecc216e8"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4cc4bc73e53af8e7a42c8fd7923bbe35babacfa7394ae9240b3430b5dcf16b2a"}, - {file = "rpds_py-0.19.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:35af5e4d5448fa179fd7fff0bba0fba51f876cd55212f96c8bbcecc5c684ae5c"}, - {file = "rpds_py-0.19.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3511f6baf8438326e351097cecd137eb45c5f019944fe0fd0ae2fea2fd26be39"}, - {file = "rpds_py-0.19.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:57863d16187995c10fe9cf911b897ed443ac68189179541734502353af33e693"}, - {file = "rpds_py-0.19.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9e318e6786b1e750a62f90c6f7fa8b542102bdcf97c7c4de2a48b50b61bd36ec"}, - {file = "rpds_py-0.19.1-cp38-none-win32.whl", hash = "sha256:53dbc35808c6faa2ce3e48571f8f74ef70802218554884787b86a30947842a14"}, - {file = "rpds_py-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:8df1c283e57c9cb4d271fdc1875f4a58a143a2d1698eb0d6b7c0d7d5f49c53a1"}, - {file = "rpds_py-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e76c902d229a3aa9d5ceb813e1cbcc69bf5bda44c80d574ff1ac1fa3136dea71"}, - {file = "rpds_py-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:de1f7cd5b6b351e1afd7568bdab94934d656abe273d66cda0ceea43bbc02a0c2"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24fc5a84777cb61692d17988989690d6f34f7f95968ac81398d67c0d0994a897"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:74129d5ffc4cde992d89d345f7f7d6758320e5d44a369d74d83493429dad2de5"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e360188b72f8080fefa3adfdcf3618604cc8173651c9754f189fece068d2a45"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13e6d4840897d4e4e6b2aa1443e3a8eca92b0402182aafc5f4ca1f5e24f9270a"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f09529d2332264a902688031a83c19de8fda5eb5881e44233286b9c9ec91856d"}, - {file = "rpds_py-0.19.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0d4b52811dcbc1aba08fd88d475f75b4f6db0984ba12275d9bed1a04b2cae9b5"}, - {file = "rpds_py-0.19.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dd635c2c4043222d80d80ca1ac4530a633102a9f2ad12252183bcf338c1b9474"}, - {file = "rpds_py-0.19.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f35b34a5184d5e0cc360b61664c1c06e866aab077b5a7c538a3e20c8fcdbf90b"}, - {file = "rpds_py-0.19.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d4ec0046facab83012d821b33cead742a35b54575c4edfb7ed7445f63441835f"}, - {file = "rpds_py-0.19.1-cp39-none-win32.whl", hash = "sha256:f5b8353ea1a4d7dfb59a7f45c04df66ecfd363bb5b35f33b11ea579111d4655f"}, - {file = "rpds_py-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:1fb93d3486f793d54a094e2bfd9cd97031f63fcb5bc18faeb3dd4b49a1c06523"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:7d5c7e32f3ee42f77d8ff1a10384b5cdcc2d37035e2e3320ded909aa192d32c3"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:89cc8921a4a5028d6dd388c399fcd2eef232e7040345af3d5b16c04b91cf3c7e"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bca34e913d27401bda2a6f390d0614049f5a95b3b11cd8eff80fe4ec340a1208"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5953391af1405f968eb5701ebbb577ebc5ced8d0041406f9052638bafe52209d"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:840e18c38098221ea6201f091fc5d4de6128961d2930fbbc96806fb43f69aec1"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6d8b735c4d162dc7d86a9cf3d717f14b6c73637a1f9cd57fe7e61002d9cb1972"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce757c7c90d35719b38fa3d4ca55654a76a40716ee299b0865f2de21c146801c"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a9421b23c85f361a133aa7c5e8ec757668f70343f4ed8fdb5a4a14abd5437244"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:3b823be829407393d84ee56dc849dbe3b31b6a326f388e171555b262e8456cc1"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:5e58b61dcbb483a442c6239c3836696b79f2cd8e7eec11e12155d3f6f2d886d1"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39d67896f7235b2c886fb1ee77b1491b77049dcef6fbf0f401e7b4cbed86bbd4"}, - {file = "rpds_py-0.19.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:8b32cd4ab6db50c875001ba4f5a6b30c0f42151aa1fbf9c2e7e3674893fb1dc4"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1c32e41de995f39b6b315d66c27dea3ef7f7c937c06caab4c6a79a5e09e2c415"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a129c02b42d46758c87faeea21a9f574e1c858b9f358b6dd0bbd71d17713175"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:346557f5b1d8fd9966059b7a748fd79ac59f5752cd0e9498d6a40e3ac1c1875f"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:31e450840f2f27699d014cfc8865cc747184286b26d945bcea6042bb6aa4d26e"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01227f8b3e6c8961490d869aa65c99653df80d2f0a7fde8c64ebddab2b9b02fd"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69084fd29bfeff14816666c93a466e85414fe6b7d236cfc108a9c11afa6f7301"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d2b88efe65544a7d5121b0c3b003ebba92bfede2ea3577ce548b69c5235185"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ea961a674172ed2235d990d7edf85d15d8dfa23ab8575e48306371c070cda67"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:5beffdbe766cfe4fb04f30644d822a1080b5359df7db3a63d30fa928375b2720"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:720f3108fb1bfa32e51db58b832898372eb5891e8472a8093008010911e324c5"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:c2087dbb76a87ec2c619253e021e4fb20d1a72580feeaa6892b0b3d955175a71"}, - {file = "rpds_py-0.19.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ddd50f18ebc05ec29a0d9271e9dbe93997536da3546677f8ca00b76d477680c"}, - {file = "rpds_py-0.19.1.tar.gz", hash = "sha256:31dd5794837f00b46f4096aa8ccaa5972f73a938982e32ed817bb520c465e520"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, ] [[package]] name = "scipy" -version = "1.14.0" +version = "1.14.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"}, - {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"}, - {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"}, - {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"}, - {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"}, - {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"}, - {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"}, - {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, ] [package.dependencies] @@ -2193,8 +2293,8 @@ numpy = ">=1.23.5,<2.3" [package.extras] dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] -test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "six" @@ -2295,6 +2395,26 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "treescope" +version = "0.1.5" +description = "Treescope: An interactive HTML pretty-printer for ML research in IPython notebooks." +optional = false +python-versions = ">=3.10" +files = [ + {file = "treescope-0.1.5-py3-none-any.whl", hash = "sha256:1b07446291212ceafd28e27a26fabcd3299e056307d35f949241991bf190dd65"}, + {file = "treescope-0.1.5.tar.gz", hash = "sha256:75cff86663dfcac133b922fcc0ee9520f605dc9cbedcb34eb86a80476acbbb32"}, +] + +[package.dependencies] +numpy = ">=1.25.2" + +[package.extras] +dev = ["ipython", "jupyter", "pyink (>=24.3.0)", "pylint (>=2.6.0)"] +docs = ["ipython", "ipython (>=8.8.0)", "jax[cpu] (>=0.4.23)", "matplotlib (>=3.5.0)", "myst-nb (>=1.0.0)", "myst-parser (>=3.0.1)", "palettable (==3.3.3)", "pandas (==2.2.2)", "penzai (>=0.2.0,<0.3.0)", "plotly (==5.22.0)", "sphinx (>=6.0.0,<7.3.0)", "sphinx-book-theme (>=1.0.1)", "sphinx-hoverxref", "sphinx_contributors", "sphinxcontrib-katex", "torch (==2.3.1)"] +notebook = ["ipython", "jax (>=0.4.23)", "palettable"] +test = ["absl-py (>=1.4.0)", "jax (>=0.4.23)", "pytest (>=8.2.2)", "torch (>=2.0.0)"] + [[package]] name = "typeguard" version = "2.13.3" @@ -2323,13 +2443,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.2" +version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, - {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, + {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, + {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] @@ -2454,13 +2574,13 @@ test = ["websockets"] [[package]] name = "widgetsnbextension" -version = "4.0.11" +version = "4.0.13" description = "Jupyter interactive widgets for Jupyter Notebook" optional = false python-versions = ">=3.7" files = [ - {file = "widgetsnbextension-4.0.11-py3-none-any.whl", hash = "sha256:55d4d6949d100e0d08b94948a42efc3ed6dfdc0e9468b2c4b128c9a2ce3a7a36"}, - {file = "widgetsnbextension-4.0.11.tar.gz", hash = "sha256:8b22a8f1910bfd188e596fe7fc05dcbd87e810c8a4ba010bdb3da86637398474"}, + {file = "widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71"}, + {file = "widgetsnbextension-4.0.13.tar.gz", hash = "sha256:ffcb67bc9febd10234a362795f643927f4e0c05d9342c727b65d2384f8feacb6"}, ] [[package]] @@ -2545,4 +2665,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "0e362c176cd02fbd79aadae2b21b277c01f188afe3426d0c8755d7d35f4015ee" +content-hash = "42925b13162b17dff6f9ac5a701991d53a8de8a7f653847c1b18ae6d3a480d1b" diff --git a/pyproject.toml b/pyproject.toml index 3b1f9cc..d53ba76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ package-mode = false [tool.poetry.dependencies] python = ">=3.11,<3.13" jupytext = "^1.16.1" -genjax = {version = "0.5.0.post30.dev0+2df4f579" , source = "gcp" } -genstudio = {version = "2024.07.29.2335", source = "gcp"} +genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } +genstudio = {version = "2024.10.1", source = "gcp"} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" From 2b5a0e8ae0cf154b62cd2909e38c682f85048b8f Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Mon, 21 Oct 2024 09:59:17 -0700 Subject: [PATCH 34/86] upgrade genjax; start work on grid search --- .../probcomp-localization-tutorial.py | 239 ++++++++++++++++-- 1 file changed, 212 insertions(+), 27 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 1ce4d86..5a5a845 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -21,9 +21,9 @@ # if "google.colab" in sys.modules: # from google.colab import auth # pyright: ignore [reportMissingImports] -# auth.authenticate_user() -# %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa -# %pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa + auth.authenticate_user() + %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa + %pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa # %% [markdown] # # ProbComp Localization Tutorial # @@ -35,6 +35,10 @@ import json import genstudio.plot as Plot + + + + import itertools import jax import jax.numpy as jnp @@ -393,10 +397,9 @@ def integrate_controls_physical(robot_inputs): # ### Plot such data # %% def pose_plot(p, fill: str | Any = "black", **opts): - z = opts.get("zoom", 1.0) - r = z * 0.15 - wing_opacity = opts.get("opacity", 0.3) - WING_ANGLE, WING_LENGTH = jnp.pi / 12, z * opts.get("wing_length", 0.6) + r = opts.get('r', 0.5) + wing_opacity = opts.get('opacity', 0.3) + WING_ANGLE, WING_LENGTH = jnp.pi/12, opts.get('wing_length', 0.6) center = p.p angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1]) @@ -409,23 +412,23 @@ def pose_plot(p, fill: str | Any = "black", **opts): # Draw wings wings = Plot.line( [wing_ends[0], center, wing_ends[1]], - strokeWidth=opts.get("strokeWidth", 2), + strokeWidth=opts.get('strokeWidth', 2), stroke=fill, - opacity=wing_opacity, + opacity=wing_opacity ) # Draw center dot - dot = Plot.ellipse([center], fill=fill, **({"r": r} | opts)) + dot = Plot.ellipse([center], fill=fill, **opts) return wings + dot walls_plot = Plot.new( Plot.line( - world["wall_verts"], - strokeWidth=2, - stroke="#ccc", - ), + world["wall_verts"], + strokeWidth=2, + stroke="#ccc", + ), {"margin": 0, "inset": 50, "width": 500, "axis": None, "aspectRatio": 1}, Plot.domain([0, 20]), ) @@ -731,7 +734,7 @@ def generate_path(key: PRNGKey) -> Pose: key, sub_key = jax.random.split(key) sample_paths_v = jax.vmap(generate_path)(jax.random.split(sub_key, N_samples)) -Plot.Grid(*[walls_plot + poses_to_plots(path) for path in sample_paths_v]) +Plot.Grid([walls_plot + poses_to_plots(path) for path in sample_paths_v]) # %% # Animation showing a single path with confidence circles @@ -949,7 +952,6 @@ def sensor_model_one(pose, angle): @ "distance" ) - sensor_model = sensor_model_one.vmap(in_axes=(None, 0)) @@ -1665,14 +1667,197 @@ def localization_sis(motion_settings, observations): # %% # Try it in the low deviation setting key, sub_key = jax.random.split(key) -low_smc_result = localization_sis( - motion_settings_low_deviation, observations_low_deviation -).run(sub_key, 20) -( - world_plot - + path_to_polyline(path_low_deviation, stroke="blue", strokeWidth=2) - + [ - path_to_polyline(pose_list_to_plural_pose(p), opacity=0.1, stroke="green") - for p in low_smc_result.flood_fill() - ] -) +N_updates = 1000 +drift_traces, log_weights, _, _ = jax.vmap(gaussian_drift, in_axes=(0, None, None))(jax.random.split(sub_key, 1000), t0, motion_settings_high_deviation) + +# %% [markdown] +# Let's weightedly-select 10 from among those and see if there's any improvement +# %% +key, sub_key = jax.random.split(key) +N_selection = 10 +selected_indices = jax.vmap(categorical_sampler, in_axes=(0, None))(jax.random.split(sub_key, N_selection), log_weights) +selected_indices +# %% [markdown] +# Do you notice that many (or all) the selected indices are repeats? This is because we are searching a probability space of high dimension: it's unlikely that there will be many traces producing a dramatic improvement. Even if there's only one, we'll write the plotting function for a selection of drifted traces: after that, we will fix the problem of repeated selections. +# %% + +selected_traces = jax.tree.map(lambda v: v[selected_indices], drift_traces) + +def plot_traces(traces): + return (world_plot + + [animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="green") for path in jax.vmap(get_path)(traces)] + + poses_to_plots(path_high_deviation, fill=Plot.constantly("high deviation path"), opacity=0.2) + + Plot.color_map({"low deviation path": "green", "high deviation path": "blue", "integrated path": "black"})) + +plot_traces(selected_traces) + +# %% [markdown] +# That looks promising, but there may only be one path in that output, since one of the drifted traces is probabilistically dominant. How can we get more candidate traces? We can use `vmap` *again*, to provide a fresh batch of drift samples for each desired trace. That will give us a weighted sample of potentially-improved traces to work with. + +# %% +# Generate K drifted samples, by generating N importance samples for each K and making a weighted selection from each batch. +def multi_drift(key, trace: genjax.Trace, scale, K: int, N: int): + k1, k2 = jax.random.split(key) + kn_samples, log_weights, _, _ = jax.vmap(gaussian_drift, in_axes=(0, None, None))(jax.random.split(k1, N*K), trace, scale) + batched_weights = log_weights.reshape((K, N)) + winners = jax.vmap(categorical_sampler)(jax.random.split(k2, K), batched_weights) + # The winning indices are relative to the batch from which they were drawn. Reset the indices to linear form. + winners += jnp.arange(0, N*K, N) + return jax.tree.map(lambda v: v[winners], kn_samples) + + +# %% +key, sub_key = jax.random.split(key) +drifted_traces = multi_drift(sub_key, t0, motion_settings_high_deviation, 20, 1000) +plot_traces(drifted_traces) +# %% [markdown] +# We can see some improvement in the density of the paths selected. It's possible to imagine improving the search by repeating this drift process on all of the samples retured by the original importance sample. But we must face one important fact: we have used acceleration to improve what amounts to a brute-force search. The next inference step should take advantage of the information we have about the control steps, iteratively improving the path from the starting point, combining the control step and sensor data information to refine the selection of each step as it is made. + +# %% +# Let's approach the problem step by step instead of trying to infer the whole path. +# For each given pose, we will use the sensor data to propose a refinement. + +@genjax.gen +def perturb_pose(pose: Pose, motion_settings): + d_p = jnp.array(( + genjax.normal(0.0, motion_settings['p_noise']) @ 'd_x', + genjax.normal(0.0, motion_settings['p_noise']) @ 'd_y' + )) + d_hd = genjax.normal(0.0, motion_settings['hd_noise']) @ 'd_hd' + return Pose(pose.p + d_p, pose.hd + d_hd) + +@genjax.gen +def perturb_model(pose: Pose, motion_settings): + p1 = perturb_pose(pose, motion_settings) @ 'pose' + _ = sensor_model(p1, sensor_angles) @ 'sensor' + return p1 + +# %% [markdown] +# To get started we'll work with the initial point, and then improve it. Once that's done, +# we can chain together such improved moves to hopefully get a better inference of the +# actual path. + +# %% +key, sub_key = jax.random.split(key) +p0 = start_pose_prior.simulate(sub_key, (robot_inputs['start'], motion_settings_low_deviation)) +key, sub_key = jax.random.split(key) +tr_p0 = jax.vmap(perturb_model.simulate, in_axes=(0, None))(jax.random.split(sub_key, 100), (p0.get_retval(), motion_settings_low_deviation)) +# %% + +# %% [markdown] +# Create a choicemap that will enforce the observations found at step $i$. + +def observations_to_choicemap(observations, i): + o_i = observations[i] + return C['sensor', jnp.arange(len(o_i)), 'distance'].set(o_i) +# %% [markdown] +# The first thing we'll try is a Boltzmann update: generate a cloud of nearby points +# using the generative function we wrote, and weightedly select a replacement from that. +# First, let's generate the cloud and visualize it. +# %% +def boltzmann_sample(key: PRNGKey, N: int, pose: Pose, motion_settings, observations, i): + return jax.vmap(perturb_model.importance, in_axes=(0, None, None))( + jax.random.split(key, N), + observations_to_choicemap(observations, i), + (pose, motion_settings) + ) + +def small_pose_plot(p: Pose, **opts): + """This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose. + TODO: consider scaling r and wing_length based on the size of the plot domain.""" + opts = {'r': 0.001} | opts + return pose_plot(p, wing_length=0.006, **opts) + +def weighted_small_pose_plot(target, poses, ws): + lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) + scaled_ws = jnp.exp(ws - lse_ws) + max_scaled_w: FloatArray = jnp.max(scaled_ws) + scaled_ws /= max_scaled_w + # the following hack "boosts" lower scores a bit, to give us more visibility into + # the density of the nearby cloud. Aesthetically, I found too many points were + # invisible without some adjustment, since the score distribution is concentrated + # closely around 1.0 + scaled_ws = scaled_ws ** 0.3 + return (Plot.new([small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)] + + small_pose_plot(target, r = 0.003, fill='red') + + small_pose_plot(robot_inputs['start'], r=0.003,fill='green')) + + { + "color": {"type":"linear", "scheme":"Purples"}, + "height": 400, + "width": 400, + "aspectRatio": 1 + }) +key, sub_key = jax.random.split(key) +bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation, 0) +weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]) + +# %% + +# %% +# def weighted_small_pose_plot_0(poses, ws): +# lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) +# scaled_ws = jnp.exp(ws - lse_ws) +# max_scaled_w: FloatArray = jnp.max(scaled_ws) +# bias_factor = 1.05 # Set > 1 to lift things close to zero +# rescaled_ws = (1 - 1/bias_factor) + scaled_ws / (bias_factor * max_scaled_w) +# # use rescaled_ws as a color density +# print(f'range: {jnp.min(rescaled_ws), jnp.max(rescaled_ws)}') +# return (Plot.new([small_pose_plot(p, opacity=w) for p, w in zip(poses, rescaled_ws)] +# + small_pose_plot(p0.get_retval(), r = 0.003, fill='red') +# + small_pose_plot(robot_inputs['start'], r=0.003,fill='green')) +# + { +# "color": {"type":"linear", "scheme":"Purples"}, +# "height": 400, +# "width": 400, +# "aspectRatio": 1 +# }) + + + + + + +# (weighted_small_pose_plot(trs.get_retval(), ws) + Plot.color_map({"real start": "green", "perceived start": "red", "nearby candidate": "black", "selected candidate": "cyan"}) +# + {"height": 400, "width": 400, "aspectRatio": 1}) + +# %% +# Grid approach (using assess maybe?) +def grid_of_nearby_poses(p, size, n): + grid_ax = jnp.arange(-n, n+1) * size + n_ax = len(grid_ax) + grid = jnp.dstack(jnp.meshgrid(grid_ax, grid_ax)).reshape(n_ax * n_ax, -1) + return Pose(p.p + grid, jnp.repeat(p.hd, n_ax*n_ax)) + +# %% +grid_of_nearby_poses(p0.get_retval(), 0.01, 3) +# %% + +# def grid_plot(g): +# return (Plot.new([small_pose_plot(p) for p in g]) +# + small_pose_plot(p0.get_retval(), fill='red') +# + small_pose_plot(robot_inputs['start'], fill='green')) + +# grid_plot(grid_of_nearby_poses(p0.get_retval(), 0.015, 10)) +# %% + +pose_grid = grid_of_nearby_poses(p0.get_retval(), 0.008, 15) +@genjax.gen +def assess_model(p): + sensor_model(p, sensor_angles) @ 'sensor' + return p +# %% +key, sub_key = jax.random.split(key) +assess_scores, assess_retvals = jax.vmap(lambda k, p: assess_model.assess(k, (p,)), in_axes=(None, 0))(cm, pose_grid) +#assess_scores, assess_retvals = jax.vmap(lambda p: sensor_model.assess(cm, (p, sensor_angles)))(pose_grid) +# %% +#sensor_model.assess(cm, (pose_grid[0], sensor_angles)) +#sensor_model.simulate(sub_key, (pose_grid[0], sensor_angles)) +#sensor_model.importance(sub_key, observations_to_choicemap(observations_low_deviation, 0), (pose_grid[0], sensor_angles)) + +# Since the above calls work... +# I think this ought to work, but doesn't! TODO: find a minimal repro and file an issue +#sensor_model.assess(cm, (pose_grid[0], sensor_angles)) + +(weighted_small_pose_plot(p0.get_retval(), assess_retvals, assess_scores) & +(weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]))) +# %% From 32aaaaba621580e79351bb4c2e5de594e9671c85 Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Mon, 21 Oct 2024 18:39:49 -0700 Subject: [PATCH 35/86] four-way plot --- .../probcomp-localization-tutorial.py | 159 +++++++++++------- 1 file changed, 99 insertions(+), 60 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 5a5a845..e40a43e 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1442,16 +1442,12 @@ def importance_sample( # %% - -def path_to_polyline(path, **options): - if len(path.p.shape) > 1: - x_coords = path.p[:, 0] - y_coords = path.p[:, 1] - return Plot.line({"x": x_coords, "y": y_coords}, {"curve": "linear", **options}) - else: - return Plot.dot([path.p], fill=options["stroke"], r=2, **options) - - +def animate_path_as_line(path, **options): + x_coords = path.p[:, 0] + y_coords = path.p[:, 1] + return Plot.line({"x": x_coords, "y": y_coords}, + {"curve": "catmull-rom", + **options}) # ( world_plot @@ -1742,23 +1738,20 @@ def perturb_model(pose: Pose, motion_settings): p0 = start_pose_prior.simulate(sub_key, (robot_inputs['start'], motion_settings_low_deviation)) key, sub_key = jax.random.split(key) tr_p0 = jax.vmap(perturb_model.simulate, in_axes=(0, None))(jax.random.split(sub_key, 100), (p0.get_retval(), motion_settings_low_deviation)) -# %% - # %% [markdown] -# Create a choicemap that will enforce the observations found at step $i$. +# Create a choicemap that will enforce the given sensor observation -def observations_to_choicemap(observations, i): - o_i = observations[i] - return C['sensor', jnp.arange(len(o_i)), 'distance'].set(o_i) +def observation_to_choicemap(observation): + return C['sensor', jnp.arange(len(observation)), 'distance'].set(observation) # %% [markdown] # The first thing we'll try is a Boltzmann update: generate a cloud of nearby points # using the generative function we wrote, and weightedly select a replacement from that. # First, let's generate the cloud and visualize it. # %% -def boltzmann_sample(key: PRNGKey, N: int, pose: Pose, motion_settings, observations, i): +def boltzmann_sample(key: PRNGKey, N: int, pose: Pose, motion_settings, observations): return jax.vmap(perturb_model.importance, in_axes=(0, None, None))( jax.random.split(key, N), - observations_to_choicemap(observations, i), + observation_to_choicemap(observations), (pose, motion_settings) ) @@ -1788,38 +1781,8 @@ def weighted_small_pose_plot(target, poses, ws): "aspectRatio": 1 }) key, sub_key = jax.random.split(key) -bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation, 0) +bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation[0]) weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]) - -# %% - -# %% -# def weighted_small_pose_plot_0(poses, ws): -# lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) -# scaled_ws = jnp.exp(ws - lse_ws) -# max_scaled_w: FloatArray = jnp.max(scaled_ws) -# bias_factor = 1.05 # Set > 1 to lift things close to zero -# rescaled_ws = (1 - 1/bias_factor) + scaled_ws / (bias_factor * max_scaled_w) -# # use rescaled_ws as a color density -# print(f'range: {jnp.min(rescaled_ws), jnp.max(rescaled_ws)}') -# return (Plot.new([small_pose_plot(p, opacity=w) for p, w in zip(poses, rescaled_ws)] -# + small_pose_plot(p0.get_retval(), r = 0.003, fill='red') -# + small_pose_plot(robot_inputs['start'], r=0.003,fill='green')) -# + { -# "color": {"type":"linear", "scheme":"Purples"}, -# "height": 400, -# "width": 400, -# "aspectRatio": 1 -# }) - - - - - - -# (weighted_small_pose_plot(trs.get_retval(), ws) + Plot.color_map({"real start": "green", "perceived start": "red", "nearby candidate": "black", "selected candidate": "cyan"}) -# + {"height": 400, "width": 400, "aspectRatio": 1}) - # %% # Grid approach (using assess maybe?) def grid_of_nearby_poses(p, size, n): @@ -1831,15 +1794,6 @@ def grid_of_nearby_poses(p, size, n): # %% grid_of_nearby_poses(p0.get_retval(), 0.01, 3) # %% - -# def grid_plot(g): -# return (Plot.new([small_pose_plot(p) for p in g]) -# + small_pose_plot(p0.get_retval(), fill='red') -# + small_pose_plot(robot_inputs['start'], fill='green')) - -# grid_plot(grid_of_nearby_poses(p0.get_retval(), 0.015, 10)) -# %% - pose_grid = grid_of_nearby_poses(p0.get_retval(), 0.008, 15) @genjax.gen def assess_model(p): @@ -1847,7 +1801,7 @@ def assess_model(p): return p # %% key, sub_key = jax.random.split(key) -assess_scores, assess_retvals = jax.vmap(lambda k, p: assess_model.assess(k, (p,)), in_axes=(None, 0))(cm, pose_grid) +assess_scores, assess_retvals = jax.vmap(lambda k, p: assess_model.assess(k, (p,)), in_axes=(None, 0))(observation_to_choicemap(observations_low_deviation[0]), pose_grid) #assess_scores, assess_retvals = jax.vmap(lambda p: sensor_model.assess(cm, (p, sensor_angles)))(pose_grid) # %% #sensor_model.assess(cm, (pose_grid[0], sensor_angles)) @@ -1859,5 +1813,90 @@ def assess_model(p): #sensor_model.assess(cm, (pose_grid[0], sensor_angles)) (weighted_small_pose_plot(p0.get_retval(), assess_retvals, assess_scores) & -(weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]))) + weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1])) +# %% [markdown] +# Now let's try doing the whole path. We want to produce something that is ultimately +# scan-compatible, so it should have the form state -> update -> new_state. The state +# is obviously the pose; the update will include the sensor readings at the current +# position and the control input for the next step. + +def select_by_weight(key: PRNGKey, weights: FloatArray, things): + chosen = jax.random.categorical(key, weights) + return jax.tree.map(lambda v: v[chosen], things) + +def improved_path(key: PRNGKey, motion_settings, observations): + + def boltzmann_improver(k: PRNGKey, pose, observation): + k1, k2 = jax.random.split(k, 2) + trs, ws = boltzmann_sample(k1, 1000, pose, motion_settings, observation) + return select_by_weight(k2, ws, trs.get_retval()) + + def grid_search_improver(k: PRNGKey, pose, observation): + choicemap = observation_to_choicemap(observation) + nearby_poses = grid_of_nearby_poses(pose, 0.008, 15) + ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses) + return select_by_weight(k, ws, nearby_poses) + + def improve_pose_and_step(state, update): + pose = state + observation, control, key = update + k1, k2 = jax.random.split(key) + # improve the step where we are + p1 = grid_search_improver(k1, pose, observation) + # run the step model to advance one step + p2 = step_model.simulate(k2, (p1, control, motion_settings)) + return (p2.get_retval(), p2.get_retval()) + + # We have one fewer control than step, since no step got us to the initial position. + # Our scan step starts at the initial step and applies a control input each time. + # To make things balance, we need to add a zero step to the end of the control input + # array, so that when we arrive at the final step, no more control input is given. + controls = robot_inputs['controls'] + Control(jnp.array([0]), jnp.array([0])) + n_steps = len(controls) + sub_keys = jax.random.split(key, n_steps + 1) + p0 = start_pose_prior.simulate(sub_keys[0], (robot_inputs['start'], motion_settings)).get_retval() + return jax.lax.scan(improve_pose_and_step, p0, ( + observations, + controls, + sub_keys[1:] + )) + + + # Generate initial point from prior + k1, k2 = jax.random.split(key, 2) + sp_tr = start_pose_prior.simulate(k1, (robot_inputs['start'], motion_settings)) + p = sp_tr.get_retval() +# %% +key, k1, k2 = jax.random.split(key, 3) +low_importance = select_by_weight(k1, low_weights, low_deviation_paths) +high_importance = select_by_weight(k2, high_weights, high_deviation_paths) + + +# %% +endpoint_low, improved_low = improved_path(jax.random.PRNGKey(0), motion_settings_low_deviation, observations_low_deviation) + +# %% + +def path_comparison_plot(improved, integrated, importance, true): + return (world_plot + + animate_path_as_line(improved, strokeWidth=2, stroke=Plot.constantly("improved")) + + animate_path_as_line(integrated, strokeWidth=2, stroke=Plot.constantly("integrated")) + + animate_path_as_line(importance, strokeWidth=2, stroke=Plot.constantly("importance")) + + animate_path_as_line(true, strokeWidth=2, stroke=Plot.constantly("true")) + + poses_to_plots(improved, fill=Plot.constantly("improved"), opacity=0.2) + + poses_to_plots(integrated, fill=Plot.constantly("integrated"), opacity=0.2) + + poses_to_plots(importance, fill=Plot.constantly("importance")) + + poses_to_plots(true, fill=Plot.constantly("true")) + + Plot.color_map({"integrated": "green", "improved": "blue", "true": "black", "importance": "red"})) + +# %% +path_comparison_plot(improved_low, path_integrated, low_importance, path_low_deviation) +# %% +endpoint_high, improved_high = improved_path(jax.random.PRNGKey(0), motion_settings_high_deviation, observations_high_deviation) +path_comparison_plot(improved_high, path_integrated, high_importance, path_high_deviation) + + + +# %% +robot_inputs['controls'][-1] # %% From 3b027ddab0ac35ac8ceca71a9c998085e0e0859c Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Tue, 22 Oct 2024 15:00:35 -0700 Subject: [PATCH 36/86] doc and graphics fixes --- .../probcomp-localization-tutorial.py | 159 ++++++++++-------- 1 file changed, 87 insertions(+), 72 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index e40a43e..7816558 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1064,10 +1064,11 @@ def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): def animate_full_trace(trace, frame_key=None): path = get_path(trace) readings = get_sensors(trace) - motion_settings = trace.get_args()[0] - return animate_path_and_sensors( - path, readings, motion_settings, frame_key=frame_key - ) + # since we use make_full_model to curry motion_settings around the scan combinator, + # that object will not be in the outer trace's argument list; but we can be a little + # crafty and find it at a lower level. + motion_settings = trace.get_subtrace(('initial',)).get_subtrace(('pose',)).get_args()[1] + return animate_path_and_sensors(path, readings, motion_settings, frame_key=frame_key) animate_full_trace(tr) @@ -1256,28 +1257,18 @@ def constraint_from_path(path): # %% Plot.Row( - *[ - ( - html("div.f3.b.tc", title) - | animate_full_trace(trace, frame_key="frame") - | html("span.tc", f"score: {score:,.2f}") - ) - for (title, trace, motion_settings, score) in [ - [ - "Low deviation", - trace_path_integrated_observations_low_deviation, - motion_settings_low_deviation, - w_low, - ], - [ - "High deviation", - trace_path_integrated_observations_high_deviation, - motion_settings_high_deviation, - w_high, - ], - ] - ] -) | Plot.Slider("frame", 0, T, fps=2) + *[(html("div.f3.b.tc", title) + | animate_full_trace(trace, frame_key="frame") + | html("span.tc", f"score: {score:,.2f}")) + for (title, trace, motion_settings, score) in + [["Low deviation", + trace_path_integrated_observations_low_deviation, + motion_settings_low_deviation, + w_low], + ["High deviation", + trace_path_integrated_observations_high_deviation, + motion_settings_high_deviation, + w_high]]]) | Plot.Slider("frame", 0, T, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... @@ -1446,7 +1437,7 @@ def animate_path_as_line(path, **options): x_coords = path.p[:, 0] y_coords = path.p[:, 1] return Plot.line({"x": x_coords, "y": y_coords}, - {"curve": "catmull-rom", + {"curve": "linear", **options}) # ( @@ -1783,26 +1774,53 @@ def weighted_small_pose_plot(target, poses, ws): key, sub_key = jax.random.split(key) bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation[0]) weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]) -# %% -# Grid approach (using assess maybe?) -def grid_of_nearby_poses(p, size, n): - grid_ax = jnp.arange(-n, n+1) * size - n_ax = len(grid_ax) - grid = jnp.dstack(jnp.meshgrid(grid_ax, grid_ax)).reshape(n_ax * n_ax, -1) - return Pose(p.p + grid, jnp.repeat(p.hd, n_ax*n_ax)) - -# %% -grid_of_nearby_poses(p0.get_retval(), 0.01, 3) -# %% -pose_grid = grid_of_nearby_poses(p0.get_retval(), 0.008, 15) +# %% [markdown] +# Develop a function which will produce a grid of evenly spaced nearby poses given +# an initial pose. $n$ is the number of steps to take in each cardinal direction +# (up/down, left/right and changes in heading). For example, if you say $n = 2$, there +# will be a $5\times 5$ grid of positions with the original pose in the center, and 5 layers +# of this type, each with different heading deltas (including zero), for a total of +# $125 = 5^3$ alternate poses. +# %% +def grid_of_nearby_poses(p, n, motion_settings): + indices = jnp.arange(-n, n+1) + n_indices = len(indices) + grid_ax = indices * 2 * motion_settings['p_noise'] / n + grid = jnp.dstack(jnp.meshgrid(grid_ax, grid_ax)).reshape(n_indices * n_indices, -1) + # That's the position grid. We will now make a 1-d grid for the heading deltas, + # and then form the linear cartesian product. + headings = indices * 2 * motion_settings['hd_noise'] / n + return Pose(jnp.repeat(p.p + grid, n_indices, axis=0), jnp.tile(p.hd + headings, n_indices * n_indices)) +# %% +cube_step_size = 8 +pose_grid = grid_of_nearby_poses(p0.get_retval(), cube_step_size, motion_settings_low_deviation) @genjax.gen def assess_model(p): sensor_model(p, sensor_angles) @ 'sensor' return p # %% key, sub_key = jax.random.split(key) -assess_scores, assess_retvals = jax.vmap(lambda k, p: assess_model.assess(k, (p,)), in_axes=(None, 0))(observation_to_choicemap(observations_low_deviation[0]), pose_grid) +model_assess = jax.jit(assess_model.assess) +assess_scores, assess_retvals = jax.vmap(lambda k, p: model_assess(k, (p,)), in_axes=(None, 0))(observation_to_choicemap(observations_low_deviation[0]), pose_grid) #assess_scores, assess_retvals = jax.vmap(lambda p: sensor_model.assess(cm, (p, sensor_angles)))(pose_grid) +# %% +# Our grid of nearby poses is actually a cube when we take into consideration the +# heading deltas. In order to get a 2d density, we decide to flatten the cube by +# taking the "best" of the headings by score at each point. +def flatten_pose_cube(n, poses, scores): + d = 2 * n + 1 + pose_groups = poses.p.reshape((d, d*d, 2)) + heading_groups = poses.hd.reshape((d, d*d)) + score_groups = scores.reshape((d, d*d)) + # find the best score in each group + best = jnp.argmax(score_groups, axis=1) + # We want to select the best column from every row, so we need to + # explicitly enumerate the rows we want (using : would not have the + # same effect) + return (Pose(pose_groups[jnp.arange(len(pose_groups)), best], + heading_groups[jnp.arange(len(heading_groups)), best]), + score_groups[jnp.arange(len(score_groups)), best]) + # %% #sensor_model.assess(cm, (pose_grid[0], sensor_angles)) #sensor_model.simulate(sub_key, (pose_grid[0], sensor_angles)) @@ -1811,7 +1829,11 @@ def assess_model(p): # Since the above calls work... # I think this ought to work, but doesn't! TODO: find a minimal repro and file an issue #sensor_model.assess(cm, (pose_grid[0], sensor_angles)) - +# %% [markdown] +# Prepare a plot showing the density of nearby improvements available using the grid +# search and importance sampling techniques. +# %% +assess_pose_plane, assess_score_plan = flatten_pose_cube(cube_step_size, assess_retvals, assess_scores) (weighted_small_pose_plot(p0.get_retval(), assess_retvals, assess_scores) & weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1])) # %% [markdown] @@ -1824,7 +1846,7 @@ def select_by_weight(key: PRNGKey, weights: FloatArray, things): chosen = jax.random.categorical(key, weights) return jax.tree.map(lambda v: v[chosen], things) -def improved_path(key: PRNGKey, motion_settings, observations): +def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray, mode: str): def boltzmann_improver(k: PRNGKey, pose, observation): k1, k2 = jax.random.split(k, 2) @@ -1833,7 +1855,7 @@ def boltzmann_improver(k: PRNGKey, pose, observation): def grid_search_improver(k: PRNGKey, pose, observation): choicemap = observation_to_choicemap(observation) - nearby_poses = grid_of_nearby_poses(pose, 0.008, 15) + nearby_poses = grid_of_nearby_poses(pose, 15, motion_settings) ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses) return select_by_weight(k, ws, nearby_poses) @@ -1842,10 +1864,11 @@ def improve_pose_and_step(state, update): observation, control, key = update k1, k2 = jax.random.split(key) # improve the step where we are - p1 = grid_search_improver(k1, pose, observation) + improver = {"grid": grid_search_improver, "boltzmann": boltzmann_improver}[mode] + p1 = improver(k1, pose, observation) # run the step model to advance one step p2 = step_model.simulate(k2, (p1, control, motion_settings)) - return (p2.get_retval(), p2.get_retval()) + return (p2.get_retval(), p1) # We have one fewer control than step, since no step got us to the initial position. # Our scan step starts at the initial step and applies a control input each time. @@ -1860,43 +1883,35 @@ def improve_pose_and_step(state, update): controls, sub_keys[1:] )) - - - # Generate initial point from prior - k1, k2 = jax.random.split(key, 2) - sp_tr = start_pose_prior.simulate(k1, (robot_inputs['start'], motion_settings)) - p = sp_tr.get_retval() # %% +# Select an importance sample via weight in both the low and high deviation settings. key, k1, k2 = jax.random.split(key, 3) low_importance = select_by_weight(k1, low_weights, low_deviation_paths) high_importance = select_by_weight(k2, high_weights, high_deviation_paths) - - # %% -endpoint_low, improved_low = improved_path(jax.random.PRNGKey(0), motion_settings_low_deviation, observations_low_deviation) - +key, sub_key = jax.random.split(key) +endpoint_low, improved_low = improved_path(sub_key, motion_settings_low_deviation, observations_low_deviation, "grid") # %% -def path_comparison_plot(improved, integrated, importance, true): - return (world_plot - + animate_path_as_line(improved, strokeWidth=2, stroke=Plot.constantly("improved")) - + animate_path_as_line(integrated, strokeWidth=2, stroke=Plot.constantly("integrated")) - + animate_path_as_line(importance, strokeWidth=2, stroke=Plot.constantly("importance")) - + animate_path_as_line(true, strokeWidth=2, stroke=Plot.constantly("true")) - + poses_to_plots(improved, fill=Plot.constantly("improved"), opacity=0.2) - + poses_to_plots(integrated, fill=Plot.constantly("integrated"), opacity=0.2) - + poses_to_plots(importance, fill=Plot.constantly("importance")) - + poses_to_plots(true, fill=Plot.constantly("true")) - + Plot.color_map({"integrated": "green", "improved": "blue", "true": "black", "importance": "red"})) +def path_comparison_plot(*plots): + types = ["improved", "integrated", "importance", "true"] + plot = world_plot + plot += [animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t)) for p, t in zip(plots, types)] + plot += [poses_to_plots(p, fill=Plot.constantly(t)) for p, t in zip(plots, types)] + return plot + Plot.color_map({"integrated": "green", "improved": "blue", "true": "black", "importance": "red"}) # %% path_comparison_plot(improved_low, path_integrated, low_importance, path_low_deviation) # %% -endpoint_high, improved_high = improved_path(jax.random.PRNGKey(0), motion_settings_high_deviation, observations_high_deviation) +key, sub_key = jax.random.split(key) +endpoint_high, improved_high = improved_path(sub_key, motion_settings_high_deviation, observations_high_deviation, "grid") path_comparison_plot(improved_high, path_integrated, high_importance, path_high_deviation) - - - -# %% -robot_inputs['controls'][-1] +# %% [markdown] +# To see how the grid search improves poses, we play back the grid-search path +# next to an importance sample path. You can see the grid search has a better fit +# of sensor data to wall position at a variety of time steps. # %% +Plot.Row( + animate_path_and_sensors(improved_high, observations_high_deviation, motion_settings_high_deviation, frame_key="frame"), + animate_path_and_sensors(high_importance, observations_high_deviation, motion_settings_high_deviation, frame_key="frame") +) | Plot.Slider("frame", 0, T, fps=2) From ebc205c02c2c666ef97445cef8a9104ebd78a3d7 Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Fri, 25 Oct 2024 10:20:58 -0700 Subject: [PATCH 37/86] wip --- .../probcomp-localization-tutorial.py | 307 +++++++++--------- 1 file changed, 159 insertions(+), 148 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 7816558..dcdd193 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -31,7 +31,7 @@ # # Dependencies are specified in pyproject.toml. # %% -# Global setup code +# Global setup codef import json import genstudio.plot as Plot @@ -482,16 +482,14 @@ def pose_plot(p, fill: str | Any = "black", **opts): # # We start with the two building blocks: the starting pose and individual steps of motion. # %% - -# TODO(colin,jay): Originally, we passed motion_settings['p_noise'] ** 2 to -# mv_normal_diag, but I think this squares the scale twice. TFP documenentation -# - https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiag -# states that: scale = diag(scale_diag); covariance = scale @ scale.T. The second -# equation will have the effect of squaring the individual diagonal scales. - +# @genjax.gen +# def start_pose_prior(motion_settings, start): +# p = genjax.mv_normal(start.p, motion_settings["p_noise"] ** 2 * jnp.eye(2)) @ "p" +# hd = genjax.normal(start.hd, motion_settings["hd_noise"]) @ "hd" +# return Pose(p, hd) @genjax.gen -def step_model(motion_settings, start, control): +def step_proposal(motion_settings, start, control): p = ( genjax.mv_normal_diag( start.p + control.ds * start.dp(), motion_settings["p_noise"] * jnp.ones(2) @@ -499,8 +497,12 @@ def step_model(motion_settings, start, control): @ "p" ) hd = genjax.normal(start.hd + control.dhd, motion_settings["hd_noise"]) @ "hd" + print(f"gonna return: {start.p, p, hd} -> {physical_step(start.p, p, hd)}") return physical_step(start.p, p, hd) +@genjax.gen +def start_pose_prior(motion_settings, start): + return step_proposal.inline(motion_settings, start, noop_control) # Set the motion settings default_motion_settings = {"p_noise": 0.5, "hd_noise": 2 * jnp.pi / 36.0} @@ -510,8 +512,8 @@ def step_model(motion_settings, start, control): # %% key = jax.random.PRNGKey(0) -step_model.simulate( - key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]) +start_pose_prior.simulate( + key, (default_motion_settings, robot_inputs["start"]) ).get_retval() # %% [markdown] @@ -535,20 +537,14 @@ def make_circle(p, r): # Generate N_samples of starting poses from the prior N_samples = 50 key, sub_key = jax.random.split(key) -pose_samples = jax.vmap(step_model.simulate, in_axes=(0, None))( +pose_samples = jax.vmap(step_proposal.simulate, in_axes=(0, None))( jax.random.split(sub_key, N_samples), (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) - -def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: - return Pose(jnp.array([pose.p for pose in pl]), [pose.hd for pose in pl]) - - def poses_to_plots(poses: Iterable[Pose], **plot_opts): return [pose_plot(pose, **plot_opts) for pose in poses] - # Plot the world, starting pose samples, and 95% confidence region # Calculate the radius of the 95% confidence region def confidence_circle(pose: Pose, p_noise: float): @@ -558,7 +554,6 @@ def confidence_circle(pose: Pose, p_noise: float): r=2.5 * p_noise, ) + Plot.color_map({"95% confidence region": "rgba(255,0,0,0.25)"}) - ( world_plot + poses_to_plots([robot_inputs["start"]], fill=Plot.constantly("step from here")) @@ -580,9 +575,8 @@ def confidence_circle(pose: Pose, p_noise: float): # %% # `simulate` takes the GF plus a tuple of args to pass to it. key, sub_key = jax.random.split(key) -trace = step_model.simulate( - sub_key, - (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), +trace = start_pose_prior.simulate( + sub_key, (default_motion_settings, robot_inputs["start"]) ) trace.get_choices() @@ -706,9 +700,36 @@ def confidence_circle(pose: Pose, p_noise: float): # (It is worth acknowledging two strange things in the code below: the use of the suffix `.accumulate()` in path_model and the use of that auxiliary function itself. # %% -path_model = ( - step_model.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() +@genjax.gen +def path_model_start(robot_inputs, motion_settings): + return start_pose_prior(motion_settings, robot_inputs["start"]) @ ( + "initial", + "pose", + ) + +@genjax.gen +def path_model_step(motion_settings, previous_pose, control): + return step_proposal(motion_settings, previous_pose, control) @ ( + "steps", + "pose", + ) + +path_model = path_model_step.partial_apply(default_motion_settings).accumulate() + +# TODO(colin,huebert): talk about accumulate, what it does, and _why_ from the point of view of acceleration. This is the flow control modification we were hinting at above, and it constrains the step function to have the two-argument signature that it does, which is why we reached for `partial` in the first place. Emphasize that this small bit of preparation allows massively parallel execution on a GPU and so it's worth the hassle. + +key, sub_key1, sub_key2 = jax.random.split(key, 3) +initial_pose = path_model_start.simulate( + sub_key1, (robot_inputs, default_motion_settings) ) +step_proposal.simulate( + sub_key2, + ( + default_motion_settings, + initial_pose.get_retval(), + robot_inputs["controls"][0], + ), +).get_choices() # result[0] ~~ robot_inputs['start'] + control_step[0] (which is zero) + noise @@ -789,9 +810,8 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): # %% key, sub_key = jax.random.split(key) -trace = step_model.simulate( - sub_key, - (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), +trace = start_pose_prior.simulate( + sub_key, (default_motion_settings, robot_inputs["start"]) ) key, sub_key = jax.random.split(key) rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( @@ -998,30 +1018,27 @@ def noisy_sensor(pose): # # We fold the sensor model into the motion model to form a "full model", whose traces describe simulations of the entire robot situation as we have described it. # %% - - @genjax.gen -def full_model_kernel(motion_settings, state, control): - pose = step_model(motion_settings, state, control) @ "pose" +def full_model_initial(motion_settings): + pose = start_pose_prior(motion_settings, robot_inputs["start"]) @ "pose" sensor_model(pose, sensor_angles) @ "sensor" return pose +@genjax.gen +def full_model_kernel(motion_settings, state, control): + pose = step_proposal(motion_settings, state, control) @ "pose" + sensor_model(pose, sensor_angles) @ "sensor" + return pose, pose @genjax.gen def full_model(motion_settings): - return ( - full_model_kernel.partial_apply(motion_settings) - .map(lambda r: (r, r)) - .scan()(robot_inputs["start"], robot_inputs["controls"]) - @ "steps" - ) - + initial = full_model_initial(motion_settings) @ "initial" + return full_model_kernel.partial_apply(motion_settings).scan(n=T)(initial, robot_inputs["controls"]) @ "steps" def get_path(trace): ps = trace.get_retval()[1] return ps - def get_sensors(trace): ch = trace.get_choices() return ch["steps", :, "sensor", :, "distance"] @@ -1064,10 +1081,7 @@ def animate_path_and_sensors(path, readings, motion_settings, frame_key=None): def animate_full_trace(trace, frame_key=None): path = get_path(trace) readings = get_sensors(trace) - # since we use make_full_model to curry motion_settings around the scan combinator, - # that object will not be in the outer trace's argument list; but we can be a little - # crafty and find it at a lower level. - motion_settings = trace.get_subtrace(('initial',)).get_subtrace(('pose',)).get_args()[1] + motion_settings = trace.get_args()[0] return animate_path_and_sensors(path, readings, motion_settings, frame_key=frame_key) @@ -1152,7 +1166,7 @@ def plt(readings): animate_bare_sensors(path_integrated, world_plot) # %% -world_plot + plot_sensors(robot_inputs["start"], observations_low_deviation[0]) +world_plot + plot_sensors(robot_inputs['start'], observations_low_deviation[0]) # %% [markdown] # It would seem that the fit is reasonable in low motion deviation, but really breaks down in high motion deviation. # @@ -1169,15 +1183,12 @@ def plt(readings): model_importance = jax.jit(full_model.importance) key, sub_key = jax.random.split(key) -sample, log_weight = model_importance( - sub_key, constraints_low_deviation, (motion_settings_low_deviation,) -) +sample, log_weight = model_importance(sub_key, constraints_low_deviation, (motion_settings_low_deviation,)) + animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% key, sub_key = jax.random.split(key) -sample, log_weight = model_importance( - sub_key, constraints_high_deviation, (motion_settings_high_deviation,) -) +sample, log_weight = model_importance(sub_key, constraints_high_deviation, (motion_settings_high_deviation,)) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% [markdown] # A trace resulting from a call to `importance` is structurally indistinguishable from one drawn from `simulate`. But there is a key situational difference: while `get_score` always returns the frequency with which `simulate` stochastically produces the trace, this value is **no longer equal to** the frequency with which the trace is stochastically produced by `importance`. This is both true in an obvious and less relevant sense, as well as true in a more subtle and extremely germane sense. @@ -1215,8 +1226,6 @@ def plt(readings): # In words, the data are incongruously unlikely for the integrated path. The (log) density of the measurement data, given the integrated path... # %% - - def constraint_from_path(path): c_ps = jax.vmap(lambda ix, p: C["steps", ix, "pose", "p"].set(p))( jnp.arange(T), path.p @@ -1237,17 +1246,9 @@ def constraint_from_path(path): ) key, sub_key = jax.random.split(key) -trace_path_integrated_observations_low_deviation, w_low = model_importance( - sub_key, - constraints_path_integrated_observations_low_deviation, - (motion_settings_low_deviation,), -) +trace_path_integrated_observations_low_deviation, w_low = model_importance(sub_key, constraints_path_integrated_observations_low_deviation, (motion_settings_low_deviation,)) key, sub_key = jax.random.split(key) -trace_path_integrated_observations_high_deviation, w_high = model_importance( - sub_key, - constraints_path_integrated_observations_high_deviation, - (motion_settings_high_deviation,), -) +trace_path_integrated_observations_high_deviation, w_high = model_importance(sub_key, constraints_path_integrated_observations_high_deviation, (motion_settings_high_deviation,)) w_low, w_high # TODO: Jay then does two projections to compare the log-weights of these two things, @@ -1279,21 +1280,9 @@ def constraint_from_path(path): key, sub_key = jax.random.split(key) -traces_generated_low_deviation, low_weights = jax.vmap( - model_importance, in_axes=(0, None, None) -)( - jax.random.split(sub_key, N_samples), - constraints_low_deviation, - (motion_settings_low_deviation,), -) +traces_generated_low_deviation, low_weights = jax.vmap(model_importance, in_axes=(0, None, None))(jax.random.split(sub_key, N_samples), constraints_low_deviation, (motion_settings_low_deviation,)) -traces_generated_high_deviation, high_weights = jax.vmap( - model_importance, in_axes=(0, None, None) -)( - jax.random.split(sub_key, N_samples), - constraints_high_deviation, - (motion_settings_high_deviation,), -) +traces_generated_high_deviation, high_weights = jax.vmap(model_importance, in_axes=(0, None, None))(jax.random.split(sub_key, N_samples), constraints_high_deviation, (motion_settings_high_deviation,)) # low_weights, high_weights # two histograms @@ -1399,20 +1388,10 @@ def constraint_from_path(path): # %% - -def importance_sample( - key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int -): - """Produce N importance samples of depth K from the model. That is, N times, we - generate K importance samples conditioned by the constraints, and categorically - select one of them.""" +def resample(key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int): key1, key2 = jax.random.split(key) - samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( - jax.random.split(key1, N * K), constraints, (motion_settings,) - ) - winners = jax.vmap(genjax.categorical.sampler)( - jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) - ) + samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))(jax.random.split(key1, N*K), constraints, (motion_settings,)) + winners = jax.vmap(categorical_sampler)(jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) @@ -1423,13 +1402,9 @@ def importance_sample( jit_resample = jax.jit(importance_sample, static_argnums=(3, 4)) key, sub_key = jax.random.split(key) -low_posterior = jit_resample( - sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20 -) +low_posterior = resample(sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20) key, sub_key = jax.random.split(key) -high_posterior = jit_resample( - sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20 -) +high_posterior = resample(sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20) # %% @@ -1700,50 +1675,38 @@ def multi_drift(key, trace: genjax.Trace, scale, K: int, N: int): # %% [markdown] # We can see some improvement in the density of the paths selected. It's possible to imagine improving the search by repeating this drift process on all of the samples retured by the original importance sample. But we must face one important fact: we have used acceleration to improve what amounts to a brute-force search. The next inference step should take advantage of the information we have about the control steps, iteratively improving the path from the starting point, combining the control step and sensor data information to refine the selection of each step as it is made. -# %% -# Let's approach the problem step by step instead of trying to infer the whole path. -# For each given pose, we will use the sensor data to propose a refinement. - -@genjax.gen -def perturb_pose(pose: Pose, motion_settings): - d_p = jnp.array(( - genjax.normal(0.0, motion_settings['p_noise']) @ 'd_x', - genjax.normal(0.0, motion_settings['p_noise']) @ 'd_y' - )) - d_hd = genjax.normal(0.0, motion_settings['hd_noise']) @ 'd_hd' - return Pose(pose.p + d_p, pose.hd + d_hd) - -@genjax.gen -def perturb_model(pose: Pose, motion_settings): - p1 = perturb_pose(pose, motion_settings) @ 'pose' - _ = sensor_model(p1, sensor_angles) @ 'sensor' - return p1 - # %% [markdown] +# Let's approach the problem step by step instead of trying to infer the whole path. # To get started we'll work with the initial point, and then improve it. Once that's done, # we can chain together such improved moves to hopefully get a better inference of the # actual path. # %% key, sub_key = jax.random.split(key) -p0 = start_pose_prior.simulate(sub_key, (robot_inputs['start'], motion_settings_low_deviation)) +p0 = start_pose_prior.simulate(sub_key, (motion_settings_low_deviation, robot_inputs['start'])) key, sub_key = jax.random.split(key) -tr_p0 = jax.vmap(perturb_model.simulate, in_axes=(0, None))(jax.random.split(sub_key, 100), (p0.get_retval(), motion_settings_low_deviation)) +tr_p0 = jax.vmap(full_model_kernel.simulate, in_axes=(0, None))( + jax.random.split(sub_key, 100), + (motion_settings_low_deviation, p0.get_retval(), robot_inputs['controls'][0]) +) # %% [markdown] # Create a choicemap that will enforce the given sensor observation -def observation_to_choicemap(observation): - return C['sensor', jnp.arange(len(observation)), 'distance'].set(observation) +def observation_to_choicemap(observation, pose=None): + sensor_cm = C['sensor', jnp.arange(len(observation)), 'distance'].set(observation) + pose_cm = C['pose', 'p'].set(pose.p) + C['pose', 'hd'].set(pose.hd) if pose is not None else C.n() + return sensor_cm + pose_cm + # %% [markdown] # The first thing we'll try is a Boltzmann update: generate a cloud of nearby points # using the generative function we wrote, and weightedly select a replacement from that. # First, let's generate the cloud and visualize it. # %% -def boltzmann_sample(key: PRNGKey, N: int, pose: Pose, motion_settings, observations): - return jax.vmap(perturb_model.importance, in_axes=(0, None, None))( +def boltzmann_sample(key: PRNGKey, N: int, gf, observation): + return jax.vmap(gf.importance, in_axes=(0, None, None))( jax.random.split(key, N), - observation_to_choicemap(observations), - (pose, motion_settings) + observation_to_choicemap(observation), + () ) def small_pose_plot(p: Pose, **opts): @@ -1771,9 +1734,15 @@ def weighted_small_pose_plot(target, poses, ws): "width": 400, "aspectRatio": 1 }) +# %% [markdown] +# For the first step we use the full_model_initial generative function. Subsequent steps +# will use the full_model_kernel. In the case of the initial step, we have: +# - the true initial position of the robot in green +# - the robot's belief about its initial position in red +# - a cloud of possible updates conditioned on the sensor data in shades of purple key, sub_key = jax.random.split(key) -bs = boltzmann_sample(sub_key, 1000, p0.get_retval(), motion_settings_low_deviation, observations_low_deviation[0]) -weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1]) +bs = boltzmann_sample(sub_key, 1000, full_model_initial(motion_settings_low_deviation), observations_low_deviation[0]) +weighted_small_pose_plot(path_low_deviation[0], bs[0].get_retval(), bs[1]) # %% [markdown] # Develop a function which will produce a grid of evenly spaced nearby poses given # an initial pose. $n$ is the number of steps to take in each cardinal direction @@ -1793,16 +1762,15 @@ def grid_of_nearby_poses(p, n, motion_settings): return Pose(jnp.repeat(p.p + grid, n_indices, axis=0), jnp.tile(p.hd + headings, n_indices * n_indices)) # %% cube_step_size = 8 -pose_grid = grid_of_nearby_poses(p0.get_retval(), cube_step_size, motion_settings_low_deviation) -@genjax.gen -def assess_model(p): - sensor_model(p, sensor_angles) @ 'sensor' - return p +pose_grid = grid_of_nearby_poses(path_low_deviation[0], cube_step_size, motion_settings_low_deviation) # %% key, sub_key = jax.random.split(key) -model_assess = jax.jit(assess_model.assess) -assess_scores, assess_retvals = jax.vmap(lambda k, p: model_assess(k, (p,)), in_axes=(None, 0))(observation_to_choicemap(observations_low_deviation[0]), pose_grid) -#assess_scores, assess_retvals = jax.vmap(lambda p: sensor_model.assess(cm, (p, sensor_angles)))(pose_grid) +assess_scores, assess_retvals = jax.vmap( + lambda p: full_model_initial.assess( + observation_to_choicemap(observations_low_deviation[0], path_low_deviation[0]), + (motion_settings_low_deviation, p, robot_inputs['controls'][0]) + ))(pose_grid) + # %% # Our grid of nearby poses is actually a cube when we take into consideration the # heading deltas. In order to get a 2d density, we decide to flatten the cube by @@ -1821,21 +1789,13 @@ def flatten_pose_cube(n, poses, scores): heading_groups[jnp.arange(len(heading_groups)), best]), score_groups[jnp.arange(len(score_groups)), best]) -# %% -#sensor_model.assess(cm, (pose_grid[0], sensor_angles)) -#sensor_model.simulate(sub_key, (pose_grid[0], sensor_angles)) -#sensor_model.importance(sub_key, observations_to_choicemap(observations_low_deviation, 0), (pose_grid[0], sensor_angles)) - -# Since the above calls work... -# I think this ought to work, but doesn't! TODO: find a minimal repro and file an issue -#sensor_model.assess(cm, (pose_grid[0], sensor_angles)) # %% [markdown] # Prepare a plot showing the density of nearby improvements available using the grid # search and importance sampling techniques. # %% -assess_pose_plane, assess_score_plan = flatten_pose_cube(cube_step_size, assess_retvals, assess_scores) -(weighted_small_pose_plot(p0.get_retval(), assess_retvals, assess_scores) & - weighted_small_pose_plot(p0.get_retval(), bs[0].get_retval(), bs[1])) +assess_pose_plane, assess_score_plan = flatten_pose_cube(cube_step_size, assess_retvals[0], assess_scores) +(weighted_small_pose_plot(path_low_deviation[0], assess_retvals[0], assess_scores) & + weighted_small_pose_plot(path_low_deviation[0], bs[0].get_retval(), bs[1])) # %% [markdown] # Now let's try doing the whole path. We want to produce something that is ultimately # scan-compatible, so it should have the form state -> update -> new_state. The state @@ -1846,12 +1806,21 @@ def select_by_weight(key: PRNGKey, weights: FloatArray, things): chosen = jax.random.categorical(key, weights) return jax.tree.map(lambda v: v[chosen], things) +# Step 1. retire assess_model and use full_model_kernel in both bz and grid improvers. +# Step 2. add the [pose,weight] of `pose` to the vector sampled by select_by_weight in the bz case +# Step 3. How is the weight computed for `pose` ? +# what we have now + correction term +# pose.weight = full_model_kernel.assess(p, (cm,)) + def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray, mode: str): def boltzmann_improver(k: PRNGKey, pose, observation): k1, k2 = jax.random.split(k, 2) trs, ws = boltzmann_sample(k1, 1000, pose, motion_settings, observation) - return select_by_weight(k2, ws, trs.get_retval()) + return select_by_weight(k2, ws ++ [pose.weight], trs.get_retval() ++ [pose]) + # we need to have a possibility of rejecting the move and staying where we are + # that is proportional to the weight of the current position. + # def grid_search_improver(k: PRNGKey, pose, observation): choicemap = observation_to_choicemap(observation) @@ -1867,7 +1836,7 @@ def improve_pose_and_step(state, update): improver = {"grid": grid_search_improver, "boltzmann": boltzmann_improver}[mode] p1 = improver(k1, pose, observation) # run the step model to advance one step - p2 = step_model.simulate(k2, (p1, control, motion_settings)) + p2 = step_proposal.simulate(k2, (motion_settings, p1, control)) return (p2.get_retval(), p1) # We have one fewer control than step, since no step got us to the initial position. @@ -1877,10 +1846,10 @@ def improve_pose_and_step(state, update): controls = robot_inputs['controls'] + Control(jnp.array([0]), jnp.array([0])) n_steps = len(controls) sub_keys = jax.random.split(key, n_steps + 1) - p0 = start_pose_prior.simulate(sub_keys[0], (robot_inputs['start'], motion_settings)).get_retval() + p0 = start_pose_prior.simulate(sub_keys[0], (motion_settings, robot_inputs['start'])).get_retval() return jax.lax.scan(improve_pose_and_step, p0, ( - observations, - controls, + observations, # observation at time t + controls, # guides step from t to t+1 sub_keys[1:] )) # %% @@ -1915,3 +1884,45 @@ def path_comparison_plot(*plots): animate_path_and_sensors(improved_high, observations_high_deviation, motion_settings_high_deviation, frame_key="frame"), animate_path_and_sensors(high_importance, observations_high_deviation, motion_settings_high_deviation, frame_key="frame") ) | Plot.Slider("frame", 0, T, fps=2) +# %% +@genjax.gen +def f(params): + print(f'f({params})') + return genjax.normal(params['loc'], params['scale']) @ 'x' + +@genjax.gen +def g(params): + print(f'g({params})') + return f(params) @ 'f' + +tr = f.simulate(key, ({'loc': 5.0, 'scale': 0.01},)) +# %% +tr +# %% +key, sub_key = jax.random.split(key) +tr.update(key, C['f','y'].set(99.0) + C['f','x'].set(5.05)) +# %% + + + +@genjax.gen +def g(params, c, s): + dc = genjax.normal(c, params['s'] ** 2) @ 's' + return c + dc, None + +@genjax.gen +def f(params): + return g.partial_apply(params).scan(n=3)(0.0, jnp.array([0.1, 0.2, 0.3])) @ "steps" + +key, sub_key = jax.random.split(key) +args = ({'s': 10.0},) +tr = f.simulate(sub_key, args) # works fine +key, sub_key = jax.random.split(key) +f.importance(sub_key, C['steps',1,'s'].set(99.0), args) # works fine +tr.update(sub_key, C['steps',1,'s'].set(99.0), genjax.Diff.no_change(args)) + + + + + +# %% From a476b2ea20ecc2f02678752041a3481afd1acbf5 Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Wed, 6 Nov 2024 10:54:50 -0800 Subject: [PATCH 38/86] get grid improver working --- .../probcomp-localization-tutorial.py | 877 ++++++++---------- 1 file changed, 396 insertions(+), 481 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index dcdd193..b459680 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -21,9 +21,9 @@ # if "google.colab" in sys.modules: # from google.colab import auth # pyright: ignore [reportMissingImports] - auth.authenticate_user() - %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa - %pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa +# auth.authenticate_user() +# %pip install --quiet keyring keyrings.google-artifactregistry-auth # type: ignore # noqa +# %pip install --quiet genjax==0.7.0 genstudio==2024.9.7 --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/ # type: ignore # noqa # %% [markdown] # # ProbComp Localization Tutorial # @@ -31,14 +31,10 @@ # # Dependencies are specified in pyproject.toml. # %% -# Global setup codef +# Global setup code import json import genstudio.plot as Plot - - - - import itertools import jax import jax.numpy as jnp @@ -239,6 +235,10 @@ def load_world(file_name): # %% + +noop_control = Control(jnp.array(0.0), jnp.array(0.0)) + + def integrate_controls_unphysical(robot_inputs): """ Integrates the controls to generate a path from the starting pose. @@ -258,7 +258,8 @@ def integrate_controls_unphysical(robot_inputs): pose.apply_control(control), ), robot_inputs["start"], - robot_inputs["controls"], + # Prepend a no-op control to include the first pose in the result + robot_inputs["controls"].prepend(noop_control), )[1] @@ -384,7 +385,7 @@ def integrate_controls_physical(robot_inputs): new_pose, ), robot_inputs["start"], - robot_inputs["controls"], + robot_inputs["controls"].prepend(noop_control), )[1] @@ -397,9 +398,9 @@ def integrate_controls_physical(robot_inputs): # ### Plot such data # %% def pose_plot(p, fill: str | Any = "black", **opts): - r = opts.get('r', 0.5) - wing_opacity = opts.get('opacity', 0.3) - WING_ANGLE, WING_LENGTH = jnp.pi/12, opts.get('wing_length', 0.6) + r = opts.get("r", 0.5) + wing_opacity = opts.get("opacity", 0.3) + WING_ANGLE, WING_LENGTH = jnp.pi / 12, opts.get("wing_length", 0.6) center = p.p angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1]) @@ -412,9 +413,9 @@ def pose_plot(p, fill: str | Any = "black", **opts): # Draw wings wings = Plot.line( [wing_ends[0], center, wing_ends[1]], - strokeWidth=opts.get('strokeWidth', 2), + strokeWidth=opts.get("strokeWidth", 2), stroke=fill, - opacity=wing_opacity + opacity=wing_opacity, ) # Draw center dot @@ -425,10 +426,10 @@ def pose_plot(p, fill: str | Any = "black", **opts): walls_plot = Plot.new( Plot.line( - world["wall_verts"], - strokeWidth=2, - stroke="#ccc", - ), + world["wall_verts"], + strokeWidth=2, + stroke="#ccc", + ), {"margin": 0, "inset": 50, "width": 500, "axis": None, "aspectRatio": 1}, Plot.domain([0, 20]), ) @@ -482,11 +483,13 @@ def pose_plot(p, fill: str | Any = "black", **opts): # # We start with the two building blocks: the starting pose and individual steps of motion. # %% -# @genjax.gen -# def start_pose_prior(motion_settings, start): -# p = genjax.mv_normal(start.p, motion_settings["p_noise"] ** 2 * jnp.eye(2)) @ "p" -# hd = genjax.normal(start.hd, motion_settings["hd_noise"]) @ "hd" -# return Pose(p, hd) + +# TODO(colin,jay): Originally, we passed motion_settings['p_noise'] ** 2 to +# mv_normal_diag, but I think this squares the scale twice. TFP documenentation +# - https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiag +# states that: scale = diag(scale_diag); covariance = scale @ scale.T. The second +# equation will have the effect of squaring the individual diagonal scales. + @genjax.gen def step_proposal(motion_settings, start, control): @@ -497,12 +500,12 @@ def step_proposal(motion_settings, start, control): @ "p" ) hd = genjax.normal(start.hd + control.dhd, motion_settings["hd_noise"]) @ "hd" - print(f"gonna return: {start.p, p, hd} -> {physical_step(start.p, p, hd)}") return physical_step(start.p, p, hd) -@genjax.gen -def start_pose_prior(motion_settings, start): - return step_proposal.inline(motion_settings, start, noop_control) + +# @genjax.gen +# def start_pose_prior(motion_settings, start): +# return step_proposal.inline(motion_settings, start, noop_control) # Set the motion settings default_motion_settings = {"p_noise": 0.5, "hd_noise": 2 * jnp.pi / 36.0} @@ -512,8 +515,8 @@ def start_pose_prior(motion_settings, start): # %% key = jax.random.PRNGKey(0) -start_pose_prior.simulate( - key, (default_motion_settings, robot_inputs["start"]) +step_proposal.simulate( + key, (default_motion_settings, robot_inputs["start"], noop_control) ).get_retval() # %% [markdown] @@ -542,6 +545,7 @@ def make_circle(p, r): (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) + def poses_to_plots(poses: Iterable[Pose], **plot_opts): return [pose_plot(pose, **plot_opts) for pose in poses] @@ -554,6 +558,7 @@ def confidence_circle(pose: Pose, p_noise: float): r=2.5 * p_noise, ) + Plot.color_map({"95% confidence region": "rgba(255,0,0,0.25)"}) + ( world_plot + poses_to_plots([robot_inputs["start"]], fill=Plot.constantly("step from here")) @@ -575,8 +580,8 @@ def confidence_circle(pose: Pose, p_noise: float): # %% # `simulate` takes the GF plus a tuple of args to pass to it. key, sub_key = jax.random.split(key) -trace = start_pose_prior.simulate( - sub_key, (default_motion_settings, robot_inputs["start"]) +trace = step_proposal.simulate( + sub_key, (default_motion_settings, robot_inputs["start"], noop_control) ) trace.get_choices() @@ -700,12 +705,13 @@ def confidence_circle(pose: Pose, p_noise: float): # (It is worth acknowledging two strange things in the code below: the use of the suffix `.accumulate()` in path_model and the use of that auxiliary function itself. # %% -@genjax.gen -def path_model_start(robot_inputs, motion_settings): - return start_pose_prior(motion_settings, robot_inputs["start"]) @ ( - "initial", - "pose", - ) +# @genjax.gen +# def path_model_start(robot_inputs, motion_settings): +# return start_pose_prior(motion_settingsqw, robot_inputs["start"]) @ ( +# "initial", +# "pose", +# ) + @genjax.gen def path_model_step(motion_settings, previous_pose, control): @@ -714,28 +720,17 @@ def path_model_step(motion_settings, previous_pose, control): "pose", ) -path_model = path_model_step.partial_apply(default_motion_settings).accumulate() -# TODO(colin,huebert): talk about accumulate, what it does, and _why_ from the point of view of acceleration. This is the flow control modification we were hinting at above, and it constrains the step function to have the two-argument signature that it does, which is why we reached for `partial` in the first place. Emphasize that this small bit of preparation allows massively parallel execution on a GPU and so it's worth the hassle. +path_model = path_model_step.partial_apply(default_motion_settings).accumulate() -key, sub_key1, sub_key2 = jax.random.split(key, 3) -initial_pose = path_model_start.simulate( - sub_key1, (robot_inputs, default_motion_settings) -) -step_proposal.simulate( - sub_key2, - ( - default_motion_settings, - initial_pose.get_retval(), - robot_inputs["controls"][0], - ), -).get_choices() # result[0] ~~ robot_inputs['start'] + control_step[0] (which is zero) + noise # %% def generate_path_trace(key: PRNGKey) -> genjax.Trace: - return path_model.simulate(key, (robot_inputs["start"], robot_inputs["controls"])) + return path_model.simulate( + key, (robot_inputs["start"], robot_inputs["controls"].prepend(noop_control)) + ) def path_from_trace(tr: genjax.Trace) -> Pose: @@ -755,7 +750,7 @@ def generate_path(key: PRNGKey) -> Pose: key, sub_key = jax.random.split(key) sample_paths_v = jax.vmap(generate_path)(jax.random.split(sub_key, N_samples)) -Plot.Grid([walls_plot + poses_to_plots(path) for path in sample_paths_v]) +Plot.Grid(*[walls_plot + poses_to_plots(path) for path in sample_paths_v]) # %% # Animation showing a single path with confidence circles @@ -810,8 +805,8 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): # %% key, sub_key = jax.random.split(key) -trace = start_pose_prior.simulate( - sub_key, (default_motion_settings, robot_inputs["start"]) +trace = step_proposal.simulate( + sub_key, (default_motion_settings, robot_inputs["start"], noop_control) ) key, sub_key = jax.random.split(key) rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( @@ -972,6 +967,7 @@ def sensor_model_one(pose, angle): @ "distance" ) + sensor_model = sensor_model_one.vmap(in_axes=(None, 0)) @@ -1018,11 +1014,12 @@ def noisy_sensor(pose): # # We fold the sensor model into the motion model to form a "full model", whose traces describe simulations of the entire robot situation as we have described it. # %% -@genjax.gen -def full_model_initial(motion_settings): - pose = start_pose_prior(motion_settings, robot_inputs["start"]) @ "pose" - sensor_model(pose, sensor_angles) @ "sensor" - return pose +# @genjax.gen +# def full_model_initial(motion_settings): +# pose = start_pose_prior(motion_settings, robot_inputs["start"]) @ "pose" +# sensor_model(pose, sensor_angles) @ "sensor" +# return pose + @genjax.gen def full_model_kernel(motion_settings, state, control): @@ -1030,18 +1027,31 @@ def full_model_kernel(motion_settings, state, control): sensor_model(pose, sensor_angles) @ "sensor" return pose, pose + @genjax.gen def full_model(motion_settings): - initial = full_model_initial(motion_settings) @ "initial" - return full_model_kernel.partial_apply(motion_settings).scan(n=T)(initial, robot_inputs["controls"]) @ "steps" + return ( + full_model_kernel.partial_apply(motion_settings).scan()( + robot_inputs["start"], robot_inputs["controls"].prepend(noop_control) + ) + @ "steps" + ) + def get_path(trace): + # p = trace.get_subtrace(("initial",)).get_retval() ps = trace.get_retval()[1] + # return ps.prepend(p) return ps + def get_sensors(trace): ch = trace.get_choices() - return ch["steps", :, "sensor", :, "distance"] + # return jnp.concatenate(( + # ch["initial", "sensor", ..., "distance"][jnp.newaxis], + # ch["steps", ..., "sensor", ..., "distance"] + # )) + return ch["steps", ..., "sensor", ..., "distance"] key, sub_key = jax.random.split(key) @@ -1082,7 +1092,9 @@ def animate_full_trace(trace, frame_key=None): path = get_path(trace) readings = get_sensors(trace) motion_settings = trace.get_args()[0] - return animate_path_and_sensors(path, readings, motion_settings, frame_key=frame_key) + return animate_path_and_sensors( + path, readings, motion_settings, frame_key=frame_key + ) animate_full_trace(tr) @@ -1123,11 +1135,13 @@ def animate_full_trace(trace, frame_key=None): # Encode sensor readings into choice map. -def constraint_from_sensors(readings, t: int = T): - return C["steps", jnp.arange(t + 1), "sensor", :, "distance"].set(readings[: t + 1]) - # return jax.vmap( - # lambda v: C["steps", :, "sensor", :, "distance"].set(v) - # )(readings[:t]) +def constraint_from_sensors(readings): + angle_indices = jnp.arange(len(sensor_angles)) + return jax.vmap( + lambda ix, v: C["steps", ix, "sensor", angle_indices, "distance"].set(v) + )(jnp.arange(T), readings[1:]) + C[ + "initial", "sensor", angle_indices, "distance" + ].set(readings[0]) constraints_low_deviation = constraint_from_sensors(observations_low_deviation) @@ -1166,7 +1180,7 @@ def plt(readings): animate_bare_sensors(path_integrated, world_plot) # %% -world_plot + plot_sensors(robot_inputs['start'], observations_low_deviation[0]) +world_plot + plot_sensors(robot_inputs["start"], observations_low_deviation[0]) # %% [markdown] # It would seem that the fit is reasonable in low motion deviation, but really breaks down in high motion deviation. # @@ -1183,12 +1197,15 @@ def plt(readings): model_importance = jax.jit(full_model.importance) key, sub_key = jax.random.split(key) -sample, log_weight = model_importance(sub_key, constraints_low_deviation, (motion_settings_low_deviation,)) - +sample, log_weight = model_importance( + sub_key, constraints_low_deviation, (motion_settings_low_deviation,) +) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% key, sub_key = jax.random.split(key) -sample, log_weight = model_importance(sub_key, constraints_high_deviation, (motion_settings_high_deviation,)) +sample, log_weight = model_importance( + sub_key, constraints_high_deviation, (motion_settings_high_deviation,) +) animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") # %% [markdown] # A trace resulting from a call to `importance` is structurally indistinguishable from one drawn from `simulate`. But there is a key situational difference: while `get_score` always returns the frequency with which `simulate` stochastically produces the trace, this value is **no longer equal to** the frequency with which the trace is stochastically produced by `importance`. This is both true in an obvious and less relevant sense, as well as true in a more subtle and extremely germane sense. @@ -1226,14 +1243,22 @@ def plt(readings): # In words, the data are incongruously unlikely for the integrated path. The (log) density of the measurement data, given the integrated path... # %% + + +# TODO(colin): if we prepended the noop-control once and for all, we could set T = len(controls) +# and get rid of this excess arithmetic def constraint_from_path(path): c_ps = jax.vmap(lambda ix, p: C["steps", ix, "pose", "p"].set(p))( - jnp.arange(T), path.p + jnp.arange(T + 1), path.p ) c_hds = jax.vmap(lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd))( - jnp.arange(T), path.hd + jnp.arange(T + 1), path.hd ) + + # c_p = C["initial", "pose", "p"].set(path.p[0]) + # c_hd = C["initial", "pose", "hd"].set(path.hd[0]) + return c_ps + c_hds # + c_p + c_hd @@ -1246,9 +1271,17 @@ def constraint_from_path(path): ) key, sub_key = jax.random.split(key) -trace_path_integrated_observations_low_deviation, w_low = model_importance(sub_key, constraints_path_integrated_observations_low_deviation, (motion_settings_low_deviation,)) +trace_path_integrated_observations_low_deviation, w_low = model_importance( + sub_key, + constraints_path_integrated_observations_low_deviation, + (motion_settings_low_deviation,), +) key, sub_key = jax.random.split(key) -trace_path_integrated_observations_high_deviation, w_high = model_importance(sub_key, constraints_path_integrated_observations_high_deviation, (motion_settings_high_deviation,)) +trace_path_integrated_observations_high_deviation, w_high = model_importance( + sub_key, + constraints_path_integrated_observations_high_deviation, + (motion_settings_high_deviation,), +) w_low, w_high # TODO: Jay then does two projections to compare the log-weights of these two things, @@ -1258,18 +1291,28 @@ def constraint_from_path(path): # %% Plot.Row( - *[(html("div.f3.b.tc", title) - | animate_full_trace(trace, frame_key="frame") - | html("span.tc", f"score: {score:,.2f}")) - for (title, trace, motion_settings, score) in - [["Low deviation", - trace_path_integrated_observations_low_deviation, - motion_settings_low_deviation, - w_low], - ["High deviation", - trace_path_integrated_observations_high_deviation, - motion_settings_high_deviation, - w_high]]]) | Plot.Slider("frame", 0, T, fps=2) + *[ + ( + html("div.f3.b.tc", title) + | animate_full_trace(trace, frame_key="frame") + | html("span.tc", f"score: {score:,.2f}") + ) + for (title, trace, motion_settings, score) in [ + [ + "Low deviation", + trace_path_integrated_observations_low_deviation, + motion_settings_low_deviation, + w_low, + ], + [ + "High deviation", + trace_path_integrated_observations_high_deviation, + motion_settings_high_deviation, + w_high, + ], + ] + ] +) | Plot.Slider("frame", 0, T, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... @@ -1280,9 +1323,21 @@ def constraint_from_path(path): key, sub_key = jax.random.split(key) -traces_generated_low_deviation, low_weights = jax.vmap(model_importance, in_axes=(0, None, None))(jax.random.split(sub_key, N_samples), constraints_low_deviation, (motion_settings_low_deviation,)) +traces_generated_low_deviation, low_weights = jax.vmap( + model_importance, in_axes=(0, None, None) +)( + jax.random.split(sub_key, N_samples), + constraints_low_deviation, + (motion_settings_low_deviation,), +) -traces_generated_high_deviation, high_weights = jax.vmap(model_importance, in_axes=(0, None, None))(jax.random.split(sub_key, N_samples), constraints_high_deviation, (motion_settings_high_deviation,)) +traces_generated_high_deviation, high_weights = jax.vmap( + model_importance, in_axes=(0, None, None) +)( + jax.random.split(sub_key, N_samples), + constraints_high_deviation, + (motion_settings_high_deviation,), +) # low_weights, high_weights # two histograms @@ -1388,10 +1443,17 @@ def constraint_from_path(path): # %% -def resample(key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int): + +def resample( + key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int +): key1, key2 = jax.random.split(key) - samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))(jax.random.split(key1, N*K), constraints, (motion_settings,)) - winners = jax.vmap(categorical_sampler)(jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))) + samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( + jax.random.split(key1, N * K), constraints, (motion_settings,) + ) + winners = jax.vmap(categorical_sampler)( + jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) + ) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) @@ -1399,30 +1461,33 @@ def resample(key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: in return selected -jit_resample = jax.jit(importance_sample, static_argnums=(3, 4)) - key, sub_key = jax.random.split(key) -low_posterior = resample(sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20) +low_posterior = resample( + sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20 +) key, sub_key = jax.random.split(key) -high_posterior = resample(sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20) +high_posterior = resample( + sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20 +) # %% + def animate_path_as_line(path, **options): x_coords = path.p[:, 0] y_coords = path.p[:, 1] - return Plot.line({"x": x_coords, "y": y_coords}, - {"curve": "linear", - **options}) + return Plot.line({"x": x_coords, "y": y_coords}, {"curve": "linear", **options}) + + # ( world_plot + [ - path_to_polyline(path, opacity=0.2, strokeWidth=2, stroke="green") + animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="green") for path in jax.vmap(get_path)(low_posterior) ] + [ - path_to_polyline(path, opacity=0.2, strokeWidth=2, stroke="blue") + animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="blue") for path in jax.vmap(get_path)(high_posterior) ] + poses_to_plots( @@ -1446,257 +1511,42 @@ def animate_path_as_line(path, **options): # Let's pause a moment to examine this chart. If the robot had no sensors, it would have no alternative but to estimate its position by integrating the control inputs to produce the integrated path in gray. In the low deviation setting, Gen has helped the robot to see that about halfway through its journey, noise in the control-effector relationship has caused the robot to deviate to the south slightly, and *the sensor data combined with importance sampling is enough* to give accurate results in the low deviation setting. # But in the high deviation setting, the loose nature of the paths in the blue posterior indicate that the robot has not discovered its true position by using importance sampling with the noisy sensor data. In the high deviation setting, more refined inference technique will be required. # -# Let's approach the problem step by step instead of trying to infer the whole path at once. -# The technique we will use is called Sequential Importance Sampling or a -# [Particle Filter](https://en.wikipedia.org/wiki/Particle_filter). It works like this. -# -# When we designed the step model for the robot, we arranged things so that the model -# could be used with `scan`: the model takes a *state* and a *control input* to produce -# a new *state*. Imagine at some time step $t$ that we use importance sampling with this -# model at a pose $\mathbf{z}_t$ and control input $\mathbf{u}_t$, scored with respect to the -# sensor observations $\mathbf{y}_t$ observed at that time. We will get a weighted collection -# of possible updated poses $\mathbf{z}_t^N$ and weights $w^N$. -# -# The particle filter "winnows" this set by replacing it with $N$ weighted selections -# *with replacement* from this collection. This may select better candidates several -# times, and is likely to drop poor candidates from the collection. We can arrange to -# to this at each time step with a little preparation: we start by "cloning" our idea -# of the robot's initial position into an N vector and this becomes the initial particle -# collection. At each step, we generate an importance sample and winnow it. -# -# This can also be done as a scan. Our previous attempt used `scan` to produce candidate -# paths from start to end, and these were scored for importance using all of the sensor -# readings at once. The results were better than guesses, but not accurate, in the -# high deviation setting. -# -# The technique we will use here discards steps with low likelihood at each step, and -# reinforces steps with high likelihood, allowing better particles to proportionately -# search more of the probability space while discarding unpromising particles. -# -# The following class attempts to generatlize this idea: - -# %% -StateT = TypeVar("StateT") -ControlT = TypeVar("ControlT") - - -class SequentialImportanceSampling(Generic[StateT, ControlT]): - """ - Given: - - a functional wrapper for the importance method of a generative function - - an initial state of type StateT, which should be a PyTree $z_0$ - - a vector of control inputs, also a PyTree $u_i, of shape $(T, \ldots)$ - - an array of observations $y_i$, also of shape $(T, \ldots)$ - perform the inference technique known as Sequential Importance Sampling. - - The signature of the GFI importance method is - key -> constraint -> args -> (trace, weight) - For importance sampling, this is vmapped over key to get - [keys] -> constraint -> args -> ([trace], [weight]) - The functional wrapper's purpose is to maneuver the state and control - inputs into whatever argument shape the underlying model is expecting, - and to turn the observation at step $t$ into a choicemap asserting - that constraint. - - After the object is constructed, SIS can be performed at any importance - depth with the `run` method, which will perform the following steps: - - - inflate the initial value to a vector of size N of identical initial - values - - vmap over N keys generated from the supplied key - - each vmap cell will scan over the control inputs and observations - - Between each step, categorical sampling with replacement is formed to - create a particle filter. Favorable importance draws are likely to - be replicated, and unfavorable ones discarded. The resampled vector of - states is sent the the next step, while the values drawn from the - importance sample and the indices chosen are emitted from teh scan step, - where, at the end of the process, they will be available as matrices - of shape (N, T). - """ - - def __init__( - self, - importance: Callable[ - [PRNGKey, StateT, ControlT, Array], tuple[genjax.Trace[StateT], float] - ], - init: StateT, - controls: ControlT, - observations: Array, - ): - self.importance = jax.jit(importance) - self.init = init - self.controls = controls - self.observations = observations - - class Result(Generic[StateT]): - """This object contains all of the information generated by the SIS scan, - and offers some convenient methods to reconstruct the paths explored - (`flood_fill`) or ultimately chosen (`backtrack`). - """ - - def __init__( - self, N: int, end: StateT, samples: genjax.Trace[StateT], indices: IntArray - ): - self.N = N - self.end = end - self.samples = samples - self.indices = indices - - def flood_fill(self) -> list[list[StateT]]: - samples = self.samples.get_retval() - active_paths = [[p] for p in samples[0]] - complete_paths = [] - for i in range(1, len(samples)): - indices = self.indices[i - 1] - counts = jnp.bincount(indices, length=self.N) - new_active_paths = self.N * [None] - for j in range(self.N): - if counts[j] == 0: - complete_paths.append(active_paths[j]) - new_active_paths[j] = active_paths[indices[j]] + [samples[i][j]] - active_paths = new_active_paths - - return complete_paths + active_paths - - def backtrack(self) -> list[list[StateT]]: - paths = [[p] for p in self.end] - samples = self.samples.get_retval() - for i in reversed(range(len(samples))): - for j in range(len(paths)): - paths[j].append(samples[i][self.indices[i][j].item()]) - for p in paths: - p.reverse() - return paths - - def run(self, key: PRNGKey, N: int) -> dict: - def step(state, update): - key, control, observation = update - ks = jax.random.split(key, (2, N)) - sample, log_weights = jax.vmap(self.importance, in_axes=(0, 0, None, None))( - ks[0], state, control, observation - ) - indices = jax.vmap(genjax.categorical.sampler, in_axes=(0, None))( - ks[1], log_weights - ) - resample = jax.tree.map(lambda v: v[indices], sample) - return resample.get_retval(), (sample, indices) - - init_array = jax.tree.map( - lambda a: jnp.broadcast_to(a, (N,) + a.shape), self.init - ) - end, (samples, indices) = jax.lax.scan( - step, - init_array, - ( - jax.random.split(key, len(self.controls)), - self.controls, - self.observations, - ), - ) - return SequentialImportanceSampling.Result(N, end, samples, indices) - - -# %% -def localization_sis(motion_settings, observations): - return SequentialImportanceSampling( - lambda key, pose, control, observation: full_model_kernel.importance( - key, - C["sensor", :, "distance"].set(observation), - (motion_settings, pose, control), - ), - robot_inputs["start"], - robot_inputs["controls"], - observations, - ) - - -# %% - -key, sub_key = jax.random.split(key) -smc_result = localization_sis( - motion_settings_high_deviation, observations_high_deviation -).run(sub_key, 100) - -( - world_plot - + path_to_polyline(path_high_deviation, stroke="blue", strokeWidth=2) - + [ - path_to_polyline(pose_list_to_plural_pose(p), opacity=0.1, stroke="green") - for p in smc_result.flood_fill() - ] -) -# %% -# Try it in the low deviation setting -key, sub_key = jax.random.split(key) -N_updates = 1000 -drift_traces, log_weights, _, _ = jax.vmap(gaussian_drift, in_axes=(0, None, None))(jax.random.split(sub_key, 1000), t0, motion_settings_high_deviation) - -# %% [markdown] -# Let's weightedly-select 10 from among those and see if there's any improvement -# %% -key, sub_key = jax.random.split(key) -N_selection = 10 -selected_indices = jax.vmap(categorical_sampler, in_axes=(0, None))(jax.random.split(sub_key, N_selection), log_weights) -selected_indices -# %% [markdown] -# Do you notice that many (or all) the selected indices are repeats? This is because we are searching a probability space of high dimension: it's unlikely that there will be many traces producing a dramatic improvement. Even if there's only one, we'll write the plotting function for a selection of drifted traces: after that, we will fix the problem of repeated selections. -# %% - -selected_traces = jax.tree.map(lambda v: v[selected_indices], drift_traces) - -def plot_traces(traces): - return (world_plot - + [animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="green") for path in jax.vmap(get_path)(traces)] - + poses_to_plots(path_high_deviation, fill=Plot.constantly("high deviation path"), opacity=0.2) - + Plot.color_map({"low deviation path": "green", "high deviation path": "blue", "integrated path": "black"})) +# Let's approach the problem step by step instead of trying to infer the whole path. +# To get started we'll work with the initial point, and then improve it. Once that's done, +# we can chain together such improved moves to hopefully get a better inference of the +# actual path. -plot_traces(selected_traces) +# One thing we'll need is a path to improve. We can select one of the importance samples we generated +# earlier. -# %% [markdown] -# That looks promising, but there may only be one path in that output, since one of the drifted traces is probabilistically dominant. How can we get more candidate traces? We can use `vmap` *again*, to provide a fresh batch of drift samples for each desired trace. That will give us a weighted sample of potentially-improved traces to work with. -# %% -# Generate K drifted samples, by generating N importance samples for each K and making a weighted selection from each batch. -def multi_drift(key, trace: genjax.Trace, scale, K: int, N: int): - k1, k2 = jax.random.split(key) - kn_samples, log_weights, _, _ = jax.vmap(gaussian_drift, in_axes=(0, None, None))(jax.random.split(k1, N*K), trace, scale) - batched_weights = log_weights.reshape((K, N)) - winners = jax.vmap(categorical_sampler)(jax.random.split(k2, K), batched_weights) - # The winning indices are relative to the batch from which they were drawn. Reset the indices to linear form. - winners += jnp.arange(0, N*K, N) - return jax.tree.map(lambda v: v[winners], kn_samples) +def select_by_weight(key: PRNGKey, weights: FloatArray, things): + chosen = jax.random.categorical(key, weights) + return jax.tree.map(lambda v: v[chosen], things) -# %% -key, sub_key = jax.random.split(key) -drifted_traces = multi_drift(sub_key, t0, motion_settings_high_deviation, 20, 1000) -plot_traces(drifted_traces) # %% [markdown] -# We can see some improvement in the density of the paths selected. It's possible to imagine improving the search by repeating this drift process on all of the samples retured by the original importance sample. But we must face one important fact: we have used acceleration to improve what amounts to a brute-force search. The next inference step should take advantage of the information we have about the control steps, iteratively improving the path from the starting point, combining the control step and sensor data information to refine the selection of each step as it is made. +# Select an importance sample by weight in both the low and high deviation settings. It will be handy +# to have one path to work with to test our improvements. -# %% [markdown] -# Let's approach the problem step by step instead of trying to infer the whole path. -# To get started we'll work with the initial point, and then improve it. Once that's done, -# we can chain together such improved moves to hopefully get a better inference of the -# actual path. +key, k1, k2 = jax.random.split(key, 3) +low_deviation_path = select_by_weight(k1, low_weights, low_deviation_paths) +high_deviation_path = select_by_weight(k2, high_weights, high_deviation_paths) -# %% -key, sub_key = jax.random.split(key) -p0 = start_pose_prior.simulate(sub_key, (motion_settings_low_deviation, robot_inputs['start'])) -key, sub_key = jax.random.split(key) -tr_p0 = jax.vmap(full_model_kernel.simulate, in_axes=(0, None))( - jax.random.split(sub_key, 100), - (motion_settings_low_deviation, p0.get_retval(), robot_inputs['controls'][0]) -) # %% [markdown] # Create a choicemap that will enforce the given sensor observation + def observation_to_choicemap(observation, pose=None): - sensor_cm = C['sensor', jnp.arange(len(observation)), 'distance'].set(observation) - pose_cm = C['pose', 'p'].set(pose.p) + C['pose', 'hd'].set(pose.hd) if pose is not None else C.n() + sensor_cm = C["sensor", jnp.arange(len(observation)), "distance"].set(observation) + pose_cm = ( + C["pose", "p"].set(pose.p) + C["pose", "hd"].set(pose.hd) + if pose is not None + else C.n() + ) return sensor_cm + pose_cm + # %% [markdown] # The first thing we'll try is a Boltzmann update: generate a cloud of nearby points # using the generative function we wrote, and weightedly select a replacement from that. @@ -1704,18 +1554,18 @@ def observation_to_choicemap(observation, pose=None): # %% def boltzmann_sample(key: PRNGKey, N: int, gf, observation): return jax.vmap(gf.importance, in_axes=(0, None, None))( - jax.random.split(key, N), - observation_to_choicemap(observation), - () + jax.random.split(key, N), observation_to_choicemap(observation), () ) + def small_pose_plot(p: Pose, **opts): """This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose. TODO: consider scaling r and wing_length based on the size of the plot domain.""" - opts = {'r': 0.001} | opts + opts = {"r": 0.001} | opts return pose_plot(p, wing_length=0.006, **opts) -def weighted_small_pose_plot(target, poses, ws): + +def weighted_small_pose_plot(proposal, truth, poses, ws): lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) scaled_ws = jnp.exp(ws - lse_ws) max_scaled_w: FloatArray = jnp.max(scaled_ws) @@ -1724,16 +1574,19 @@ def weighted_small_pose_plot(target, poses, ws): # the density of the nearby cloud. Aesthetically, I found too many points were # invisible without some adjustment, since the score distribution is concentrated # closely around 1.0 - scaled_ws = scaled_ws ** 0.3 - return (Plot.new([small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)] - + small_pose_plot(target, r = 0.003, fill='red') - + small_pose_plot(robot_inputs['start'], r=0.003,fill='green')) - + { - "color": {"type":"linear", "scheme":"Purples"}, - "height": 400, - "width": 400, - "aspectRatio": 1 - }) + scaled_ws = scaled_ws**0.3 + return Plot.new( + [small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)] + + small_pose_plot(proposal, r=0.003, fill="red") + + small_pose_plot(truth, r=0.003, fill="green") + ) + { + "color": {"type": "linear", "scheme": "OrRd"}, + "height": 400, + "width": 400, + "aspectRatio": 1, + } + + # %% [markdown] # For the first step we use the full_model_initial generative function. Subsequent steps # will use the full_model_kernel. In the case of the initial step, we have: @@ -1741,8 +1594,20 @@ def weighted_small_pose_plot(target, poses, ws): # - the robot's belief about its initial position in red # - a cloud of possible updates conditioned on the sensor data in shades of purple key, sub_key = jax.random.split(key) -bs = boltzmann_sample(sub_key, 1000, full_model_initial(motion_settings_low_deviation), observations_low_deviation[0]) -weighted_small_pose_plot(path_low_deviation[0], bs[0].get_retval(), bs[1]) +bs = boltzmann_sample( + sub_key, + 1000, + full_model_kernel( + motion_settings_low_deviation, robot_inputs["start"], noop_control + ), + observations_low_deviation[0], +) +# %% +weighted_small_pose_plot( + path_low_deviation[0], robot_inputs["start"], bs[0].get_retval()[0], bs[1] +) + + # %% [markdown] # Develop a function which will produce a grid of evenly spaced nearby poses given # an initial pose. $n$ is the number of steps to take in each cardinal direction @@ -1752,177 +1617,227 @@ def weighted_small_pose_plot(target, poses, ws): # $125 = 5^3$ alternate poses. # %% def grid_of_nearby_poses(p, n, motion_settings): - indices = jnp.arange(-n, n+1) + indices = jnp.arange(-n, n + 1) n_indices = len(indices) - grid_ax = indices * 2 * motion_settings['p_noise'] / n - grid = jnp.dstack(jnp.meshgrid(grid_ax, grid_ax)).reshape(n_indices * n_indices, -1) - # That's the position grid. We will now make a 1-d grid for the heading deltas, - # and then form the linear cartesian product. - headings = indices * 2 * motion_settings['hd_noise'] / n - return Pose(jnp.repeat(p.p + grid, n_indices, axis=0), jnp.tile(p.hd + headings, n_indices * n_indices)) -# %% -cube_step_size = 8 -pose_grid = grid_of_nearby_poses(path_low_deviation[0], cube_step_size, motion_settings_low_deviation) + point_deltas = indices * 2 * motion_settings["p_noise"] / n + hd_deltas = indices * 2 * motion_settings["hd_noise"] / n + xs = jnp.repeat(point_deltas, n_indices) + ys = jnp.tile(point_deltas, n_indices) + points = jnp.repeat(jnp.column_stack((xs, ys)), n_indices, axis=0) + headings = jnp.tile(hd_deltas, n_indices * n_indices) + return Pose(p.p + points, p.hd + headings) + + # %% -key, sub_key = jax.random.split(key) -assess_scores, assess_retvals = jax.vmap( - lambda p: full_model_initial.assess( - observation_to_choicemap(observations_low_deviation[0], path_low_deviation[0]), - (motion_settings_low_deviation, p, robot_inputs['controls'][0]) - ))(pose_grid) + + +def grid_sample(gf, pose_grid, observations): + scores, _retvals = jax.vmap( + lambda pose: gf.assess(observation_to_choicemap(observations, pose), ()) + )(pose_grid) + return scores + # %% # Our grid of nearby poses is actually a cube when we take into consideration the -# heading deltas. In order to get a 2d density, we decide to flatten the cube by -# taking the "best" of the headings by score at each point. -def flatten_pose_cube(n, poses, scores): - d = 2 * n + 1 - pose_groups = poses.p.reshape((d, d*d, 2)) - heading_groups = poses.hd.reshape((d, d*d)) - score_groups = scores.reshape((d, d*d)) - # find the best score in each group - best = jnp.argmax(score_groups, axis=1) - # We want to select the best column from every row, so we need to - # explicitly enumerate the rows we want (using : would not have the - # same effect) - return (Pose(pose_groups[jnp.arange(len(pose_groups)), best], - heading_groups[jnp.arange(len(heading_groups)), best]), - score_groups[jnp.arange(len(score_groups)), best]) +# heading deltas. In order to get a 2d density to visualize, we flatten the cube by +# taking the "best" of the headings by score at each point. (Note: for the inference +# that follows, we will work with the full cube). +def flatten_pose_cube(pose_grid, cube_step_size, scores): + n_indices = 2 * cube_step_size + 1 + best_heading_indices = jnp.argmax( + scores.reshape(n_indices * n_indices, n_indices), axis=1 + ) + # those were block relative; linearize them by adding back block indices + bs = best_heading_indices + jnp.arange(0, n_indices**3, n_indices) + return Pose(pose_grid.p[bs], pose_grid.hd[bs]), scores[bs] + # %% [markdown] # Prepare a plot showing the density of nearby improvements available using the grid # search and importance sampling techniques. # %% -assess_pose_plane, assess_score_plan = flatten_pose_cube(cube_step_size, assess_retvals[0], assess_scores) -(weighted_small_pose_plot(path_low_deviation[0], assess_retvals[0], assess_scores) & - weighted_small_pose_plot(path_low_deviation[0], bs[0].get_retval(), bs[1])) +# Test our code for visualizing the Boltzmann and grid searches at the initial pose. +def initial_pose_chart(key): + cube_step_size = 6 + pose_grid = grid_of_nearby_poses( + path_low_deviation[0], cube_step_size, motion_settings_low_deviation + ) + score_grid = grid_sample( + full_model_kernel( + motion_settings_low_deviation, robot_inputs["start"], noop_control + ), + pose_grid, + observations_low_deviation[0], + ) + pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid) + return weighted_small_pose_plot( + path_low_deviation[0], robot_inputs["start"], pose_plane, score_plane + ) & weighted_small_pose_plot( + path_low_deviation[0], robot_inputs["start"], bs[0].get_retval()[0], bs[1] + ) + + +key, sub_key = jax.random.split(key) +initial_pose_chart(sub_key) + + +# %% [markdown] +# See if this works for other points in the path +def improvements_at_step(key, path, k): + gf = full_model_kernel( + motion_settings_low_deviation, path[k - 1], robot_inputs["controls"][k - 1] + ) + cube_step_size = 6 + bs = boltzmann_sample(k1, 500, gf, observations_low_deviation[k]) + print( + f'from {path[k-1]}, step {robot_inputs['controls'][k-1]}, truth {path_low_deviation[k]}, ps {bs[0].get_retval()[0]}' + ) + p1 = weighted_small_pose_plot( + path[k], path_low_deviation[k], bs[0].get_retval()[0], bs[1] + ) + pose_grid = grid_of_nearby_poses( + path[k], cube_step_size, motion_settings_low_deviation + ) + score_grid = grid_sample(gf, pose_grid, observations_low_deviation[k]) + pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid) + print(f"score_plane {score_plane}") + p2 = weighted_small_pose_plot( + path[k], path_low_deviation[k], pose_plane, score_plane + ) + return p1 & p2 + + +key, sub_key = jax.random.split(key) +improvements_at_step(sub_key, low_deviation_path, 5) + +# %% +# Animation of the above +key, *sub_keys = jax.random.split(key, 5) +Plot.Frames( + [ + improvements_at_step(k, low_deviation_path, i + 1) + for i, k in enumerate(sub_keys) + ], + fps=2, +) + # %% [markdown] # Now let's try doing the whole path. We want to produce something that is ultimately # scan-compatible, so it should have the form state -> update -> new_state. The state # is obviously the pose; the update will include the sensor readings at the current # position and the control input for the next step. -def select_by_weight(key: PRNGKey, weights: FloatArray, things): - chosen = jax.random.categorical(key, weights) - return jax.tree.map(lambda v: v[chosen], things) - # Step 1. retire assess_model and use full_model_kernel in both bz and grid improvers. # Step 2. add the [pose,weight] of `pose` to the vector sampled by select_by_weight in the bz case # Step 3. How is the weight computed for `pose` ? # what we have now + correction term # pose.weight = full_model_kernel.assess(p, (cm,)) -def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray, mode: str): - def boltzmann_improver(k: PRNGKey, pose, observation): +def improved_path( + mode: str, key: PRNGKey, motion_settings: dict, observations: FloatArray +): + def boltzmann_step(k: PRNGKey, gf, _center_pose, observation): k1, k2 = jax.random.split(k, 2) - trs, ws = boltzmann_sample(k1, 1000, pose, motion_settings, observation) - return select_by_weight(k2, ws ++ [pose.weight], trs.get_retval() ++ [pose]) - # we need to have a possibility of rejecting the move and staying where we are - # that is proportional to the weight of the current position. - # - - def grid_search_improver(k: PRNGKey, pose, observation): - choicemap = observation_to_choicemap(observation) - nearby_poses = grid_of_nearby_poses(pose, 15, motion_settings) - ws, retvals = jax.vmap(lambda p: assess_model.assess(choicemap, (p,)))(nearby_poses) - return select_by_weight(k, ws, nearby_poses) - - def improve_pose_and_step(state, update): - pose = state + trs, ws = boltzmann_sample(k1, 1000, gf, observation) + return ws, trs.get_retval()[0] + + def grid_search_step(k: PRNGKey, gf, center_pose, observation): + pose_grid = grid_of_nearby_poses(center_pose, 15, motion_settings) + nearby_weights = grid_sample(gf, pose_grid, observation) + return nearby_weights, pose_grid + + def improved_step(state, update): observation, control, key = update - k1, k2 = jax.random.split(key) - # improve the step where we are - improver = {"grid": grid_search_improver, "boltzmann": boltzmann_improver}[mode] - p1 = improver(k1, pose, observation) - # run the step model to advance one step - p2 = step_proposal.simulate(k2, (motion_settings, p1, control)) - return (p2.get_retval(), p1) + gf = full_model_kernel(motion_settings, state, control) + k1, k2, k3 = jax.random.split(key, 3) + # First, just run the model. + tr = gf.simulate(k1, ()) + new_pose = tr.get_retval()[0] + improver = {"grid": grid_search_step, "boltzmann": boltzmann_step}[mode] + # Run the improver, and add the candidate point to the list of weights and + # return values, to create the possibility of accepting the initial proposal + # as well as any of the improvement candidates, as Bayesian inference requires + weights, poses = improver(k2, gf, new_pose, observation) + #weights = jnp.append(weights, tr.get_score()) + #poses = Pose(jnp.vstack(poses.p, new_pose.p), jnp.append(poses.hd, new_pose.hd)) + chosen_pose = select_by_weight(k3, weights, poses) + return chosen_pose, chosen_pose # We have one fewer control than step, since no step got us to the initial position. # Our scan step starts at the initial step and applies a control input each time. # To make things balance, we need to add a zero step to the end of the control input # array, so that when we arrive at the final step, no more control input is given. - controls = robot_inputs['controls'] + Control(jnp.array([0]), jnp.array([0])) + controls = robot_inputs["controls"].prepend(noop_control) n_steps = len(controls) sub_keys = jax.random.split(key, n_steps + 1) - p0 = start_pose_prior.simulate(sub_keys[0], (motion_settings, robot_inputs['start'])).get_retval() - return jax.lax.scan(improve_pose_and_step, p0, ( - observations, # observation at time t - controls, # guides step from t to t+1 - sub_keys[1:] - )) -# %% -# Select an importance sample via weight in both the low and high deviation settings. -key, k1, k2 = jax.random.split(key, 3) -low_importance = select_by_weight(k1, low_weights, low_deviation_paths) -high_importance = select_by_weight(k2, high_weights, high_deviation_paths) + return jax.lax.scan( + improved_step, + robot_inputs["start"], + ( + observations, # observation at time t + controls, # guides step from t to t+1 + sub_keys[1:], + ), + ) + +jit_improved_path = jax.jit(improved_path, static_argnums=0) + # %% key, sub_key = jax.random.split(key) -endpoint_low, improved_low = improved_path(sub_key, motion_settings_low_deviation, observations_low_deviation, "grid") +_, improved_low = jit_improved_path( + "grid", sub_key, motion_settings_low_deviation, observations_low_deviation +) # %% + def path_comparison_plot(*plots): types = ["improved", "integrated", "importance", "true"] plot = world_plot - plot += [animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t)) for p, t in zip(plots, types)] + plot += [ + animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t)) + for p, t in zip(plots, types) + ] plot += [poses_to_plots(p, fill=Plot.constantly(t)) for p, t in zip(plots, types)] - return plot + Plot.color_map({"integrated": "green", "improved": "blue", "true": "black", "importance": "red"}) + return plot + Plot.color_map( + { + "integrated": "green", + "improved": "blue", + "true": "black", + "importance": "red", + } + ) + # %% -path_comparison_plot(improved_low, path_integrated, low_importance, path_low_deviation) +path_comparison_plot( + improved_low, path_integrated, low_deviation_path, path_low_deviation +) # %% key, sub_key = jax.random.split(key) -endpoint_high, improved_high = improved_path(sub_key, motion_settings_high_deviation, observations_high_deviation, "grid") -path_comparison_plot(improved_high, path_integrated, high_importance, path_high_deviation) +_, improved_high = jit_improved_path( + "boltzmann", sub_key, motion_settings_high_deviation, observations_high_deviation +) +path_comparison_plot( + improved_high, path_integrated, high_deviation_path, path_high_deviation +) # %% [markdown] # To see how the grid search improves poses, we play back the grid-search path # next to an importance sample path. You can see the grid search has a better fit # of sensor data to wall position at a variety of time steps. # %% Plot.Row( - animate_path_and_sensors(improved_high, observations_high_deviation, motion_settings_high_deviation, frame_key="frame"), - animate_path_and_sensors(high_importance, observations_high_deviation, motion_settings_high_deviation, frame_key="frame") + animate_path_and_sensors( + improved_high, + observations_high_deviation, + motion_settings_high_deviation, + frame_key="frame", + ), + animate_path_and_sensors( + high_deviation_path, + observations_high_deviation, + motion_settings_high_deviation, + frame_key="frame", + ), ) | Plot.Slider("frame", 0, T, fps=2) -# %% -@genjax.gen -def f(params): - print(f'f({params})') - return genjax.normal(params['loc'], params['scale']) @ 'x' - -@genjax.gen -def g(params): - print(f'g({params})') - return f(params) @ 'f' - -tr = f.simulate(key, ({'loc': 5.0, 'scale': 0.01},)) -# %% -tr -# %% -key, sub_key = jax.random.split(key) -tr.update(key, C['f','y'].set(99.0) + C['f','x'].set(5.05)) -# %% - - - -@genjax.gen -def g(params, c, s): - dc = genjax.normal(c, params['s'] ** 2) @ 's' - return c + dc, None - -@genjax.gen -def f(params): - return g.partial_apply(params).scan(n=3)(0.0, jnp.array([0.1, 0.2, 0.3])) @ "steps" - -key, sub_key = jax.random.split(key) -args = ({'s': 10.0},) -tr = f.simulate(sub_key, args) # works fine -key, sub_key = jax.random.split(key) -f.importance(sub_key, C['steps',1,'s'].set(99.0), args) # works fine -tr.update(sub_key, C['steps',1,'s'].set(99.0), genjax.Diff.no_change(args)) - - - - - -# %% From 0b3cb4cd7f1c58dbeb00281de1e3294da43a22b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:36:16 +0000 Subject: [PATCH 39/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../probcomp-localization-tutorial.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index b459680..238522c 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1760,8 +1760,8 @@ def improved_step(state, update): # return values, to create the possibility of accepting the initial proposal # as well as any of the improvement candidates, as Bayesian inference requires weights, poses = improver(k2, gf, new_pose, observation) - #weights = jnp.append(weights, tr.get_score()) - #poses = Pose(jnp.vstack(poses.p, new_pose.p), jnp.append(poses.hd, new_pose.hd)) + # weights = jnp.append(weights, tr.get_score()) + # poses = Pose(jnp.vstack(poses.p, new_pose.p), jnp.append(poses.hd, new_pose.hd)) chosen_pose = select_by_weight(k3, weights, poses) return chosen_pose, chosen_pose @@ -1782,6 +1782,7 @@ def improved_step(state, update): ), ) + jit_improved_path = jax.jit(improved_path, static_argnums=0) # %% From cafe8a80bdd95371b992343087af2bd9022a7e3b Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Wed, 6 Nov 2024 12:53:31 -0800 Subject: [PATCH 40/86] - prepend `noop_contro`l to `robot_inputs['controls']` once and for all - take care of the conseqeunces of this throughout - switch from accumulate to scan with dimap (a little more verbose but consistency is improved) --- .../probcomp-localization-tutorial.py | 80 ++++++------------- 1 file changed, 25 insertions(+), 55 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 238522c..ccb70b7 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -235,10 +235,6 @@ def load_world(file_name): # %% - -noop_control = Control(jnp.array(0.0), jnp.array(0.0)) - - def integrate_controls_unphysical(robot_inputs): """ Integrates the controls to generate a path from the starting pose. @@ -258,8 +254,7 @@ def integrate_controls_unphysical(robot_inputs): pose.apply_control(control), ), robot_inputs["start"], - # Prepend a no-op control to include the first pose in the result - robot_inputs["controls"].prepend(noop_control), + robot_inputs["controls"], )[1] @@ -385,7 +380,7 @@ def integrate_controls_physical(robot_inputs): new_pose, ), robot_inputs["start"], - robot_inputs["controls"].prepend(noop_control), + robot_inputs["controls"], )[1] @@ -503,10 +498,6 @@ def step_proposal(motion_settings, start, control): return physical_step(start.p, p, hd) -# @genjax.gen -# def start_pose_prior(motion_settings, start): -# return step_proposal.inline(motion_settings, start, noop_control) - # Set the motion settings default_motion_settings = {"p_noise": 0.5, "hd_noise": 2 * jnp.pi / 36.0} @@ -516,7 +507,7 @@ def step_proposal(motion_settings, start, control): # %% key = jax.random.PRNGKey(0) step_proposal.simulate( - key, (default_motion_settings, robot_inputs["start"], noop_control) + key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]) ).get_retval() # %% [markdown] @@ -581,7 +572,8 @@ def confidence_circle(pose: Pose, p_noise: float): # `simulate` takes the GF plus a tuple of args to pass to it. key, sub_key = jax.random.split(key) trace = step_proposal.simulate( - sub_key, (default_motion_settings, robot_inputs["start"], noop_control) + sub_key, + (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) trace.get_choices() @@ -705,32 +697,16 @@ def confidence_circle(pose: Pose, p_noise: float): # (It is worth acknowledging two strange things in the code below: the use of the suffix `.accumulate()` in path_model and the use of that auxiliary function itself. # %% -# @genjax.gen -# def path_model_start(robot_inputs, motion_settings): -# return start_pose_prior(motion_settingsqw, robot_inputs["start"]) @ ( -# "initial", -# "pose", -# ) - - -@genjax.gen -def path_model_step(motion_settings, previous_pose, control): - return step_proposal(motion_settings, previous_pose, control) @ ( - "steps", - "pose", - ) - - -path_model = path_model_step.partial_apply(default_motion_settings).accumulate() +path_model = ( + step_proposal.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() +) # result[0] ~~ robot_inputs['start'] + control_step[0] (which is zero) + noise # %% def generate_path_trace(key: PRNGKey) -> genjax.Trace: - return path_model.simulate( - key, (robot_inputs["start"], robot_inputs["controls"].prepend(noop_control)) - ) + return path_model.simulate(key, (robot_inputs["start"], robot_inputs["controls"])) def path_from_trace(tr: genjax.Trace) -> Pose: @@ -806,7 +782,8 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): key, sub_key = jax.random.split(key) trace = step_proposal.simulate( - sub_key, (default_motion_settings, robot_inputs["start"], noop_control) + sub_key, + (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) key, sub_key = jax.random.split(key) rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( @@ -1014,11 +991,6 @@ def noisy_sensor(pose): # # We fold the sensor model into the motion model to form a "full model", whose traces describe simulations of the entire robot situation as we have described it. # %% -# @genjax.gen -# def full_model_initial(motion_settings): -# pose = start_pose_prior(motion_settings, robot_inputs["start"]) @ "pose" -# sensor_model(pose, sensor_angles) @ "sensor" -# return pose @genjax.gen @@ -1032,7 +1004,7 @@ def full_model_kernel(motion_settings, state, control): def full_model(motion_settings): return ( full_model_kernel.partial_apply(motion_settings).scan()( - robot_inputs["start"], robot_inputs["controls"].prepend(noop_control) + robot_inputs["start"], robot_inputs["controls"] ) @ "steps" ) @@ -1041,7 +1013,6 @@ def full_model(motion_settings): def get_path(trace): # p = trace.get_subtrace(("initial",)).get_retval() ps = trace.get_retval()[1] - # return ps.prepend(p) return ps @@ -1139,9 +1110,9 @@ def constraint_from_sensors(readings): angle_indices = jnp.arange(len(sensor_angles)) return jax.vmap( lambda ix, v: C["steps", ix, "sensor", angle_indices, "distance"].set(v) - )(jnp.arange(T), readings[1:]) + C[ - "initial", "sensor", angle_indices, "distance" - ].set(readings[0]) + )(jnp.arange(T), readings) + C["initial", "sensor", angle_indices, "distance"].set( + readings[0] + ) constraints_low_deviation = constraint_from_sensors(observations_low_deviation) @@ -1245,20 +1216,14 @@ def plt(readings): # %% -# TODO(colin): if we prepended the noop-control once and for all, we could set T = len(controls) -# and get rid of this excess arithmetic def constraint_from_path(path): c_ps = jax.vmap(lambda ix, p: C["steps", ix, "pose", "p"].set(p))( - jnp.arange(T + 1), path.p + jnp.arange(T), path.p ) c_hds = jax.vmap(lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd))( - jnp.arange(T + 1), path.hd + jnp.arange(T), path.hd ) - - # c_p = C["initial", "pose", "p"].set(path.p[0]) - # c_hd = C["initial", "pose", "hd"].set(path.hd[0]) - return c_ps + c_hds # + c_p + c_hd @@ -1598,7 +1563,9 @@ def weighted_small_pose_plot(proposal, truth, poses, ws): sub_key, 1000, full_model_kernel( - motion_settings_low_deviation, robot_inputs["start"], noop_control + motion_settings_low_deviation, + robot_inputs["start"], + robot_inputs["controls"][0], ), observations_low_deviation[0], ) @@ -1665,7 +1632,9 @@ def initial_pose_chart(key): ) score_grid = grid_sample( full_model_kernel( - motion_settings_low_deviation, robot_inputs["start"], noop_control + motion_settings_low_deviation, + robot_inputs["start"], + robot_inputs["controls"][0], ), pose_grid, observations_low_deviation[0], @@ -1769,7 +1738,7 @@ def improved_step(state, update): # Our scan step starts at the initial step and applies a control input each time. # To make things balance, we need to add a zero step to the end of the control input # array, so that when we arrive at the final step, no more control input is given. - controls = robot_inputs["controls"].prepend(noop_control) + controls = robot_inputs["controls"] n_steps = len(controls) sub_keys = jax.random.split(key, n_steps + 1) return jax.lax.scan( @@ -1842,3 +1811,4 @@ def path_comparison_plot(*plots): frame_key="frame", ), ) | Plot.Slider("frame", 0, T, fps=2) +# %% From 37b986f9019a6c20273851a85db5c9a5770d3eae Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Thu, 7 Nov 2024 08:38:51 -0800 Subject: [PATCH 41/86] - add some missing cell dividers back - refresh jupyter version --- .../probcomp-localization-tutorial.ipynb | 334 ++++++++++-------- .../probcomp-localization-tutorial.py | 9 +- 2 files changed, 190 insertions(+), 153 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb index 476fc5f..ba86eb2 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb @@ -488,10 +488,9 @@ "outputs": [], "source": [ "def pose_plot(p, fill: str | Any = \"black\", **opts):\n", - " z = opts.get(\"zoom\", 1.0)\n", - " r = z * 0.15\n", + " r = opts.get(\"r\", 0.5)\n", " wing_opacity = opts.get(\"opacity\", 0.3)\n", - " WING_ANGLE, WING_LENGTH = jnp.pi / 12, z * opts.get(\"wing_length\", 0.6)\n", + " WING_ANGLE, WING_LENGTH = jnp.pi / 12, opts.get(\"wing_length\", 0.6)\n", " center = p.p\n", " angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1])\n", "\n", @@ -510,7 +509,7 @@ " )\n", "\n", " # Draw center dot\n", - " dot = Plot.ellipse([center], fill=fill, **({\"r\": r} | opts))\n", + " dot = Plot.ellipse([center], fill=fill, **opts)\n", "\n", " return wings + dot\n", "\n", @@ -705,6 +704,8 @@ "# Plot the world, starting pose samples, and 95% confidence region\n", "# Calculate the radius of the 95% confidence region\n", "def confidence_circle(pose: Pose, p_noise: float):\n", + " # TODO\n", + " # should this also take into account the hd_noise?\n", " return Plot.scaled_circle(\n", " *pose.p,\n", " fill=Plot.constantly(\"95% confidence region\"),\n", @@ -1473,13 +1474,18 @@ "\n", "\n", "def get_path(trace):\n", + " # p = trace.get_subtrace((\"initial\",)).get_retval()\n", " ps = trace.get_retval()[1]\n", " return ps\n", "\n", "\n", "def get_sensors(trace):\n", " ch = trace.get_choices()\n", - " return ch[\"steps\", :, \"sensor\", :, \"distance\"]\n", + " # return jnp.concatenate((\n", + " # ch[\"initial\", \"sensor\", ..., \"distance\"][jnp.newaxis],\n", + " # ch[\"steps\", ..., \"sensor\", ..., \"distance\"]\n", + " # ))\n", + " return ch[\"steps\", ..., \"sensor\", ..., \"distance\"]\n", "\n", "\n", "key, sub_key = jax.random.split(key)\n", @@ -2117,6 +2123,9 @@ "metadata": {}, "outputs": [], "source": [ + "categorical_sampler = jax.jit(genjax.categorical.sampler)\n", + "\n", + "\n", "def resample(\n", " key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int\n", "):\n", @@ -2124,7 +2133,7 @@ " samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))(\n", " jax.random.split(key1, N * K), constraints, (motion_settings,)\n", " )\n", - " winners = jax.vmap(genjax.categorical.sampler)(\n", + " winners = jax.vmap(categorical_sampler)(\n", " jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))\n", " )\n", " # indices returned are relative to the start of the K-segment from which they were drawn.\n", @@ -2134,14 +2143,12 @@ " return selected\n", "\n", "\n", - "jit_resample = jax.jit(resample, static_argnums=(3, 4))\n", - "\n", "key, sub_key = jax.random.split(key)\n", - "low_posterior = jit_resample(\n", + "low_posterior = resample(\n", " sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20\n", ")\n", "key, sub_key = jax.random.split(key)\n", - "high_posterior = jit_resample(\n", + "high_posterior = resample(\n", " sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20\n", ")" ] @@ -2218,19 +2225,14 @@ "outputs": [], "source": [ "def select_by_weight(key: PRNGKey, weights: FloatArray, things):\n", - " \"\"\"Makes a categorical selection from the vector object `things`\n", - " weighted by `weights`. The selected object is returned (with its\n", - " outermost axis removed) with its weight.\"\"\"\n", " chosen = jax.random.categorical(key, weights)\n", - " return jax.tree.map(lambda v: v[chosen], things), weights[chosen]" + " return jax.tree.map(lambda v: v[chosen], things)" ] }, { "cell_type": "markdown", "id": "105", - "metadata": { - "lines_to_next_cell": 0 - }, + "metadata": {}, "source": [ "Select an importance sample by weight in both the low and high deviation settings. It will be handy\n", "to have one path to work with to test our improvements." @@ -2240,21 +2242,19 @@ "cell_type": "code", "execution_count": null, "id": "106", - "metadata": { - "lines_to_next_cell": 0 - }, + "metadata": {}, "outputs": [], "source": [ "key, k1, k2 = jax.random.split(key, 3)\n", - "low_deviation_path, _ = select_by_weight(k1, low_weights, low_deviation_paths)\n", - "high_deviation_path, _ = select_by_weight(k2, high_weights, high_deviation_paths)" + "low_deviation_path = select_by_weight(k1, low_weights, low_deviation_paths)\n", + "high_deviation_path = select_by_weight(k2, high_weights, high_deviation_paths)" ] }, { "cell_type": "markdown", "id": "107", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 2 }, "source": [ "Create a choicemap that will enforce the given sensor observation" @@ -2268,7 +2268,7 @@ "outputs": [], "source": [ "def observation_to_choicemap(observation, pose=None):\n", - " sensor_cm = C[\"sensor\", :, \"distance\"].set(observation)\n", + " sensor_cm = C[\"sensor\", jnp.arange(len(observation)), \"distance\"].set(observation)\n", " pose_cm = (\n", " C[\"pose\", \"p\"].set(pose.p) + C[\"pose\", \"hd\"].set(pose.hd)\n", " if pose is not None\n", @@ -2284,8 +2284,9 @@ "lines_to_next_cell": 0 }, "source": [ - "Let's visualize a cloud of possible poses by coloring the elements proportional to their\n", - "plausibility under the sensor readingss." + "The first thing we'll try is a Boltzmann update: generate a cloud of nearby points\n", + "using the generative function we wrote, and weightedly select a replacement from that.\n", + "First, let's generate the cloud and visualize it." ] }, { @@ -2295,17 +2296,22 @@ "metadata": {}, "outputs": [], "source": [ - "def step_sample(key: PRNGKey, N: int, gf, observation):\n", - " tr, ws = jax.vmap(gf.importance, in_axes=(0, None, None))(\n", + "def boltzmann_sample(key: PRNGKey, N: int, gf, observation):\n", + " return jax.vmap(gf.importance, in_axes=(0, None, None))(\n", " jax.random.split(key, N), observation_to_choicemap(observation), ()\n", " )\n", - " return tr.get_retval()[0], ws\n", "\n", "\n", - "def weighted_small_pose_plot(proposal, truth, weights, poses, zoom=1):\n", - " max_logw = jnp.max(weights)\n", - " lse_ws = max_logw + jnp.log(jnp.sum(jnp.exp(weights - max_logw)))\n", - " scaled_ws = jnp.exp(weights - lse_ws)\n", + "def small_pose_plot(p: Pose, **opts):\n", + " \"\"\"This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose.\n", + " TODO: consider scaling r and wing_length based on the size of the plot domain.\"\"\"\n", + " opts = {\"r\": 0.001} | opts\n", + " return pose_plot(p, wing_length=0.006, **opts)\n", + "\n", + "\n", + "def weighted_small_pose_plot(proposal, truth, poses, ws):\n", + " lse_ws = jnp.log(jnp.sum(jnp.exp(ws)))\n", + " scaled_ws = jnp.exp(ws - lse_ws)\n", " max_scaled_w: FloatArray = jnp.max(scaled_ws)\n", " scaled_ws /= max_scaled_w\n", " # the following hack \"boosts\" lower scores a bit, to give us more visibility into\n", @@ -2313,11 +2319,10 @@ " # invisible without some adjustment, since the score distribution is concentrated\n", " # closely around 1.0\n", " scaled_ws = scaled_ws**0.3\n", - " z = 0.03 * zoom\n", " return Plot.new(\n", - " [pose_plot(p, fill=w, zoom=z) for p, w in zip(poses, scaled_ws)]\n", - " + pose_plot(proposal, fill=\"red\", zoom=z)\n", - " + pose_plot(truth, fill=\"green\", zoom=z)\n", + " [small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)]\n", + " + small_pose_plot(proposal, r=0.003, fill=\"red\")\n", + " + small_pose_plot(truth, r=0.003, fill=\"green\")\n", " ) + {\n", " \"color\": {\"type\": \"linear\", \"scheme\": \"OrRd\"},\n", " \"height\": 400,\n", @@ -2326,17 +2331,31 @@ " }" ] }, + { + "cell_type": "markdown", + "id": "111", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "For the first step we use the full_model_initial generative function. Subsequent steps\n", + "will use the full_model_kernel. In the case of the initial step, we have:\n", + "- the true initial position of the robot in green\n", + "- the robot's belief about its initial position in red\n", + "- a cloud of possible updates conditioned on the sensor data in shades of purple" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "111", + "id": "112", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "key, sub_key = jax.random.split(key)\n", - "step_poses, step_scores = step_sample(\n", + "bs = boltzmann_sample(\n", " sub_key,\n", " 1000,\n", " full_model_kernel(\n", @@ -2351,18 +2370,18 @@ { "cell_type": "code", "execution_count": null, - "id": "112", + "id": "113", "metadata": {}, "outputs": [], "source": [ "weighted_small_pose_plot(\n", - " path_low_deviation[0], robot_inputs[\"start\"], step_scores, step_poses\n", + " path_low_deviation[0], robot_inputs[\"start\"], bs[0].get_retval()[0], bs[1]\n", ")" ] }, { "cell_type": "markdown", - "id": "113", + "id": "114", "metadata": { "lines_to_next_cell": 0 }, @@ -2378,7 +2397,7 @@ { "cell_type": "code", "execution_count": null, - "id": "114", + "id": "115", "metadata": { "lines_to_next_cell": 2 }, @@ -2399,13 +2418,13 @@ { "cell_type": "code", "execution_count": null, - "id": "115", + "id": "116", "metadata": {}, "outputs": [], "source": [ - "def grid_sample(gf, pose_grid, observation):\n", + "def grid_sample(gf, pose_grid, observations):\n", " scores, _retvals = jax.vmap(\n", - " lambda pose: gf.assess(observation_to_choicemap(observation, pose), ())\n", + " lambda pose: gf.assess(observation_to_choicemap(observations, pose), ())\n", " )(pose_grid)\n", " return scores" ] @@ -2413,7 +2432,7 @@ { "cell_type": "code", "execution_count": null, - "id": "116", + "id": "117", "metadata": {}, "outputs": [], "source": [ @@ -2433,7 +2452,7 @@ }, { "cell_type": "markdown", - "id": "117", + "id": "118", "metadata": { "lines_to_next_cell": 0 }, @@ -2445,49 +2464,98 @@ { "cell_type": "code", "execution_count": null, - "id": "118", + "id": "119", "metadata": { - "lines_to_next_cell": 0 + "lines_to_next_cell": 2 }, "outputs": [], "source": [ "# Test our code for visualizing the Boltzmann and grid searches at the initial pose.\n", - "def first_step_chart(key):\n", + "def initial_pose_chart(key):\n", " cube_step_size = 6\n", " pose_grid = grid_of_nearby_poses(\n", " path_low_deviation[0], cube_step_size, motion_settings_low_deviation\n", " )\n", - " gf = full_model_kernel(\n", - " motion_settings_low_deviation,\n", - " robot_inputs[\"start\"],\n", - " robot_inputs[\"controls\"][0],\n", - " )\n", " score_grid = grid_sample(\n", - " gf,\n", + " full_model_kernel(\n", + " motion_settings_low_deviation,\n", + " robot_inputs[\"start\"],\n", + " robot_inputs[\"controls\"][0],\n", + " ),\n", " pose_grid,\n", " observations_low_deviation[0],\n", " )\n", - " step_poses, step_scores = step_sample(\n", - " key,\n", - " 1000,\n", - " gf,\n", - " observations_low_deviation[0],\n", - " )\n", " pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid)\n", " return weighted_small_pose_plot(\n", - " path_low_deviation[0], robot_inputs[\"start\"], score_plane, pose_plane\n", + " path_low_deviation[0], robot_inputs[\"start\"], pose_plane, score_plane\n", " ) & weighted_small_pose_plot(\n", - " path_low_deviation[0], robot_inputs[\"start\"], step_scores, step_poses\n", + " path_low_deviation[0], robot_inputs[\"start\"], bs[0].get_retval()[0], bs[1]\n", " )\n", "\n", "\n", "key, sub_key = jax.random.split(key)\n", - "first_step_chart(sub_key)" + "initial_pose_chart(sub_key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "120", + "metadata": {}, + "outputs": [], + "source": [ + "# See if this works for other points in the path\n", + "\n", + "\n", + "def improvements_at_step(key, path, k):\n", + " gf = full_model_kernel(\n", + " motion_settings_low_deviation, path[k - 1], robot_inputs[\"controls\"][k - 1]\n", + " )\n", + " cube_step_size = 6\n", + " bs = boltzmann_sample(k1, 500, gf, observations_low_deviation[k])\n", + " print(\n", + " f'from {path[k-1]}, step {robot_inputs['controls'][k-1]}, truth {path_low_deviation[k]}, ps {bs[0].get_retval()[0]}'\n", + " )\n", + " p1 = weighted_small_pose_plot(\n", + " path[k], path_low_deviation[k], bs[0].get_retval()[0], bs[1]\n", + " )\n", + " pose_grid = grid_of_nearby_poses(\n", + " path[k], cube_step_size, motion_settings_low_deviation\n", + " )\n", + " score_grid = grid_sample(gf, pose_grid, observations_low_deviation[k])\n", + " pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid)\n", + " print(f\"score_plane {score_plane}\")\n", + " p2 = weighted_small_pose_plot(\n", + " path[k], path_low_deviation[k], pose_plane, score_plane\n", + " )\n", + " return p1 & p2\n", + "\n", + "\n", + "key, sub_key = jax.random.split(key)\n", + "improvements_at_step(sub_key, low_deviation_path, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "121", + "metadata": {}, + "outputs": [], + "source": [ + "# Animation of the above\n", + "key, *sub_keys = jax.random.split(key, 5)\n", + "Plot.Frames(\n", + " [\n", + " improvements_at_step(k, low_deviation_path, i + 1)\n", + " for i, k in enumerate(sub_keys)\n", + " ],\n", + " fps=2,\n", + ")" ] }, { "cell_type": "markdown", - "id": "119", + "id": "122", "metadata": { "lines_to_next_cell": 2 }, @@ -2507,51 +2575,65 @@ { "cell_type": "code", "execution_count": null, - "id": "120", + "id": "123", "metadata": {}, "outputs": [], "source": [ - "def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray):\n", - " cube_step_size = 8\n", + "def improved_path(\n", + " mode: str, key: PRNGKey, motion_settings: dict, observations: FloatArray\n", + "):\n", + " def boltzmann_step(k: PRNGKey, gf, _center_pose, observation):\n", + " k1, k2 = jax.random.split(k, 2)\n", + " trs, ws = boltzmann_sample(k1, 1000, gf, observation)\n", + " return ws, trs.get_retval()[0]\n", "\n", " def grid_search_step(k: PRNGKey, gf, center_pose, observation):\n", - " pose_grid = grid_of_nearby_poses(center_pose, cube_step_size, motion_settings)\n", + " pose_grid = grid_of_nearby_poses(center_pose, 15, motion_settings)\n", " nearby_weights = grid_sample(gf, pose_grid, observation)\n", " return nearby_weights, pose_grid\n", "\n", " def improved_step(state, update):\n", " observation, control, key = update\n", " gf = full_model_kernel(motion_settings, state, control)\n", - " # Run a sample and pick an element by weight.\n", " k1, k2, k3 = jax.random.split(key, 3)\n", - " poses, scores = step_sample(k1, 1000, gf, observation)\n", - " new_pose, new_weight = select_by_weight(k2, scores, poses)\n", - " weights2, poses2 = grid_search_step(k2, gf, new_pose, observation)\n", - " # Note that `new_pose` will be among the poses considered by grid_search_step,\n", - " # so the possibility exists to remain stationary, as Bayesian inference requires\n", - " chosen_pose, _ = select_by_weight(k3, weights2, poses2)\n", - " flat_poses, flat_scores = flatten_pose_cube(poses2, cube_step_size, weights2)\n", - " return chosen_pose, (new_pose, chosen_pose, flat_scores, flat_poses, new_weight)\n", - "\n", - " sub_keys = jax.random.split(key, T + 1)\n", + " # First, just run the model.\n", + " tr = gf.simulate(k1, ())\n", + " new_pose = tr.get_retval()[0]\n", + " improver = {\"grid\": grid_search_step, \"boltzmann\": boltzmann_step}[mode]\n", + " # Run the improver, and add the candidate point to the list of weights and\n", + " # return values, to create the possibility of accepting the initial proposal\n", + " # as well as any of the improvement candidates, as Bayesian inference requires\n", + " weights, poses = improver(k2, gf, new_pose, observation)\n", + " # weights = jnp.append(weights, tr.get_score())\n", + " # poses = Pose(jnp.vstack(poses.p, new_pose.p), jnp.append(poses.hd, new_pose.hd))\n", + " chosen_pose = select_by_weight(k3, weights, poses)\n", + " return chosen_pose, chosen_pose\n", + "\n", + " # We have one fewer control than step, since no step got us to the initial position.\n", + " # Our scan step starts at the initial step and applies a control input each time.\n", + " # To make things balance, we need to add a zero step to the end of the control input\n", + " # array, so that when we arrive at the final step, no more control input is given.\n", + " controls = robot_inputs[\"controls\"]\n", + " n_steps = len(controls)\n", + " sub_keys = jax.random.split(key, n_steps + 1)\n", " return jax.lax.scan(\n", " improved_step,\n", " robot_inputs[\"start\"],\n", " (\n", " observations, # observation at time t\n", - " robot_inputs[\"controls\"], # guides step from t to t+1\n", + " controls, # guides step from t to t+1\n", " sub_keys[1:],\n", " ),\n", " )\n", "\n", "\n", - "jit_improved_path = jax.jit(improved_path)" + "jit_improved_path = jax.jit(improved_path, static_argnums=0)" ] }, { "cell_type": "code", "execution_count": null, - "id": "121", + "id": "124", "metadata": { "lines_to_next_cell": 0 }, @@ -2559,21 +2641,15 @@ "source": [ "key, sub_key = jax.random.split(key)\n", "_, improved_low = jit_improved_path(\n", - " sub_key, motion_settings_low_deviation, observations_low_deviation\n", - ")\n", - "key, sub_key = jax.random.split(key)\n", - "_, improved_high = jit_improved_path(\n", - " sub_key, motion_settings_high_deviation, observations_high_deviation\n", + " \"grid\", sub_key, motion_settings_low_deviation, observations_low_deviation\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "122", - "metadata": { - "lines_to_next_cell": 0 - }, + "id": "125", + "metadata": {}, "outputs": [], "source": [ "def path_comparison_plot(*plots):\n", @@ -2597,34 +2673,38 @@ { "cell_type": "code", "execution_count": null, - "id": "123", + "id": "126", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "path_comparison_plot(\n", - " improved_low[0], path_integrated, low_deviation_path, path_low_deviation\n", + " improved_low, path_integrated, low_deviation_path, path_low_deviation\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "124", + "id": "127", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ + "key, sub_key = jax.random.split(key)\n", + "_, improved_high = jit_improved_path(\n", + " \"boltzmann\", sub_key, motion_settings_high_deviation, observations_high_deviation\n", + ")\n", "path_comparison_plot(\n", - " improved_high[0], path_integrated, high_deviation_path, path_high_deviation\n", + " improved_high, path_integrated, high_deviation_path, path_high_deviation\n", ")" ] }, { "cell_type": "markdown", - "id": "125", + "id": "128", "metadata": { "lines_to_next_cell": 0 }, @@ -2637,7 +2717,7 @@ { "cell_type": "code", "execution_count": null, - "id": "126", + "id": "129", "metadata": { "lines_to_next_cell": 0 }, @@ -2645,7 +2725,7 @@ "source": [ "Plot.Row(\n", " animate_path_and_sensors(\n", - " improved_high[0],\n", + " improved_high,\n", " observations_high_deviation,\n", " motion_settings_high_deviation,\n", " frame_key=\"frame\",\n", @@ -2656,57 +2736,7 @@ " motion_settings_high_deviation,\n", " frame_key=\"frame\",\n", " ),\n", - ") | Plot.Slider(\"frame\", 0, T - 1, fps=2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "127", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "# Finishing touch: weave together the improved plot and the improvement steps\n", - "# into a slider animation\n", - "# Plot.Frames(\n", - "# [weighted_small_pose_plot(improved_high[0][k], path_high_deviation[k], improved_high[2][k], improved_high[1][k]) for k in range(T)],\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "128", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "def wsp_frame(k):\n", - " return path_comparison_plot(\n", - " improved_high[0][: k + 1],\n", - " path_integrated[: k + 1],\n", - " high_deviation_path[: k + 1],\n", - " path_high_deviation[: k + 1],\n", - " ) & weighted_small_pose_plot(\n", - " improved_high[1][k],\n", - " path_high_deviation[k],\n", - " improved_high[2][k],\n", - " improved_high[3][k],\n", - " zoom=4,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "129", - "metadata": {}, - "outputs": [], - "source": [ - "Plot.Frames([wsp_frame(k) for k in range(1, 6)])" + ") | Plot.Slider(\"frame\", 0, T, fps=2)" ] }, { diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index ccb70b7..7fd77a3 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1485,6 +1485,7 @@ def animate_path_as_line(path, **options): # earlier. +# %% def select_by_weight(key: PRNGKey, weights: FloatArray, things): chosen = jax.random.categorical(key, weights) return jax.tree.map(lambda v: v[chosen], things) @@ -1494,6 +1495,7 @@ def select_by_weight(key: PRNGKey, weights: FloatArray, things): # Select an importance sample by weight in both the low and high deviation settings. It will be handy # to have one path to work with to test our improvements. +# %% key, k1, k2 = jax.random.split(key, 3) low_deviation_path = select_by_weight(k1, low_weights, low_deviation_paths) high_deviation_path = select_by_weight(k2, high_weights, high_deviation_paths) @@ -1502,6 +1504,7 @@ def select_by_weight(key: PRNGKey, weights: FloatArray, things): # Create a choicemap that will enforce the given sensor observation +# %% def observation_to_choicemap(observation, pose=None): sensor_cm = C["sensor", jnp.arange(len(observation)), "distance"].set(observation) pose_cm = ( @@ -1558,6 +1561,7 @@ def weighted_small_pose_plot(proposal, truth, poses, ws): # - the true initial position of the robot in green # - the robot's belief about its initial position in red # - a cloud of possible updates conditioned on the sensor data in shades of purple +# %% key, sub_key = jax.random.split(key) bs = boltzmann_sample( sub_key, @@ -1651,8 +1655,10 @@ def initial_pose_chart(key): initial_pose_chart(sub_key) -# %% [markdown] +# %% # See if this works for other points in the path + + def improvements_at_step(key, path, k): gf = full_model_kernel( motion_settings_low_deviation, path[k - 1], robot_inputs["controls"][k - 1] @@ -1704,6 +1710,7 @@ def improvements_at_step(key, path, k): # pose.weight = full_model_kernel.assess(p, (cm,)) +# %% def improved_path( mode: str, key: PRNGKey, motion_settings: dict, observations: FloatArray ): From 5141d4993891155999a9044e8bd94d8a60f3a694 Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Mon, 11 Nov 2024 19:10:24 -0800 Subject: [PATCH 42/86] more visualization, doc touchups --- .../probcomp-localization-tutorial.py | 240 +++++++----------- 1 file changed, 97 insertions(+), 143 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 7fd77a3..1881b24 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -393,9 +393,10 @@ def integrate_controls_physical(robot_inputs): # ### Plot such data # %% def pose_plot(p, fill: str | Any = "black", **opts): - r = opts.get("r", 0.5) + z = opts.get("zoom", 1.0) + r = z * 0.15 wing_opacity = opts.get("opacity", 0.3) - WING_ANGLE, WING_LENGTH = jnp.pi / 12, opts.get("wing_length", 0.6) + WING_ANGLE, WING_LENGTH = jnp.pi / 12, z * opts.get("wing_length", 0.6) center = p.p angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1]) @@ -414,7 +415,7 @@ def pose_plot(p, fill: str | Any = "black", **opts): ) # Draw center dot - dot = Plot.ellipse([center], fill=fill, **opts) + dot = Plot.ellipse([center], fill=fill, **({'r': r} | opts)) return wings + dot @@ -1011,18 +1012,13 @@ def full_model(motion_settings): def get_path(trace): - # p = trace.get_subtrace(("initial",)).get_retval() ps = trace.get_retval()[1] return ps def get_sensors(trace): ch = trace.get_choices() - # return jnp.concatenate(( - # ch["initial", "sensor", ..., "distance"][jnp.newaxis], - # ch["steps", ..., "sensor", ..., "distance"] - # )) - return ch["steps", ..., "sensor", ..., "distance"] + return ch["steps", :, "sensor", :, "distance"] key, sub_key = jax.random.split(key) @@ -1408,7 +1404,6 @@ def constraint_from_path(path): # %% - def resample( key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int ): @@ -1416,7 +1411,7 @@ def resample( samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( jax.random.split(key1, N * K), constraints, (motion_settings,) ) - winners = jax.vmap(categorical_sampler)( + winners = jax.vmap(genjax.categorical.sampler)( jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) ) # indices returned are relative to the start of the K-segment from which they were drawn. @@ -1425,13 +1420,14 @@ def resample( selected = jax.tree.map(lambda x: x[winners], samples) return selected +jit_resample = jax.jit(resample, static_argnums=(3,4)) key, sub_key = jax.random.split(key) -low_posterior = resample( +low_posterior = jit_resample( sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20 ) key, sub_key = jax.random.split(key) -high_posterior = resample( +high_posterior = jit_resample( sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20 ) @@ -1487,26 +1483,25 @@ def animate_path_as_line(path, **options): # %% def select_by_weight(key: PRNGKey, weights: FloatArray, things): + """Makes a categorical selection from the vector object `things` + weighted by `weights`. The selected object is returned (with its + outermost axis removed) with its weight.""" chosen = jax.random.categorical(key, weights) - return jax.tree.map(lambda v: v[chosen], things) + return jax.tree.map(lambda v: v[chosen], things), weights[chosen] # %% [markdown] # Select an importance sample by weight in both the low and high deviation settings. It will be handy # to have one path to work with to test our improvements. - # %% key, k1, k2 = jax.random.split(key, 3) -low_deviation_path = select_by_weight(k1, low_weights, low_deviation_paths) -high_deviation_path = select_by_weight(k2, high_weights, high_deviation_paths) - +low_deviation_path, _ = select_by_weight(k1, low_weights, low_deviation_paths) +high_deviation_path, _ = select_by_weight(k2, high_weights, high_deviation_paths) # %% [markdown] # Create a choicemap that will enforce the given sensor observation - - # %% def observation_to_choicemap(observation, pose=None): - sensor_cm = C["sensor", jnp.arange(len(observation)), "distance"].set(observation) + sensor_cm = C["sensor", :, "distance"].set(observation) pose_cm = ( C["pose", "p"].set(pose.p) + C["pose", "hd"].set(pose.hd) if pose is not None @@ -1516,26 +1511,20 @@ def observation_to_choicemap(observation, pose=None): # %% [markdown] -# The first thing we'll try is a Boltzmann update: generate a cloud of nearby points -# using the generative function we wrote, and weightedly select a replacement from that. -# First, let's generate the cloud and visualize it. +# Let's visualize a cloud of possible poses by coloring the elements proportional to their +# plausibility under the sensor readingss. # %% -def boltzmann_sample(key: PRNGKey, N: int, gf, observation): - return jax.vmap(gf.importance, in_axes=(0, None, None))( +def step_sample(key: PRNGKey, N: int, gf, observation): + tr, ws = jax.vmap(gf.importance, in_axes=(0, None, None))( jax.random.split(key, N), observation_to_choicemap(observation), () ) + return tr.get_retval()[0], ws -def small_pose_plot(p: Pose, **opts): - """This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose. - TODO: consider scaling r and wing_length based on the size of the plot domain.""" - opts = {"r": 0.001} | opts - return pose_plot(p, wing_length=0.006, **opts) - - -def weighted_small_pose_plot(proposal, truth, poses, ws): - lse_ws = jnp.log(jnp.sum(jnp.exp(ws))) - scaled_ws = jnp.exp(ws - lse_ws) +def weighted_small_pose_plot(proposal, truth, weights, poses, zoom=1): + max_logw = jnp.max(weights) + lse_ws = max_logw + jnp.log(jnp.sum(jnp.exp(weights - max_logw))) + scaled_ws = jnp.exp(weights - lse_ws) max_scaled_w: FloatArray = jnp.max(scaled_ws) scaled_ws /= max_scaled_w # the following hack "boosts" lower scores a bit, to give us more visibility into @@ -1543,10 +1532,11 @@ def weighted_small_pose_plot(proposal, truth, poses, ws): # invisible without some adjustment, since the score distribution is concentrated # closely around 1.0 scaled_ws = scaled_ws**0.3 + z = 0.03 * zoom return Plot.new( - [small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)] - + small_pose_plot(proposal, r=0.003, fill="red") - + small_pose_plot(truth, r=0.003, fill="green") + [pose_plot(p, fill=w, zoom=z) for p, w in zip(poses, scaled_ws)] + + pose_plot(proposal, fill="red", zoom=z) + + pose_plot(truth, fill="green", zoom=z) ) + { "color": {"type": "linear", "scheme": "OrRd"}, "height": 400, @@ -1555,15 +1545,9 @@ def weighted_small_pose_plot(proposal, truth, poses, ws): } -# %% [markdown] -# For the first step we use the full_model_initial generative function. Subsequent steps -# will use the full_model_kernel. In the case of the initial step, we have: -# - the true initial position of the robot in green -# - the robot's belief about its initial position in red -# - a cloud of possible updates conditioned on the sensor data in shades of purple # %% key, sub_key = jax.random.split(key) -bs = boltzmann_sample( +step_poses, step_scores = step_sample( sub_key, 1000, full_model_kernel( @@ -1575,7 +1559,7 @@ def weighted_small_pose_plot(proposal, truth, poses, ws): ) # %% weighted_small_pose_plot( - path_low_deviation[0], robot_inputs["start"], bs[0].get_retval()[0], bs[1] + path_low_deviation[0], robot_inputs["start"], step_scores, step_poses ) @@ -1602,9 +1586,9 @@ def grid_of_nearby_poses(p, n, motion_settings): # %% -def grid_sample(gf, pose_grid, observations): +def grid_sample(gf, pose_grid, observation): scores, _retvals = jax.vmap( - lambda pose: gf.assess(observation_to_choicemap(observations, pose), ()) + lambda pose: gf.assess(observation_to_choicemap(observation, pose), ()) )(pose_grid) return scores @@ -1629,74 +1613,37 @@ def flatten_pose_cube(pose_grid, cube_step_size, scores): # search and importance sampling techniques. # %% # Test our code for visualizing the Boltzmann and grid searches at the initial pose. -def initial_pose_chart(key): +def first_step_chart(key): cube_step_size = 6 pose_grid = grid_of_nearby_poses( path_low_deviation[0], cube_step_size, motion_settings_low_deviation ) + gf = full_model_kernel( + motion_settings_low_deviation, + robot_inputs["start"], + robot_inputs["controls"][0], + ) score_grid = grid_sample( - full_model_kernel( - motion_settings_low_deviation, - robot_inputs["start"], - robot_inputs["controls"][0], - ), + gf, pose_grid, observations_low_deviation[0], ) + step_poses, step_scores = step_sample( + key, + 1000, + gf, + observations_low_deviation[0], + ) pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid) return weighted_small_pose_plot( - path_low_deviation[0], robot_inputs["start"], pose_plane, score_plane + path_low_deviation[0], robot_inputs["start"], score_plane, pose_plane ) & weighted_small_pose_plot( - path_low_deviation[0], robot_inputs["start"], bs[0].get_retval()[0], bs[1] - ) - - -key, sub_key = jax.random.split(key) -initial_pose_chart(sub_key) - - -# %% -# See if this works for other points in the path - - -def improvements_at_step(key, path, k): - gf = full_model_kernel( - motion_settings_low_deviation, path[k - 1], robot_inputs["controls"][k - 1] - ) - cube_step_size = 6 - bs = boltzmann_sample(k1, 500, gf, observations_low_deviation[k]) - print( - f'from {path[k-1]}, step {robot_inputs['controls'][k-1]}, truth {path_low_deviation[k]}, ps {bs[0].get_retval()[0]}' + path_low_deviation[0], robot_inputs["start"], step_scores, step_poses ) - p1 = weighted_small_pose_plot( - path[k], path_low_deviation[k], bs[0].get_retval()[0], bs[1] - ) - pose_grid = grid_of_nearby_poses( - path[k], cube_step_size, motion_settings_low_deviation - ) - score_grid = grid_sample(gf, pose_grid, observations_low_deviation[k]) - pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid) - print(f"score_plane {score_plane}") - p2 = weighted_small_pose_plot( - path[k], path_low_deviation[k], pose_plane, score_plane - ) - return p1 & p2 key, sub_key = jax.random.split(key) -improvements_at_step(sub_key, low_deviation_path, 5) - -# %% -# Animation of the above -key, *sub_keys = jax.random.split(key, 5) -Plot.Frames( - [ - improvements_at_step(k, low_deviation_path, i + 1) - for i, k in enumerate(sub_keys) - ], - fps=2, -) - +first_step_chart(sub_key) # %% [markdown] # Now let's try doing the whole path. We want to produce something that is ultimately # scan-compatible, so it should have the form state -> update -> new_state. The state @@ -1712,63 +1659,53 @@ def improvements_at_step(key, path, k): # %% def improved_path( - mode: str, key: PRNGKey, motion_settings: dict, observations: FloatArray + key: PRNGKey, motion_settings: dict, observations: FloatArray ): - def boltzmann_step(k: PRNGKey, gf, _center_pose, observation): - k1, k2 = jax.random.split(k, 2) - trs, ws = boltzmann_sample(k1, 1000, gf, observation) - return ws, trs.get_retval()[0] + cube_step_size = 8 def grid_search_step(k: PRNGKey, gf, center_pose, observation): - pose_grid = grid_of_nearby_poses(center_pose, 15, motion_settings) + pose_grid = grid_of_nearby_poses(center_pose, cube_step_size, motion_settings) nearby_weights = grid_sample(gf, pose_grid, observation) return nearby_weights, pose_grid def improved_step(state, update): observation, control, key = update gf = full_model_kernel(motion_settings, state, control) + # Run a sample and pick an element by weight. k1, k2, k3 = jax.random.split(key, 3) - # First, just run the model. - tr = gf.simulate(k1, ()) - new_pose = tr.get_retval()[0] - improver = {"grid": grid_search_step, "boltzmann": boltzmann_step}[mode] - # Run the improver, and add the candidate point to the list of weights and - # return values, to create the possibility of accepting the initial proposal - # as well as any of the improvement candidates, as Bayesian inference requires - weights, poses = improver(k2, gf, new_pose, observation) - # weights = jnp.append(weights, tr.get_score()) - # poses = Pose(jnp.vstack(poses.p, new_pose.p), jnp.append(poses.hd, new_pose.hd)) - chosen_pose = select_by_weight(k3, weights, poses) - return chosen_pose, chosen_pose - - # We have one fewer control than step, since no step got us to the initial position. - # Our scan step starts at the initial step and applies a control input each time. - # To make things balance, we need to add a zero step to the end of the control input - # array, so that when we arrive at the final step, no more control input is given. - controls = robot_inputs["controls"] - n_steps = len(controls) - sub_keys = jax.random.split(key, n_steps + 1) + poses, scores = step_sample(k1, 1000, gf, observation) + new_pose, new_weight = select_by_weight(k2, scores, poses) + weights2, poses2 = grid_search_step(k2, gf, new_pose, observation) + # Note that `new_pose` will be among the poses considered by grid_search_step, + # so the possibility exists to remain stationary, as Bayesian inference requires + chosen_pose, _ = select_by_weight(k3, weights2, poses2) + flat_poses, flat_scores = flatten_pose_cube(poses2, cube_step_size, weights2) + return chosen_pose, (new_pose, chosen_pose, flat_scores, flat_poses, new_weight) + + sub_keys = jax.random.split(key, T + 1) return jax.lax.scan( improved_step, robot_inputs["start"], ( observations, # observation at time t - controls, # guides step from t to t+1 + robot_inputs['controls'], # guides step from t to t+1 sub_keys[1:], ), ) -jit_improved_path = jax.jit(improved_path, static_argnums=0) +jit_improved_path = jax.jit(improved_path) # %% key, sub_key = jax.random.split(key) _, improved_low = jit_improved_path( - "grid", sub_key, motion_settings_low_deviation, observations_low_deviation + sub_key, motion_settings_low_deviation, observations_low_deviation +) +key, sub_key = jax.random.split(key) +_, improved_high = jit_improved_path( + sub_key, motion_settings_high_deviation, observations_high_deviation ) # %% - - def path_comparison_plot(*plots): types = ["improved", "integrated", "importance", "true"] plot = world_plot @@ -1785,19 +1722,13 @@ def path_comparison_plot(*plots): "importance": "red", } ) - - # %% path_comparison_plot( - improved_low, path_integrated, low_deviation_path, path_low_deviation + improved_low[0], path_integrated, low_deviation_path, path_low_deviation ) # %% -key, sub_key = jax.random.split(key) -_, improved_high = jit_improved_path( - "boltzmann", sub_key, motion_settings_high_deviation, observations_high_deviation -) path_comparison_plot( - improved_high, path_integrated, high_deviation_path, path_high_deviation + improved_high[0], path_integrated, high_deviation_path, path_high_deviation ) # %% [markdown] # To see how the grid search improves poses, we play back the grid-search path @@ -1806,7 +1737,7 @@ def path_comparison_plot(*plots): # %% Plot.Row( animate_path_and_sensors( - improved_high, + improved_high[0], observations_high_deviation, motion_settings_high_deviation, frame_key="frame", @@ -1817,5 +1748,28 @@ def path_comparison_plot(*plots): motion_settings_high_deviation, frame_key="frame", ), -) | Plot.Slider("frame", 0, T, fps=2) +) | Plot.Slider("frame", 0, T-1, fps=2) +# %% +# Finishing touch: weave together the improved plot and the improvement steps +# into a slider animation +# Plot.Frames( +# [weighted_small_pose_plot(improved_high[0][k], path_high_deviation[k], improved_high[2][k], improved_high[1][k]) for k in range(T)], +# ) +# %% +def wsp_frame(k): + return path_comparison_plot( + improved_high[0][:k+1], + path_integrated[:k+1], + high_deviation_path[:k+1], + path_high_deviation[:k+1] + ) & weighted_small_pose_plot( + improved_high[1][k], + path_high_deviation[k], + improved_high[2][k], + improved_high[3][k], + zoom=4 + ) +# %% +Plot.Frames([wsp_frame(k) for k in range(1,6)]) + # %% From 638be108064aee14e7745d6ed8c1bc7a84fb2825 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Nov 2024 03:10:43 +0000 Subject: [PATCH 43/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../probcomp-localization-tutorial.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 1881b24..ba16efb 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -415,7 +415,7 @@ def pose_plot(p, fill: str | Any = "black", **opts): ) # Draw center dot - dot = Plot.ellipse([center], fill=fill, **({'r': r} | opts)) + dot = Plot.ellipse([center], fill=fill, **({"r": r} | opts)) return wings + dot @@ -1404,6 +1404,7 @@ def constraint_from_path(path): # %% + def resample( key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int ): @@ -1420,7 +1421,8 @@ def resample( selected = jax.tree.map(lambda x: x[winners], samples) return selected -jit_resample = jax.jit(resample, static_argnums=(3,4)) + +jit_resample = jax.jit(resample, static_argnums=(3, 4)) key, sub_key = jax.random.split(key) low_posterior = jit_resample( @@ -1497,6 +1499,8 @@ def select_by_weight(key: PRNGKey, weights: FloatArray, things): key, k1, k2 = jax.random.split(key, 3) low_deviation_path, _ = select_by_weight(k1, low_weights, low_deviation_paths) high_deviation_path, _ = select_by_weight(k2, high_weights, high_deviation_paths) + + # %% [markdown] # Create a choicemap that will enforce the given sensor observation # %% @@ -1658,9 +1662,7 @@ def first_step_chart(key): # %% -def improved_path( - key: PRNGKey, motion_settings: dict, observations: FloatArray -): +def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray): cube_step_size = 8 def grid_search_step(k: PRNGKey, gf, center_pose, observation): @@ -1688,7 +1690,7 @@ def improved_step(state, update): robot_inputs["start"], ( observations, # observation at time t - robot_inputs['controls'], # guides step from t to t+1 + robot_inputs["controls"], # guides step from t to t+1 sub_keys[1:], ), ) @@ -1705,6 +1707,8 @@ def improved_step(state, update): _, improved_high = jit_improved_path( sub_key, motion_settings_high_deviation, observations_high_deviation ) + + # %% def path_comparison_plot(*plots): types = ["improved", "integrated", "importance", "true"] @@ -1722,6 +1726,8 @@ def path_comparison_plot(*plots): "importance": "red", } ) + + # %% path_comparison_plot( improved_low[0], path_integrated, low_deviation_path, path_low_deviation @@ -1748,7 +1754,9 @@ def path_comparison_plot(*plots): motion_settings_high_deviation, frame_key="frame", ), -) | Plot.Slider("frame", 0, T-1, fps=2) +) | Plot.Slider("frame", 0, T - 1, fps=2) + + # %% # Finishing touch: weave together the improved plot and the improvement steps # into a slider animation @@ -1758,18 +1766,20 @@ def path_comparison_plot(*plots): # %% def wsp_frame(k): return path_comparison_plot( - improved_high[0][:k+1], - path_integrated[:k+1], - high_deviation_path[:k+1], - path_high_deviation[:k+1] + improved_high[0][: k + 1], + path_integrated[: k + 1], + high_deviation_path[: k + 1], + path_high_deviation[: k + 1], ) & weighted_small_pose_plot( improved_high[1][k], path_high_deviation[k], improved_high[2][k], improved_high[3][k], - zoom=4 + zoom=4, ) + + # %% -Plot.Frames([wsp_frame(k) for k in range(1,6)]) +Plot.Frames([wsp_frame(k) for k in range(1, 6)]) # %% From 7fca2afce93c79c806ffeab00bde8627b7c1063e Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Mon, 11 Nov 2024 19:30:42 -0800 Subject: [PATCH 44/86] update ipynb version --- .../probcomp-localization-tutorial.ipynb | 334 ++++++++---------- 1 file changed, 152 insertions(+), 182 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb index ba86eb2..476fc5f 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.ipynb @@ -488,9 +488,10 @@ "outputs": [], "source": [ "def pose_plot(p, fill: str | Any = \"black\", **opts):\n", - " r = opts.get(\"r\", 0.5)\n", + " z = opts.get(\"zoom\", 1.0)\n", + " r = z * 0.15\n", " wing_opacity = opts.get(\"opacity\", 0.3)\n", - " WING_ANGLE, WING_LENGTH = jnp.pi / 12, opts.get(\"wing_length\", 0.6)\n", + " WING_ANGLE, WING_LENGTH = jnp.pi / 12, z * opts.get(\"wing_length\", 0.6)\n", " center = p.p\n", " angle = jnp.arctan2(*(center - p.step_along(-r).p)[::-1])\n", "\n", @@ -509,7 +510,7 @@ " )\n", "\n", " # Draw center dot\n", - " dot = Plot.ellipse([center], fill=fill, **opts)\n", + " dot = Plot.ellipse([center], fill=fill, **({\"r\": r} | opts))\n", "\n", " return wings + dot\n", "\n", @@ -704,8 +705,6 @@ "# Plot the world, starting pose samples, and 95% confidence region\n", "# Calculate the radius of the 95% confidence region\n", "def confidence_circle(pose: Pose, p_noise: float):\n", - " # TODO\n", - " # should this also take into account the hd_noise?\n", " return Plot.scaled_circle(\n", " *pose.p,\n", " fill=Plot.constantly(\"95% confidence region\"),\n", @@ -1474,18 +1473,13 @@ "\n", "\n", "def get_path(trace):\n", - " # p = trace.get_subtrace((\"initial\",)).get_retval()\n", " ps = trace.get_retval()[1]\n", " return ps\n", "\n", "\n", "def get_sensors(trace):\n", " ch = trace.get_choices()\n", - " # return jnp.concatenate((\n", - " # ch[\"initial\", \"sensor\", ..., \"distance\"][jnp.newaxis],\n", - " # ch[\"steps\", ..., \"sensor\", ..., \"distance\"]\n", - " # ))\n", - " return ch[\"steps\", ..., \"sensor\", ..., \"distance\"]\n", + " return ch[\"steps\", :, \"sensor\", :, \"distance\"]\n", "\n", "\n", "key, sub_key = jax.random.split(key)\n", @@ -2123,9 +2117,6 @@ "metadata": {}, "outputs": [], "source": [ - "categorical_sampler = jax.jit(genjax.categorical.sampler)\n", - "\n", - "\n", "def resample(\n", " key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int\n", "):\n", @@ -2133,7 +2124,7 @@ " samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))(\n", " jax.random.split(key1, N * K), constraints, (motion_settings,)\n", " )\n", - " winners = jax.vmap(categorical_sampler)(\n", + " winners = jax.vmap(genjax.categorical.sampler)(\n", " jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))\n", " )\n", " # indices returned are relative to the start of the K-segment from which they were drawn.\n", @@ -2143,12 +2134,14 @@ " return selected\n", "\n", "\n", + "jit_resample = jax.jit(resample, static_argnums=(3, 4))\n", + "\n", "key, sub_key = jax.random.split(key)\n", - "low_posterior = resample(\n", + "low_posterior = jit_resample(\n", " sub_key, constraints_low_deviation, motion_settings_low_deviation, 2000, 20\n", ")\n", "key, sub_key = jax.random.split(key)\n", - "high_posterior = resample(\n", + "high_posterior = jit_resample(\n", " sub_key, constraints_high_deviation, motion_settings_high_deviation, 2000, 20\n", ")" ] @@ -2225,14 +2218,19 @@ "outputs": [], "source": [ "def select_by_weight(key: PRNGKey, weights: FloatArray, things):\n", + " \"\"\"Makes a categorical selection from the vector object `things`\n", + " weighted by `weights`. The selected object is returned (with its\n", + " outermost axis removed) with its weight.\"\"\"\n", " chosen = jax.random.categorical(key, weights)\n", - " return jax.tree.map(lambda v: v[chosen], things)" + " return jax.tree.map(lambda v: v[chosen], things), weights[chosen]" ] }, { "cell_type": "markdown", "id": "105", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 0 + }, "source": [ "Select an importance sample by weight in both the low and high deviation settings. It will be handy\n", "to have one path to work with to test our improvements." @@ -2242,19 +2240,21 @@ "cell_type": "code", "execution_count": null, "id": "106", - "metadata": {}, + "metadata": { + "lines_to_next_cell": 0 + }, "outputs": [], "source": [ "key, k1, k2 = jax.random.split(key, 3)\n", - "low_deviation_path = select_by_weight(k1, low_weights, low_deviation_paths)\n", - "high_deviation_path = select_by_weight(k2, high_weights, high_deviation_paths)" + "low_deviation_path, _ = select_by_weight(k1, low_weights, low_deviation_paths)\n", + "high_deviation_path, _ = select_by_weight(k2, high_weights, high_deviation_paths)" ] }, { "cell_type": "markdown", "id": "107", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "source": [ "Create a choicemap that will enforce the given sensor observation" @@ -2268,7 +2268,7 @@ "outputs": [], "source": [ "def observation_to_choicemap(observation, pose=None):\n", - " sensor_cm = C[\"sensor\", jnp.arange(len(observation)), \"distance\"].set(observation)\n", + " sensor_cm = C[\"sensor\", :, \"distance\"].set(observation)\n", " pose_cm = (\n", " C[\"pose\", \"p\"].set(pose.p) + C[\"pose\", \"hd\"].set(pose.hd)\n", " if pose is not None\n", @@ -2284,9 +2284,8 @@ "lines_to_next_cell": 0 }, "source": [ - "The first thing we'll try is a Boltzmann update: generate a cloud of nearby points\n", - "using the generative function we wrote, and weightedly select a replacement from that.\n", - "First, let's generate the cloud and visualize it." + "Let's visualize a cloud of possible poses by coloring the elements proportional to their\n", + "plausibility under the sensor readingss." ] }, { @@ -2296,22 +2295,17 @@ "metadata": {}, "outputs": [], "source": [ - "def boltzmann_sample(key: PRNGKey, N: int, gf, observation):\n", - " return jax.vmap(gf.importance, in_axes=(0, None, None))(\n", + "def step_sample(key: PRNGKey, N: int, gf, observation):\n", + " tr, ws = jax.vmap(gf.importance, in_axes=(0, None, None))(\n", " jax.random.split(key, N), observation_to_choicemap(observation), ()\n", " )\n", + " return tr.get_retval()[0], ws\n", "\n", "\n", - "def small_pose_plot(p: Pose, **opts):\n", - " \"\"\"This variant of pose_plot will is better when we're zoomed in on the vicinity of one pose.\n", - " TODO: consider scaling r and wing_length based on the size of the plot domain.\"\"\"\n", - " opts = {\"r\": 0.001} | opts\n", - " return pose_plot(p, wing_length=0.006, **opts)\n", - "\n", - "\n", - "def weighted_small_pose_plot(proposal, truth, poses, ws):\n", - " lse_ws = jnp.log(jnp.sum(jnp.exp(ws)))\n", - " scaled_ws = jnp.exp(ws - lse_ws)\n", + "def weighted_small_pose_plot(proposal, truth, weights, poses, zoom=1):\n", + " max_logw = jnp.max(weights)\n", + " lse_ws = max_logw + jnp.log(jnp.sum(jnp.exp(weights - max_logw)))\n", + " scaled_ws = jnp.exp(weights - lse_ws)\n", " max_scaled_w: FloatArray = jnp.max(scaled_ws)\n", " scaled_ws /= max_scaled_w\n", " # the following hack \"boosts\" lower scores a bit, to give us more visibility into\n", @@ -2319,10 +2313,11 @@ " # invisible without some adjustment, since the score distribution is concentrated\n", " # closely around 1.0\n", " scaled_ws = scaled_ws**0.3\n", + " z = 0.03 * zoom\n", " return Plot.new(\n", - " [small_pose_plot(p, fill=w) for p, w in zip(poses, scaled_ws)]\n", - " + small_pose_plot(proposal, r=0.003, fill=\"red\")\n", - " + small_pose_plot(truth, r=0.003, fill=\"green\")\n", + " [pose_plot(p, fill=w, zoom=z) for p, w in zip(poses, scaled_ws)]\n", + " + pose_plot(proposal, fill=\"red\", zoom=z)\n", + " + pose_plot(truth, fill=\"green\", zoom=z)\n", " ) + {\n", " \"color\": {\"type\": \"linear\", \"scheme\": \"OrRd\"},\n", " \"height\": 400,\n", @@ -2331,31 +2326,17 @@ " }" ] }, - { - "cell_type": "markdown", - "id": "111", - "metadata": { - "lines_to_next_cell": 0 - }, - "source": [ - "For the first step we use the full_model_initial generative function. Subsequent steps\n", - "will use the full_model_kernel. In the case of the initial step, we have:\n", - "- the true initial position of the robot in green\n", - "- the robot's belief about its initial position in red\n", - "- a cloud of possible updates conditioned on the sensor data in shades of purple" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "112", + "id": "111", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "key, sub_key = jax.random.split(key)\n", - "bs = boltzmann_sample(\n", + "step_poses, step_scores = step_sample(\n", " sub_key,\n", " 1000,\n", " full_model_kernel(\n", @@ -2370,18 +2351,18 @@ { "cell_type": "code", "execution_count": null, - "id": "113", + "id": "112", "metadata": {}, "outputs": [], "source": [ "weighted_small_pose_plot(\n", - " path_low_deviation[0], robot_inputs[\"start\"], bs[0].get_retval()[0], bs[1]\n", + " path_low_deviation[0], robot_inputs[\"start\"], step_scores, step_poses\n", ")" ] }, { "cell_type": "markdown", - "id": "114", + "id": "113", "metadata": { "lines_to_next_cell": 0 }, @@ -2397,7 +2378,7 @@ { "cell_type": "code", "execution_count": null, - "id": "115", + "id": "114", "metadata": { "lines_to_next_cell": 2 }, @@ -2418,13 +2399,13 @@ { "cell_type": "code", "execution_count": null, - "id": "116", + "id": "115", "metadata": {}, "outputs": [], "source": [ - "def grid_sample(gf, pose_grid, observations):\n", + "def grid_sample(gf, pose_grid, observation):\n", " scores, _retvals = jax.vmap(\n", - " lambda pose: gf.assess(observation_to_choicemap(observations, pose), ())\n", + " lambda pose: gf.assess(observation_to_choicemap(observation, pose), ())\n", " )(pose_grid)\n", " return scores" ] @@ -2432,7 +2413,7 @@ { "cell_type": "code", "execution_count": null, - "id": "117", + "id": "116", "metadata": {}, "outputs": [], "source": [ @@ -2452,7 +2433,7 @@ }, { "cell_type": "markdown", - "id": "118", + "id": "117", "metadata": { "lines_to_next_cell": 0 }, @@ -2464,98 +2445,49 @@ { "cell_type": "code", "execution_count": null, - "id": "119", + "id": "118", "metadata": { - "lines_to_next_cell": 2 + "lines_to_next_cell": 0 }, "outputs": [], "source": [ "# Test our code for visualizing the Boltzmann and grid searches at the initial pose.\n", - "def initial_pose_chart(key):\n", + "def first_step_chart(key):\n", " cube_step_size = 6\n", " pose_grid = grid_of_nearby_poses(\n", " path_low_deviation[0], cube_step_size, motion_settings_low_deviation\n", " )\n", + " gf = full_model_kernel(\n", + " motion_settings_low_deviation,\n", + " robot_inputs[\"start\"],\n", + " robot_inputs[\"controls\"][0],\n", + " )\n", " score_grid = grid_sample(\n", - " full_model_kernel(\n", - " motion_settings_low_deviation,\n", - " robot_inputs[\"start\"],\n", - " robot_inputs[\"controls\"][0],\n", - " ),\n", + " gf,\n", " pose_grid,\n", " observations_low_deviation[0],\n", " )\n", + " step_poses, step_scores = step_sample(\n", + " key,\n", + " 1000,\n", + " gf,\n", + " observations_low_deviation[0],\n", + " )\n", " pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid)\n", " return weighted_small_pose_plot(\n", - " path_low_deviation[0], robot_inputs[\"start\"], pose_plane, score_plane\n", + " path_low_deviation[0], robot_inputs[\"start\"], score_plane, pose_plane\n", " ) & weighted_small_pose_plot(\n", - " path_low_deviation[0], robot_inputs[\"start\"], bs[0].get_retval()[0], bs[1]\n", + " path_low_deviation[0], robot_inputs[\"start\"], step_scores, step_poses\n", " )\n", "\n", "\n", "key, sub_key = jax.random.split(key)\n", - "initial_pose_chart(sub_key)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "120", - "metadata": {}, - "outputs": [], - "source": [ - "# See if this works for other points in the path\n", - "\n", - "\n", - "def improvements_at_step(key, path, k):\n", - " gf = full_model_kernel(\n", - " motion_settings_low_deviation, path[k - 1], robot_inputs[\"controls\"][k - 1]\n", - " )\n", - " cube_step_size = 6\n", - " bs = boltzmann_sample(k1, 500, gf, observations_low_deviation[k])\n", - " print(\n", - " f'from {path[k-1]}, step {robot_inputs['controls'][k-1]}, truth {path_low_deviation[k]}, ps {bs[0].get_retval()[0]}'\n", - " )\n", - " p1 = weighted_small_pose_plot(\n", - " path[k], path_low_deviation[k], bs[0].get_retval()[0], bs[1]\n", - " )\n", - " pose_grid = grid_of_nearby_poses(\n", - " path[k], cube_step_size, motion_settings_low_deviation\n", - " )\n", - " score_grid = grid_sample(gf, pose_grid, observations_low_deviation[k])\n", - " pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid)\n", - " print(f\"score_plane {score_plane}\")\n", - " p2 = weighted_small_pose_plot(\n", - " path[k], path_low_deviation[k], pose_plane, score_plane\n", - " )\n", - " return p1 & p2\n", - "\n", - "\n", - "key, sub_key = jax.random.split(key)\n", - "improvements_at_step(sub_key, low_deviation_path, 5)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "121", - "metadata": {}, - "outputs": [], - "source": [ - "# Animation of the above\n", - "key, *sub_keys = jax.random.split(key, 5)\n", - "Plot.Frames(\n", - " [\n", - " improvements_at_step(k, low_deviation_path, i + 1)\n", - " for i, k in enumerate(sub_keys)\n", - " ],\n", - " fps=2,\n", - ")" + "first_step_chart(sub_key)" ] }, { "cell_type": "markdown", - "id": "122", + "id": "119", "metadata": { "lines_to_next_cell": 2 }, @@ -2575,65 +2507,51 @@ { "cell_type": "code", "execution_count": null, - "id": "123", + "id": "120", "metadata": {}, "outputs": [], "source": [ - "def improved_path(\n", - " mode: str, key: PRNGKey, motion_settings: dict, observations: FloatArray\n", - "):\n", - " def boltzmann_step(k: PRNGKey, gf, _center_pose, observation):\n", - " k1, k2 = jax.random.split(k, 2)\n", - " trs, ws = boltzmann_sample(k1, 1000, gf, observation)\n", - " return ws, trs.get_retval()[0]\n", + "def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray):\n", + " cube_step_size = 8\n", "\n", " def grid_search_step(k: PRNGKey, gf, center_pose, observation):\n", - " pose_grid = grid_of_nearby_poses(center_pose, 15, motion_settings)\n", + " pose_grid = grid_of_nearby_poses(center_pose, cube_step_size, motion_settings)\n", " nearby_weights = grid_sample(gf, pose_grid, observation)\n", " return nearby_weights, pose_grid\n", "\n", " def improved_step(state, update):\n", " observation, control, key = update\n", " gf = full_model_kernel(motion_settings, state, control)\n", + " # Run a sample and pick an element by weight.\n", " k1, k2, k3 = jax.random.split(key, 3)\n", - " # First, just run the model.\n", - " tr = gf.simulate(k1, ())\n", - " new_pose = tr.get_retval()[0]\n", - " improver = {\"grid\": grid_search_step, \"boltzmann\": boltzmann_step}[mode]\n", - " # Run the improver, and add the candidate point to the list of weights and\n", - " # return values, to create the possibility of accepting the initial proposal\n", - " # as well as any of the improvement candidates, as Bayesian inference requires\n", - " weights, poses = improver(k2, gf, new_pose, observation)\n", - " # weights = jnp.append(weights, tr.get_score())\n", - " # poses = Pose(jnp.vstack(poses.p, new_pose.p), jnp.append(poses.hd, new_pose.hd))\n", - " chosen_pose = select_by_weight(k3, weights, poses)\n", - " return chosen_pose, chosen_pose\n", - "\n", - " # We have one fewer control than step, since no step got us to the initial position.\n", - " # Our scan step starts at the initial step and applies a control input each time.\n", - " # To make things balance, we need to add a zero step to the end of the control input\n", - " # array, so that when we arrive at the final step, no more control input is given.\n", - " controls = robot_inputs[\"controls\"]\n", - " n_steps = len(controls)\n", - " sub_keys = jax.random.split(key, n_steps + 1)\n", + " poses, scores = step_sample(k1, 1000, gf, observation)\n", + " new_pose, new_weight = select_by_weight(k2, scores, poses)\n", + " weights2, poses2 = grid_search_step(k2, gf, new_pose, observation)\n", + " # Note that `new_pose` will be among the poses considered by grid_search_step,\n", + " # so the possibility exists to remain stationary, as Bayesian inference requires\n", + " chosen_pose, _ = select_by_weight(k3, weights2, poses2)\n", + " flat_poses, flat_scores = flatten_pose_cube(poses2, cube_step_size, weights2)\n", + " return chosen_pose, (new_pose, chosen_pose, flat_scores, flat_poses, new_weight)\n", + "\n", + " sub_keys = jax.random.split(key, T + 1)\n", " return jax.lax.scan(\n", " improved_step,\n", " robot_inputs[\"start\"],\n", " (\n", " observations, # observation at time t\n", - " controls, # guides step from t to t+1\n", + " robot_inputs[\"controls\"], # guides step from t to t+1\n", " sub_keys[1:],\n", " ),\n", " )\n", "\n", "\n", - "jit_improved_path = jax.jit(improved_path, static_argnums=0)" + "jit_improved_path = jax.jit(improved_path)" ] }, { "cell_type": "code", "execution_count": null, - "id": "124", + "id": "121", "metadata": { "lines_to_next_cell": 0 }, @@ -2641,15 +2559,21 @@ "source": [ "key, sub_key = jax.random.split(key)\n", "_, improved_low = jit_improved_path(\n", - " \"grid\", sub_key, motion_settings_low_deviation, observations_low_deviation\n", + " sub_key, motion_settings_low_deviation, observations_low_deviation\n", + ")\n", + "key, sub_key = jax.random.split(key)\n", + "_, improved_high = jit_improved_path(\n", + " sub_key, motion_settings_high_deviation, observations_high_deviation\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "125", - "metadata": {}, + "id": "122", + "metadata": { + "lines_to_next_cell": 0 + }, "outputs": [], "source": [ "def path_comparison_plot(*plots):\n", @@ -2673,38 +2597,34 @@ { "cell_type": "code", "execution_count": null, - "id": "126", + "id": "123", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "path_comparison_plot(\n", - " improved_low, path_integrated, low_deviation_path, path_low_deviation\n", + " improved_low[0], path_integrated, low_deviation_path, path_low_deviation\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "127", + "id": "124", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ - "key, sub_key = jax.random.split(key)\n", - "_, improved_high = jit_improved_path(\n", - " \"boltzmann\", sub_key, motion_settings_high_deviation, observations_high_deviation\n", - ")\n", "path_comparison_plot(\n", - " improved_high, path_integrated, high_deviation_path, path_high_deviation\n", + " improved_high[0], path_integrated, high_deviation_path, path_high_deviation\n", ")" ] }, { "cell_type": "markdown", - "id": "128", + "id": "125", "metadata": { "lines_to_next_cell": 0 }, @@ -2717,7 +2637,7 @@ { "cell_type": "code", "execution_count": null, - "id": "129", + "id": "126", "metadata": { "lines_to_next_cell": 0 }, @@ -2725,7 +2645,7 @@ "source": [ "Plot.Row(\n", " animate_path_and_sensors(\n", - " improved_high,\n", + " improved_high[0],\n", " observations_high_deviation,\n", " motion_settings_high_deviation,\n", " frame_key=\"frame\",\n", @@ -2736,7 +2656,57 @@ " motion_settings_high_deviation,\n", " frame_key=\"frame\",\n", " ),\n", - ") | Plot.Slider(\"frame\", 0, T, fps=2)" + ") | Plot.Slider(\"frame\", 0, T - 1, fps=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "127", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "# Finishing touch: weave together the improved plot and the improvement steps\n", + "# into a slider animation\n", + "# Plot.Frames(\n", + "# [weighted_small_pose_plot(improved_high[0][k], path_high_deviation[k], improved_high[2][k], improved_high[1][k]) for k in range(T)],\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "128", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "def wsp_frame(k):\n", + " return path_comparison_plot(\n", + " improved_high[0][: k + 1],\n", + " path_integrated[: k + 1],\n", + " high_deviation_path[: k + 1],\n", + " path_high_deviation[: k + 1],\n", + " ) & weighted_small_pose_plot(\n", + " improved_high[1][k],\n", + " path_high_deviation[k],\n", + " improved_high[2][k],\n", + " improved_high[3][k],\n", + " zoom=4,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "129", + "metadata": {}, + "outputs": [], + "source": [ + "Plot.Frames([wsp_frame(k) for k in range(1, 6)])" ] }, { From 4d79ab90848b2e4feb977bf85e7d6b949e8d2bc0 Mon Sep 17 00:00:00 2001 From: Colin Smith Date: Mon, 18 Nov 2024 09:31:58 -0800 Subject: [PATCH 45/86] notes from pair with @huebert --- .../probcomp-localization-tutorial.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index ba16efb..89e3b43 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -498,7 +498,6 @@ def step_proposal(motion_settings, start, control): hd = genjax.normal(start.hd + control.dhd, motion_settings["hd_noise"]) @ "hd" return physical_step(start.p, p, hd) - # Set the motion settings default_motion_settings = {"p_noise": 0.5, "hd_noise": 2 * jnp.pi / 36.0} @@ -702,8 +701,6 @@ def confidence_circle(pose: Pose, p_noise: float): step_proposal.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() ) - - # result[0] ~~ robot_inputs['start'] + control_step[0] (which is zero) + noise # %% def generate_path_trace(key: PRNGKey) -> genjax.Trace: @@ -731,7 +728,6 @@ def generate_path(key: PRNGKey) -> Pose: # %% # Animation showing a single path with confidence circles - # TODO: is there an off-by-one here possibly as a result of the zero initial step? # TODO: how about plot the control vector? def plot_path_with_confidence(path: Pose, step: int, p_noise: float): @@ -744,7 +740,7 @@ def plot_path_with_confidence(path: Pose, step: int, p_noise: float): plot += [ confidence_circle( # for a given index, step[index] is current pose, controls[index] is what was applied to prev pose - path[step].apply_control(robot_inputs["controls"][step + 1]), + path[step].apply_control(robot_inputs["controls"][step+1]), p_noise, ), pose_plot(path[step + 1], fill=Plot.constantly("next pose")), From 3b50801fa524e61c3f55856d596144206f88b3ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:32:13 +0000 Subject: [PATCH 46/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../probcomp-localization-tutorial.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 89e3b43..25d3cbc 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -498,6 +498,7 @@ def step_proposal(motion_settings, start, control): hd = genjax.normal(start.hd + control.dhd, motion_settings["hd_noise"]) @ "hd" return physical_step(start.p, p, hd) + # Set the motion settings default_motion_settings = {"p_noise": 0.5, "hd_noise": 2 * jnp.pi / 36.0} @@ -701,6 +702,7 @@ def confidence_circle(pose: Pose, p_noise: float): step_proposal.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() ) + # result[0] ~~ robot_inputs['start'] + control_step[0] (which is zero) + noise # %% def generate_path_trace(key: PRNGKey) -> genjax.Trace: @@ -728,6 +730,7 @@ def generate_path(key: PRNGKey) -> Pose: # %% # Animation showing a single path with confidence circles + # TODO: is there an off-by-one here possibly as a result of the zero initial step? # TODO: how about plot the control vector? def plot_path_with_confidence(path: Pose, step: int, p_noise: float): @@ -740,7 +743,7 @@ def plot_path_with_confidence(path: Pose, step: int, p_noise: float): plot += [ confidence_circle( # for a given index, step[index] is current pose, controls[index] is what was applied to prev pose - path[step].apply_control(robot_inputs["controls"][step+1]), + path[step].apply_control(robot_inputs["controls"][step + 1]), p_noise, ), pose_plot(path[step + 1], fill=Plot.constantly("next pose")), From bdfe521fb86370af1acf246ddfec1c519d2ec2a3 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Fri, 8 Nov 2024 16:09:10 +0100 Subject: [PATCH 47/86] use genstudio v2024.11.006 --- .../probcomp-localization-tutorial.py | 16 ++++++++-------- poetry.lock | 17 +++++++++++------ pyproject.toml | 2 +- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 25d3cbc..c8398d1 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -802,7 +802,7 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): + Plot.color_map({"some pose": "green", "with heading modified": "red"}) + Plot.title("Modifying a heading") ) - | html("span.tc", f"score ratio: {rotated_trace_weight_diff}") + | html(["span.tc", f"score ratio: {rotated_trace_weight_diff}"]) ) # %% [markdown] @@ -829,7 +829,7 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): for pose in path_from_trace(trace) ] + Plot.color_map({"some path": "green", "with heading modified": "red"}) -) | html("span.tc", f"score ratio: {rotated_first_step_weight_diff}") +) | html(["span.tc", f"score ratio: {rotated_first_step_weight_diff}"]) # %% [markdown] # ### Ideal sensors @@ -1166,13 +1166,13 @@ def plt(readings): sample, log_weight = model_importance( sub_key, constraints_low_deviation, (motion_settings_low_deviation,) ) -animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") +animate_full_trace(sample) | html(["span.tc", f"log_weight: {log_weight}"]) # %% key, sub_key = jax.random.split(key) sample, log_weight = model_importance( sub_key, constraints_high_deviation, (motion_settings_high_deviation,) ) -animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") +animate_full_trace(sample) | html(["span.tc", f"log_weight: {log_weight}"]) # %% [markdown] # A trace resulting from a call to `importance` is structurally indistinguishable from one drawn from `simulate`. But there is a key situational difference: while `get_score` always returns the frequency with which `simulate` stochastically produces the trace, this value is **no longer equal to** the frequency with which the trace is stochastically produced by `importance`. This is both true in an obvious and less relevant sense, as well as true in a more subtle and extremely germane sense. # @@ -1253,9 +1253,9 @@ def constraint_from_path(path): Plot.Row( *[ ( - html("div.f3.b.tc", title) + html(["div.f3.b.tc", title]) | animate_full_trace(trace, frame_key="frame") - | html("span.tc", f"score: {score:,.2f}") + | html(["span.tc", f"score: {score:,.2f}"]) ) for (title, trace, motion_settings, score) in [ [ @@ -1272,7 +1272,7 @@ def constraint_from_path(path): ], ] ] -) | Plot.Slider("frame", 0, T, fps=2) +) | Plot.Slider("frame", 0, T-1, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... @@ -1753,7 +1753,7 @@ def path_comparison_plot(*plots): motion_settings_high_deviation, frame_key="frame", ), -) | Plot.Slider("frame", 0, T - 1, fps=2) +) | Plot.Slider("frame", 0, T, fps=2) # %% diff --git a/poetry.lock b/poetry.lock index cfcb83f..0f14072 100644 --- a/poetry.lock +++ b/poetry.lock @@ -546,6 +546,13 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -711,13 +718,13 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.10.1" +version = "2024.11.6" description = "" optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genstudio-2024.10.1-py3-none-any.whl", hash = "sha256:c95cffb1e3d9ca8d9424a535ba227c3e8ecbdc95673e907f9da78e89d6c77b3c"}, - {file = "genstudio-2024.10.1.tar.gz", hash = "sha256:279d461dbec2c6d58f27c99216d9199f40f233a7add506f1c909cf48e9aff8e7"}, + {file = "genstudio-2024.11.6-py3-none-any.whl", hash = "sha256:4b0157660022e69aadc408d6b5207ef37bd08cce7668b3092253071b7d8139a9"}, + {file = "genstudio-2024.11.6.tar.gz", hash = "sha256:b21dbdc2a4ba0fd7629fff10eac4f84c50b42f6bb74b599703d8b07335ef6456"}, ] [package.dependencies] @@ -1507,7 +1514,6 @@ files = [ {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf3c20c6a7db69df58672a0d5815647ecf78c8e62a4d9bd284e8621c1fe5ccb"}, {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:75c38f5647e02d423807d252ce4528bf6a95bd776af999cb1fb48867ed01d1f6"}, {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23458d31fa50ec18e0ec4b0b4343730928296b11111df5f547c75913714116b2"}, - {file = "orjson-3.10.10-cp311-none-win32.whl", hash = "sha256:2787cd9dedc591c989f3facd7e3e86508eafdc9536a26ec277699c0aa63c685b"}, {file = "orjson-3.10.10-cp311-none-win_amd64.whl", hash = "sha256:6514449d2c202a75183f807bc755167713297c69f1db57a89a1ef4a0170ee269"}, {file = "orjson-3.10.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8564f48f3620861f5ef1e080ce7cd122ee89d7d6dacf25fcae675ff63b4d6e05"}, {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bf161a32b479034098c5b81f2608f09167ad2fa1c06abd4e527ea6bf4837a9"}, @@ -1544,7 +1550,6 @@ files = [ {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12f2003695b10817f0fa8b8fca982ed7f5761dcb0d93cff4f2f9f6709903fd7"}, {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:672f9874a8a8fb9bb1b771331d31ba27f57702c8106cdbadad8bda5d10bc1019"}, {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dcbb0ca5fafb2b378b2c74419480ab2486326974826bbf6588f4dc62137570a"}, - {file = "orjson-3.10.10-cp39-none-win32.whl", hash = "sha256:d9bbd3a4b92256875cb058c3381b782649b9a3c68a4aa9a2fff020c2f9cfc1be"}, {file = "orjson-3.10.10-cp39-none-win_amd64.whl", hash = "sha256:766f21487a53aee8524b97ca9582d5c6541b03ab6210fbaf10142ae2f3ced2aa"}, {file = "orjson-3.10.10.tar.gz", hash = "sha256:37949383c4df7b4337ce82ee35b6d7471e55195efa7dcb45ab8226ceadb0fe3b"}, ] @@ -2665,4 +2670,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "42925b13162b17dff6f9ac5a701991d53a8de8a7f653847c1b18ae6d3a480d1b" +content-hash = "479f69b6629a460a005289181202b72c2bbecef3b2dce37a765c9c1f62be05f4" diff --git a/pyproject.toml b/pyproject.toml index d53ba76..da72b31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } -genstudio = {version = "2024.10.1", source = "gcp"} +genstudio = {version = "2024.11.006", source = "gcp"} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" From cc4316e5febc7d54d5ccfac79fa88d9f0e00d0ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 15:13:22 +0000 Subject: [PATCH 48/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- genjax-localization-tutorial/probcomp-localization-tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index c8398d1..3c68a55 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1272,7 +1272,7 @@ def constraint_from_path(path): ], ] ] -) | Plot.Slider("frame", 0, T-1, fps=2) +) | Plot.Slider("frame", 0, T - 1, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... From e7f70dd3df767a4a4d7823e0ca55b08df1f724ac Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Tue, 12 Nov 2024 17:59:19 +0100 Subject: [PATCH 49/86] use genstudio v2024.11.012 --- .../probcomp-localization-tutorial.py | 2 +- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 3c68a55..98396ef 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1272,7 +1272,7 @@ def constraint_from_path(path): ], ] ] -) | Plot.Slider("frame", 0, T - 1, fps=2) +) | Plot.Slider("frame", 0, T, fps=2) # %% [markdown] # ...more closely resembles the density of these data back-fitted onto any other typical (random) paths of the model... diff --git a/poetry.lock b/poetry.lock index 0f14072..2afa41e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -718,13 +718,13 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.11.6" +version = "2024.11.12" description = "" optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genstudio-2024.11.6-py3-none-any.whl", hash = "sha256:4b0157660022e69aadc408d6b5207ef37bd08cce7668b3092253071b7d8139a9"}, - {file = "genstudio-2024.11.6.tar.gz", hash = "sha256:b21dbdc2a4ba0fd7629fff10eac4f84c50b42f6bb74b599703d8b07335ef6456"}, + {file = "genstudio-2024.11.12-py3-none-any.whl", hash = "sha256:068ddddc29615b0919eeec5a6a0996875ed6e07ef3e653fecefa57ce34674874"}, + {file = "genstudio-2024.11.12.tar.gz", hash = "sha256:d750fa1b9f4c6ca42d12187472fc6eabffef4dd14f4f3215c4d499b65087a3f0"}, ] [package.dependencies] @@ -2670,4 +2670,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "479f69b6629a460a005289181202b72c2bbecef3b2dce37a765c9c1f62be05f4" +content-hash = "a2b82d8c7045142b1654fbd425ea02e2a46fe75c32865143eeb26a1c889cfee4" diff --git a/pyproject.toml b/pyproject.toml index da72b31..abf81d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } -genstudio = {version = "2024.11.006", source = "gcp"} +genstudio = {version = "2024.11.012", source = "gcp"} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" From 48ac7a219bc8221716088569f2dfd2d5ab9043b0 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Tue, 19 Nov 2024 12:00:49 +0100 Subject: [PATCH 50/86] fix heading modified --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 2afa41e..69e4cde 100644 --- a/poetry.lock +++ b/poetry.lock @@ -718,13 +718,13 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.11.12" +version = "2024.11.13" description = "" optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genstudio-2024.11.12-py3-none-any.whl", hash = "sha256:068ddddc29615b0919eeec5a6a0996875ed6e07ef3e653fecefa57ce34674874"}, - {file = "genstudio-2024.11.12.tar.gz", hash = "sha256:d750fa1b9f4c6ca42d12187472fc6eabffef4dd14f4f3215c4d499b65087a3f0"}, + {file = "genstudio-2024.11.13-py3-none-any.whl", hash = "sha256:a69a3c2c2a5120d23bd886f9aaec4043bf3233f876cea3e09b60c49f9292c7ca"}, + {file = "genstudio-2024.11.13.tar.gz", hash = "sha256:61886af685f1c58cc08007fd71fe540e043216c69987f3754ab0cf79f6463d0e"}, ] [package.dependencies] @@ -2670,4 +2670,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "a2b82d8c7045142b1654fbd425ea02e2a46fe75c32865143eeb26a1c889cfee4" +content-hash = "cd6dead5b8982be73a6500b56424dc543d5640c30270489469aeaab04eb8aef4" diff --git a/pyproject.toml b/pyproject.toml index abf81d3..95175fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } -genstudio = {version = "2024.11.012", source = "gcp"} +genstudio = {version = "2024.11.013", source = "gcp"} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" From 51699b3dfcf270b5d4b10973d285f6c35ffcc61e Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Thu, 28 Nov 2024 21:02:51 +0100 Subject: [PATCH 51/86] initial pose angle slider --- .../probcomp-localization-tutorial.py | 138 +++++++++++++----- poetry.lock | 8 +- pyproject.toml | 2 +- 3 files changed, 106 insertions(+), 42 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 98396ef..cac1032 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -13,8 +13,9 @@ # display_name: Python 3 # language: python # name: python3 -# --- # pyright: reportUnusedExpression=false +# pyright: reportUnusedCallResult=false +# --- # %% # import sys @@ -35,6 +36,7 @@ import json import genstudio.plot as Plot +from genstudio.widget import Widget import itertools import jax import jax.numpy as jnp @@ -572,7 +574,7 @@ def confidence_circle(pose: Pose, p_noise: float): # %% # `simulate` takes the GF plus a tuple of args to pass to it. key, sub_key = jax.random.split(key) -trace = step_proposal.simulate( +trace: genjax.Trace[Pose] = step_proposal.simulate( sub_key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) @@ -779,31 +781,60 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): # # One could, for instance, consider just the placement of the first step, and replace its stochastic choice of heading with an updated value. The original trace was typical under the pose prior model, whereas the modified one may be rather less likely. This plot is annotated with log of how much unlikelier, the score ratio: # %% - key, sub_key = jax.random.split(key) trace = step_proposal.simulate( sub_key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) -key, sub_key = jax.random.split(key) -rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( - sub_key, C["hd"].set(jnp.pi / 2.0) -) -# TODO(huebert): try using a slider to choose the heading we set (initial value is 0.0) - -( - Plot.new( +def rotate_trace(key: PRNGKey, trace: genjax.Trace[Pose], angle) -> tuple[genjax.Trace[Pose], genjax.Weight]: + """Returns a modified trace with the heading set to the given angle (in radians), along with the weight difference.""" + key, sub_key = jax.random.split(key) + rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( + sub_key, C["hd"].set(angle) + ) + return rotated_trace, rotated_trace_weight_diff + +rotated_trace, rotated_trace_weight_diff = rotate_trace(key, trace, jnp.pi / 2.0) + +def set_angle(widget: Widget, e): + angle = float(e["value"]) + rotated_trace, rotated_trace_weight_diff = rotate_trace(key, trace, angle) + widget.state.update({ + "rotated_poses": pose_plot(rotated_trace.get_retval(), + fill=Plot.constantly("with heading modified")), + "angle": angle, + "rotated_trace_weight_diff": rotated_trace_weight_diff}) + +( + Plot.initialState({ + "poses": pose_plot(trace.get_retval(), fill=Plot.constantly("some pose")), + "rotated_poses": pose_plot(rotated_trace.get_retval(), fill=Plot.constantly("with heading modified")), + "rotated_trace_weight_diff": rotated_trace_weight_diff, + "angle": jnp.pi / 2.0 + + }) + | Plot.new( world_plot - + pose_plot(trace.get_retval(), fill=Plot.constantly("some pose")) - + pose_plot( - rotated_trace.get_retval(), fill=Plot.constantly("with heading modified") - ) + + Plot.js("$state.poses") + + Plot.js("$state.rotated_poses") + Plot.color_map({"some pose": "green", "with heading modified": "red"}) + Plot.title("Modifying a heading") ) - | html(["span.tc", f"score ratio: {rotated_trace_weight_diff}"]) -) + | html(["span.tc", Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`")]) + | ( + Plot.js("`angle: ${$state.angle.toFixed(2)}`") + & ["input", {"type": "range", + "name": "angle", + "defaultValue": Plot.js("$state.angle"), + "min": -jnp.pi / 2, + "max": jnp.pi / 2, + "step": 0.1, + "onChange": set_angle + }] + & {"widths": ["80px", "200px"]} + ) +).widget() # %% [markdown] # It is worth carefully thinking through a trickier instance of this. Suppose instead, within the full path, we replaced the first step's stochastic choice of heading with some specific value. @@ -811,25 +842,61 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): key, sub_key = jax.random.split(key) trace = generate_path_trace(sub_key) -key, sub_key = jax.random.split(key) -rotated_first_step, rotated_first_step_weight_diff, _, _ = trace.update( - sub_key, C[0, "hd"].set(jnp.pi / 2.0) -) +def rotate_first_step(key: PRNGKey, trace: genjax.Trace[Pose], angle) -> tuple[genjax.Trace[Pose], genjax.Weight]: + """Returns a modified trace with the first step's heading set to the given angle (in radians), along with the weight difference.""" + key, sub_key = jax.random.split(key) + rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( + sub_key, C[0, "hd"].set(angle) + ) + return rotated_trace, rotated_trace_weight_diff + +def set_first_step_angle(widget: Widget, e): + angle = float(e["value"]) + rotated_trace, rotated_trace_weight_diff = rotate_first_step(key, trace, angle) + widget.state.update({ + "rotated_path": [ + pose_plot(pose, fill=Plot.constantly("with heading modified")) + for pose in path_from_trace(rotated_trace) + ], + "angle": angle, + "rotated_trace_weight_diff": rotated_trace_weight_diff + }) -# %% ( - world_plot - + [ - pose_plot(pose, fill=Plot.constantly("with heading modified")) - for pose in path_from_trace(rotated_first_step) - ] - + [ - pose_plot(pose, fill=Plot.constantly("some path")) - for pose in path_from_trace(trace) - ] - + Plot.color_map({"some path": "green", "with heading modified": "red"}) -) | html(["span.tc", f"score ratio: {rotated_first_step_weight_diff}"]) + Plot.initialState({ + "original_path": [ + pose_plot(pose, fill=Plot.constantly("some path")) + for pose in path_from_trace(trace) + ], + "rotated_path": [ + pose_plot(pose, fill=Plot.constantly("with heading modified")) + for pose in path_from_trace(rotate_first_step(key, trace, jnp.pi / 2.0)[0]) + ], + "rotated_trace_weight_diff": rotate_first_step(key, trace, jnp.pi / 2.0)[1], + "angle": jnp.pi / 2.0 + }) + | Plot.new( + world_plot + + Plot.js("$state.rotated_path") + + Plot.js("$state.original_path") + + Plot.color_map({"some path": "green", "with heading modified": "red"}) + + Plot.title("Modifying first step heading") + ) + | html(["span.tc", Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`")]) + | ( + Plot.js("`angle: ${$state.angle.toFixed(2)}`") + & ["input", {"type": "range", + "name": "angle", + "defaultValue": Plot.js("$state.angle"), + "min": -jnp.pi / 2, + "max": jnp.pi / 2, + "step": 0.1, + "onChange": set_first_step_angle + }] + & {"widths": ["80px", "200px"]} + ) +).widget() # %% [markdown] # ### Ideal sensors @@ -1005,8 +1072,7 @@ def full_model(motion_settings): return ( full_model_kernel.partial_apply(motion_settings).scan()( robot_inputs["start"], robot_inputs["controls"] - ) - @ "steps" + ) @ "steps" ) @@ -1413,7 +1479,6 @@ def resample( ) winners = jax.vmap(genjax.categorical.sampler)( jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) - ) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) @@ -1605,7 +1670,6 @@ def flatten_pose_cube(pose_grid, cube_step_size, scores): n_indices = 2 * cube_step_size + 1 best_heading_indices = jnp.argmax( scores.reshape(n_indices * n_indices, n_indices), axis=1 - ) # those were block relative; linearize them by adding back block indices bs = best_heading_indices + jnp.arange(0, n_indices**3, n_indices) return Pose(pose_grid.p[bs], pose_grid.hd[bs]), scores[bs] diff --git a/poetry.lock b/poetry.lock index 69e4cde..4ee1067 100644 --- a/poetry.lock +++ b/poetry.lock @@ -718,13 +718,13 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.11.13" +version = "2024.11.21" description = "" optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genstudio-2024.11.13-py3-none-any.whl", hash = "sha256:a69a3c2c2a5120d23bd886f9aaec4043bf3233f876cea3e09b60c49f9292c7ca"}, - {file = "genstudio-2024.11.13.tar.gz", hash = "sha256:61886af685f1c58cc08007fd71fe540e043216c69987f3754ab0cf79f6463d0e"}, + {file = "genstudio-2024.11.21-py3-none-any.whl", hash = "sha256:ea605c426b4ea05eede61d23c353a10e97b2fe31c75891a727602da60ec2b058"}, + {file = "genstudio-2024.11.21.tar.gz", hash = "sha256:db49a1803dbd7b83b664847f2e5559c94aded11b9743cace405c4e2437f79e84"}, ] [package.dependencies] @@ -2670,4 +2670,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "cd6dead5b8982be73a6500b56424dc543d5640c30270489469aeaab04eb8aef4" +content-hash = "f42ce43d1cf6c2ef821809e15ed6a88b5a6e0e7ce88acdace00c68cd53bebb48" diff --git a/pyproject.toml b/pyproject.toml index 95175fd..f7eb008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } -genstudio = {version = "2024.11.013", source = "gcp"} +genstudio = {version = "2024.11.021", source = "gcp"} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" From 8a0004b4e8d0d801f503b2f276b8207caf85fc95 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Fri, 29 Nov 2024 20:32:31 +0100 Subject: [PATCH 52/86] wip: drawing canvas works --- .../probcomp-localization-tutorial.py | 8 +- genjax-localization-tutorial/robot-2.py | 198 ++++++++++++++++++ poetry.lock | 25 +-- pyproject.toml | 3 +- 4 files changed, 217 insertions(+), 17 deletions(-) create mode 100644 genjax-localization-tutorial/robot-2.py diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index cac1032..ff8dba2 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -138,7 +138,11 @@ def create_segments(points): return jnp.stack([points, jnp.roll(points, shift=-1, axis=0)], axis=1) -def make_world(wall_verts, clutters_vec, start, controls): +def make_world(wall_verts, clutters_vec, start, controls) -> tuple[ + dict[str, FloatArray | tuple[float, float, float, float] | float | Pose], + dict[str, Control | Pose], + int +]: """ Constructs the world by creating segments for walls and clutters, calculates the bounding box, and prepares the simulation parameters. @@ -1479,7 +1483,7 @@ def resample( ) winners = jax.vmap(genjax.categorical.sampler)( jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) - # indices returned are relative to the start of the K-segment from which they were drawn. + ) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) selected = jax.tree.map(lambda x: x[winners], samples) diff --git a/genjax-localization-tutorial/robot-2.py b/genjax-localization-tutorial/robot-2.py new file mode 100644 index 0000000..2b86d3c --- /dev/null +++ b/genjax-localization-tutorial/robot-2.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Robot Localization: Why is it Hard? +# +# Imagine you have a robot in a building and want it to know where it is. You give it: +# 1. A map of walls and obstacles +# 2. Wheels that can move it around +# 3. Distance sensors that measure how far away walls are +# +# Seems simple, right? Just track the wheel movements and use sensors to confirm position. +# +# ## The Problem +# +# In the real world: +# - Wheels slip and drift +# - Sensors are noisy +# - Small errors compound over time +# +# Try it yourself: +# 1. Add some walls by clicking +# 2. Move the robot around +# 3. Adjust the noise levels to see how they affect: +# - Sensor readings (red lines) +# - Motion uncertainty (blue cloud) +# +# Notice how quickly uncertainty grows when noise is present! + +# %% +# pyright: reportUnusedExpression=false +# pyright: reportUnknownMemberType=false + +import genstudio.plot as Plot +from genstudio.plot import js +import numpy as np +import jax.numpy as jnp +from typing import TypedDict, List, Tuple, Any + +_gensym_counter = 0 + +WALL_WIDTH=6 +PATH_WIDTH=6 + +def gensym(prefix: str = "g") -> str: + """Generate a unique symbol with an optional prefix, similar to Clojure's gensym. + + Args: + prefix: Optional string prefix for the generated symbol. Defaults to "g". + + Returns: + A unique string combining the prefix and a counter. + """ + global _gensym_counter + _gensym_counter += 1 + return f"{prefix}{_gensym_counter}" + +def drawing_system(on_complete): + key = gensym("current_line") + line = Plot.line( + js(f"$state.{key}"), + stroke="#ccc", + strokeWidth=4, + strokeDasharray="4") + + events = Plot.events({ + "_initialState": Plot.initialState({key: []}), + "onDrawStart": js(f"""(e) => {{ + $state.{key} = [[e.x, e.y, e.startTime]]; + }}"""), + "onDraw": js(f"""(e) => {{ + if ($state.{key}.length > 0) {{ + const last = $state.{key}[$state.{key}.length - 1]; + const dx = e.x - last[0]; + const dy = e.y - last[1]; + // Only add point if moved more than threshold distance + if (Math.sqrt(dx*dx + dy*dy) > 0.2) {{ + $state.update(['{key}', 'append', [e.x, e.y, e.startTime]]); + }} + }} + }}"""), + "onDrawEnd": js(f"""(e) => {{ + if ($state.{key}.length > 1) {{ + // Simplify line by keeping only every 3rd point + //const simplified = $state.{key}.filter((_, i) => i % 3 === 0); + %1($state.{key}) + }} + $state.{key} = []; + }}""", on_complete) + }) + return line + events + + + +sliders = ( + Plot.Slider( + "sensor_noise", + range=[0, 1], + step=0.05, + label="Sensor Noise" + ) + & Plot.Slider( + "motion_noise", + range=[0, 1], + step=0.05, + label="Motion Noise" + ) +) + +initial_state = Plot.initialState({ + "walls": [], + "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, + "sensor_noise": 0.1, + "motion_noise": 0.1, + "show_sensors": True, + "selected_tool": "walls", + "robot_path": [] + }) + +canvas = ( + # Draw completed walls + Plot.line( + js("$state.walls"), + stroke=Plot.constantly("Walls"), + strokeWidth=WALL_WIDTH, + z="2", + render=Plot.renderChildEvents({"onClick": js("""(e) => { + const z = $state.walls[e.index][2] + $state.walls = $state.walls.filter(([x, y, z]) => z === e.index) + }""")}) + ) + # Draw current line being drawn + + drawing_system(Plot.js("""(line) => { + if ($state.selected_tool === 'walls') { + $state.update(['walls', 'concat', line]); + } else if ($state.selected_tool === 'path') { + $state.update(['robot_path', 'reset', line]); + } + }""")) + # Draw robot path + + Plot.line( + js("$state.robot_path"), + stroke=Plot.constantly("Robot Path"), + strokeWidth=PATH_WIDTH + ) + # Draw robot + + Plot.dot( + js("[[$state.robot_pose.x, $state.robot_pose.y]]"), + r=10, + fill=Plot.constantly("Robot"), + title="Robot" + ) + + Plot.domain([0, 10], [0, 10]) + + Plot.grid() + + Plot.aspectRatio(1) + + Plot.colorMap({ + "Walls": "#666", + "Drawing": "#999", + "Robot Path": "green", + "Robot": "blue" + }) + + Plot.colorLegend() + ) + +toolbar = Plot.html("Select tool:") | ["div", {"class": "flex gap-2 h-10"}, + ["button", { + "class": js("$state.selected_tool === 'walls' ? 'px-3 py-1 rounded bg-gray-400 hover:bg-gray-500 active:bg-gray-600' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400'"), + "onClick": js("() => $state.selected_tool = 'walls'") + }, "Draw Walls"], + ["button", { + "class": js("$state.selected_tool === 'path' ? 'px-3 py-1 rounded bg-gray-400 hover:bg-gray-500 active:bg-gray-600' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400'"), + "onClick": js("() => $state.selected_tool = 'path'") + }, "Draw Robot Path"] + ] + +instructions = Plot.md(""" +1. Draw walls +2. Draw a robot path +3. Adjust noise levels to see how they affect: + - Sensor readings + - Motion uncertainty + """) + +canvas & (toolbar | instructions) | initial_state \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 4ee1067..c106a1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -718,26 +718,23 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.11.21" +version = "2024.11.021" description = "" optional = false python-versions = ">=3.10,<3.13" -files = [ - {file = "genstudio-2024.11.21-py3-none-any.whl", hash = "sha256:ea605c426b4ea05eede61d23c353a10e97b2fe31c75891a727602da60ec2b058"}, - {file = "genstudio-2024.11.21.tar.gz", hash = "sha256:db49a1803dbd7b83b664847f2e5559c94aded11b9743cace405c4e2437f79e84"}, -] +files = [] +develop = true [package.dependencies] -anywidget = ">=0.9.10,<0.10.0" -html2image = ">=2.0.4.3,<3.0.0.0" -orjson = ">=3.10.6,<4.0.0" -pillow = ">=10.4.0,<11.0.0" -traitlets = ">=5.14.3,<6.0.0" +anywidget = "^0.9.10" +html2image = "^2.0.4.3" +orjson = "^3.10.6" +pillow = "^10.4.0" +traitlets = "^5.14.3" [package.source] -type = "legacy" -url = "https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple" -reference = "gcp" +type = "directory" +url = "../genstudio" [[package]] name = "html2image" @@ -2670,4 +2667,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "f42ce43d1cf6c2ef821809e15ed6a88b5a6e0e7ce88acdace00c68cd53bebb48" +content-hash = "d19e145cafa2c3ee01389b29b5d62dd1b24815544100caa27e6ffe5404af31cb" diff --git a/pyproject.toml b/pyproject.toml index f7eb008..9fb478b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ package-mode = false python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } -genstudio = {version = "2024.11.021", source = "gcp"} +# genstudio = {version = "2024.11.021", source = "gcp"} +genstudio = {path = "../genstudio", develop = true} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" From 488d60a21768780fcb8a7595efed333d666621e8 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Fri, 29 Nov 2024 20:43:30 +0100 Subject: [PATCH 53/86] new narrative approach fix wall removal --- genjax-localization-tutorial/robot-2.py | 62 +++++++++++++------------ 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/genjax-localization-tutorial/robot-2.py b/genjax-localization-tutorial/robot-2.py index 2b86d3c..45b43ba 100644 --- a/genjax-localization-tutorial/robot-2.py +++ b/genjax-localization-tutorial/robot-2.py @@ -16,30 +16,36 @@ # --- # %% [markdown] -# # Robot Localization: Why is it Hard? +# # Robot Localization: A Robot's Perspective # -# Imagine you have a robot in a building and want it to know where it is. You give it: -# 1. A map of walls and obstacles -# 2. Wheels that can move it around -# 3. Distance sensors that measure how far away walls are +# Imagine you're a robot. You have: +# 1. A map of walls in your environment +# 2. A plan of where to go ("move forward 1m, turn right 30°") +# 3. Distance sensors that measure how far walls are # -# Seems simple, right? Just track the wheel movements and use sensors to confirm position. +# **Your Challenge**: Figure out where you actually are! # -# ## The Problem +# ## Why is this Hard? # -# In the real world: +# You can't just follow your plan perfectly because: # - Wheels slip and drift -# - Sensors are noisy -# - Small errors compound over time +# - Sensors give noisy readings +# - Small errors add up over time # -# Try it yourself: -# 1. Add some walls by clicking -# 2. Move the robot around -# 3. Adjust the noise levels to see how they affect: -# - Sensor readings (red lines) -# - Motion uncertainty (blue cloud) +# ## Try It Yourself # -# Notice how quickly uncertainty grows when noise is present! +# 1. First, create the environment: +# - Draw some walls by clicking +# +# 2. Then, plan the robot's path: +# - Draw where you want the robot to go +# - This becomes a series of movement commands +# +# 3. Watch what happens: +# - Blue line: What the robot THINKS it's doing (following commands perfectly) +# - Red rays: What the robot actually SEES (sensor readings) +# - Blue cloud: Where the robot MIGHT be (uncertainty) +# - Green line: Where the robot figures it ACTUALLY is # %% # pyright: reportUnusedExpression=false @@ -57,14 +63,7 @@ PATH_WIDTH=6 def gensym(prefix: str = "g") -> str: - """Generate a unique symbol with an optional prefix, similar to Clojure's gensym. - - Args: - prefix: Optional string prefix for the generated symbol. Defaults to "g". - - Returns: - A unique string combining the prefix and a counter. - """ + """Generate a unique symbol with an optional prefix, similar to Clojure's gensym.""" global _gensym_counter _gensym_counter += 1 return f"{prefix}{_gensym_counter}" @@ -96,6 +95,7 @@ def drawing_system(on_complete): "onDrawEnd": js(f"""(e) => {{ if ($state.{key}.length > 1) {{ // Simplify line by keeping only every 3rd point + // keep this, we may re-enable later //const simplified = $state.{key}.filter((_, i) => i % 3 === 0); %1($state.{key}) }} @@ -104,8 +104,6 @@ def drawing_system(on_complete): }) return line + events - - sliders = ( Plot.Slider( "sensor_noise", @@ -128,7 +126,10 @@ def drawing_system(on_complete): "motion_noise": 0.1, "show_sensors": True, "selected_tool": "walls", - "robot_path": [] + "robot_path": [], # The planned path + "estimated_pose": None, # Robot's best guess of current position + "sensor_readings": [], # Current sensor readings + "show_uncertainty": True # Whether to show position uncertainty cloud }) canvas = ( @@ -139,8 +140,9 @@ def drawing_system(on_complete): strokeWidth=WALL_WIDTH, z="2", render=Plot.renderChildEvents({"onClick": js("""(e) => { - const z = $state.walls[e.index][2] - $state.walls = $state.walls.filter(([x, y, z]) => z === e.index) + const zs = new Set($state.walls.map(w => w[2])); + const targetZ = [...zs][e.index]; + $state.walls = $state.walls.filter(([x, y, z]) => z !== targetZ) }""")}) ) # Draw current line being drawn From 15725fbbacf19ac4b0db7f1b7a6835a161decf35 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Fri, 29 Nov 2024 21:06:53 +0100 Subject: [PATCH 54/86] flesh out reality.py --- genjax-localization-tutorial/__init__.py | 0 pyproject.toml | 6 +- robot_2/__init__.py | 0 robot_2/reality.py | 119 ++++++++++++++++++ .../robot-2.py => robot_2/where_am_i.py | 7 +- 5 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 genjax-localization-tutorial/__init__.py create mode 100644 robot_2/__init__.py create mode 100644 robot_2/reality.py rename genjax-localization-tutorial/robot-2.py => robot_2/where_am_i.py (94%) diff --git a/genjax-localization-tutorial/__init__.py b/genjax-localization-tutorial/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 9fb478b..ba8a637 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,13 @@ [tool.poetry] -name = "genjax-localization-tutorial" +name = "genjax_localization_tutorial" version = "0.1.0" description = "" authors = ["Matthew Huebert "] readme = "README.md" -package-mode = false +packages = [ + { include = "robot_2" } +] [tool.poetry.dependencies] python = ">=3.11,<3.13" diff --git a/robot_2/__init__.py b/robot_2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/robot_2/reality.py b/robot_2/reality.py new file mode 100644 index 0000000..99cfcc4 --- /dev/null +++ b/robot_2/reality.py @@ -0,0 +1,119 @@ +# pyright: reportUnknownMemberType=false +""" +Reality Simulation for Robot Localization Tutorial + +This module simulates the "true" state of the world that the robot can only +interact with through noisy sensors and imperfect motion. This separation helps +reinforce that localization is about the robot figuring out where it is using only: + +1. What it THINKS it did (control commands) +2. What it can SENSE (noisy sensor readings) +3. What it KNOWS about the world (wall map) + +The tutorial code should never peek at _true_pose - it must work only with +the information available through execute_control() and get_sensor_readings(). +""" + +import jax.numpy as jnp +import jax.lax as lax +from dataclasses import dataclass +from typing import List, Tuple, final, Optional +import jax.random +from jax.random import PRNGKey + +@dataclass +class Pose: + """Represents a robot's position (x,y) and heading (radians)""" + p: jnp.ndarray # position [x, y] + hd: float # heading in radians + + def step_along(self, s: float) -> "Pose": + """Move forward by distance s""" + dp = jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) + return Pose(self.p + s * dp, self.hd) + + def rotate(self, angle: float) -> "Pose": + """Rotate by angle (in radians)""" + return Pose(self.p, self.hd + angle) + +@final # Indicates class won't be subclassed +class Reality: + """ + Simulates the true state of the world, which the robot can only access through + noisy sensors and imperfect motion. + + The robot: + - Can try to move (but motion will have noise) + - Can take sensor readings (but they will have noise) + - Cannot access _true_pose directly + """ + + def __init__(self, walls: List[Tuple[float, float]], motion_noise: float, sensor_noise: float): + self.walls = walls + self.motion_noise = motion_noise + self.sensor_noise = sensor_noise + self._true_pose = Pose(jnp.array([0.5, 0.5]), 0.0) + self._key = PRNGKey(0) + + def execute_control(self, control: Tuple[float, float]): + """ + Execute a control command with noise + control: (forward_dist, rotation_angle) + """ + dist, angle = control + # Add noise to motion + self._key, k1, k2 = jax.random.split(self._key, 3) + noisy_dist = dist + jax.random.normal(k1) * self.motion_noise + noisy_angle = angle + jax.random.normal(k2) * self.motion_noise + + # Update true pose + self._true_pose = self._true_pose.step_along(float(noisy_dist)).rotate(float(noisy_angle)) + + # Return only sensor readings + return self.get_sensor_readings() + + def get_sensor_readings(self) -> List[float]: + """Return noisy distance readings to walls""" + angles = jnp.linspace(0, 2*jnp.pi, 8) # 8 sensors around robot + keys = jax.random.split(self._key, 8) + self._key = keys[0] + + def get_reading(key, angle): + true_dist = self._compute_distance_to_wall(angle) + return true_dist + jax.random.normal(key) * self.sensor_noise + + readings = jax.vmap(get_reading)(keys[1:], angles) + return readings.tolist() + + def _compute_distance_to_wall(self, sensor_angle: float) -> float: + """Compute true distance to nearest wall along sensor ray""" + # Ray starts at robot position + ray_start = self._true_pose.p + # Ray direction based on robot heading plus sensor angle + ray_angle = self._true_pose.hd + sensor_angle + ray_dir = jnp.array([jnp.cos(ray_angle), jnp.sin(ray_angle)]) + + def check_wall_intersection(min_dist, i): + p1 = jnp.array(self.walls[i]) + p2 = jnp.array(self.walls[i+1]) + + # Wall vector + wall = p2 - p1 + # Vector from wall start to ray start + s = ray_start - p1 + + # Compute intersection using parametric equations + # Ray: ray_start + t*ray_dir + # Wall: p1 + u*wall + wall_norm = wall/jnp.linalg.norm(wall) + denom = jnp.cross(ray_dir, wall_norm) + + t = jnp.cross(wall_norm, s) / (denom + 1e-10) + u = jnp.cross(ray_dir, s) / (denom + 1e-10) + + # Check if intersection is valid (in front of ray and within wall segment) + is_valid = (jnp.abs(denom) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) + return jnp.where(is_valid, jnp.minimum(min_dist, t), min_dist) + + min_dist = lax.fori_loop(0, len(self.walls)-1, check_wall_intersection, jnp.inf) + return float(jnp.where(jnp.isinf(min_dist), 100.0, min_dist).item()) \ No newline at end of file diff --git a/genjax-localization-tutorial/robot-2.py b/robot_2/where_am_i.py similarity index 94% rename from genjax-localization-tutorial/robot-2.py rename to robot_2/where_am_i.py index 45b43ba..e204b69 100644 --- a/genjax-localization-tutorial/robot-2.py +++ b/robot_2/where_am_i.py @@ -57,6 +57,8 @@ import jax.numpy as jnp from typing import TypedDict, List, Tuple, Any +import robot_2.reality as reality + _gensym_counter = 0 WALL_WIDTH=6 @@ -180,15 +182,14 @@ def drawing_system(on_complete): toolbar = Plot.html("Select tool:") | ["div", {"class": "flex gap-2 h-10"}, ["button", { - "class": js("$state.selected_tool === 'walls' ? 'px-3 py-1 rounded bg-gray-400 hover:bg-gray-500 active:bg-gray-600' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400'"), + "class": js("$state.selected_tool === 'walls' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), "onClick": js("() => $state.selected_tool = 'walls'") }, "Draw Walls"], ["button", { - "class": js("$state.selected_tool === 'path' ? 'px-3 py-1 rounded bg-gray-400 hover:bg-gray-500 active:bg-gray-600' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400'"), + "class": js("$state.selected_tool === 'path' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), "onClick": js("() => $state.selected_tool = 'path'") }, "Draw Robot Path"] ] - instructions = Plot.md(""" 1. Draw walls 2. Draw a robot path From 0ad66b705fa65c521ee290c4a7ec35f734bb96cc Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 00:22:12 +0100 Subject: [PATCH 55/86] wall intersection works --- poetry.lock | 48 ++++++++++- pyproject.toml | 3 +- robot_2/reality.py | 72 +++++++++-------- robot_2/test_reality.py | 39 +++++++++ robot_2/where_am_i.py | 172 +++++++++++++++++++++++++++++++++++++--- 5 files changed, 288 insertions(+), 46 deletions(-) create mode 100644 robot_2/test_reality.py diff --git a/poetry.lock b/poetry.lock index c106a1c..35d94d9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -765,6 +765,17 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "ipykernel" version = "6.29.5" @@ -1729,6 +1740,21 @@ docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-a test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] type = ["mypy (>=1.11.2)"] +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "prompt-toolkit" version = "3.0.48" @@ -1876,6 +1902,26 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pytest" +version = "8.3.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2667,4 +2713,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "d19e145cafa2c3ee01389b29b5d62dd1b24815544100caa27e6ffe5404af31cb" +content-hash = "3d9a2ab9db39381cf0b6982d7546fd98231b1d83e37bac488f6d0dbc879aa7e8" diff --git a/pyproject.toml b/pyproject.toml index ba8a637..08d73af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ packages = [ [tool.poetry.dependencies] python = ">=3.11,<3.13" jupytext = "^1.16.1" -genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } +genjax = {version = "0.7.0.post4.dev0+eacb241e", source = "gcp" } # genstudio = {version = "2024.11.021", source = "gcp"} genstudio = {path = "../genstudio", develop = true} ipykernel = "^6.29.3" @@ -20,6 +20,7 @@ matplotlib = "^3.8.3" anywidget = "^0.9.7" watchfiles = "^0.21.0" markdown = "^3.6" +pytest = "^8.3.3" [[tool.poetry.source]] name = "gcp" diff --git a/robot_2/reality.py b/robot_2/reality.py index 99cfcc4..fa80786 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -48,11 +48,19 @@ class Reality: - Cannot access _true_pose directly """ - def __init__(self, walls: List[Tuple[float, float]], motion_noise: float, sensor_noise: float): + def __init__(self, walls: jnp.ndarray, motion_noise: float, sensor_noise: float, + initial_pose: Optional[Pose] = None): + """ + Args: + walls: JAX array of shape (N, 2, 2) containing wall segments + motion_noise: Standard deviation of noise added to motion + sensor_noise: Standard deviation of noise added to sensor readings + initial_pose: Optional starting pose, defaults to (0.5, 0.5, 0.0) + """ self.walls = walls self.motion_noise = motion_noise self.sensor_noise = sensor_noise - self._true_pose = Pose(jnp.array([0.5, 0.5]), 0.0) + self._true_pose = initial_pose if initial_pose is not None else Pose(jnp.array([0.5, 0.5]), 0.0) self._key = PRNGKey(0) def execute_control(self, control: Tuple[float, float]): @@ -72,9 +80,9 @@ def execute_control(self, control: Tuple[float, float]): # Return only sensor readings return self.get_sensor_readings() - def get_sensor_readings(self) -> List[float]: + def get_sensor_readings(self) -> jnp.ndarray: """Return noisy distance readings to walls""" - angles = jnp.linspace(0, 2*jnp.pi, 8) # 8 sensors around robot + angles = jnp.linspace(0, 2*jnp.pi, 8, endpoint=False) # 8 evenly spaced sensors keys = jax.random.split(self._key, 8) self._key = keys[0] @@ -82,38 +90,36 @@ def get_reading(key, angle): true_dist = self._compute_distance_to_wall(angle) return true_dist + jax.random.normal(key) * self.sensor_noise - readings = jax.vmap(get_reading)(keys[1:], angles) - return readings.tolist() + readings = jax.vmap(get_reading)(keys[1:], angles[:-1]) + return readings - def _compute_distance_to_wall(self, sensor_angle: float) -> float: - """Compute true distance to nearest wall along sensor ray""" - # Ray starts at robot position + def _compute_distance_to_wall(self, sensor_angle: float) -> jax.Array: + """Compute true distance to nearest wall along sensor ray using fast 2D ray-segment intersection""" ray_start = self._true_pose.p - # Ray direction based on robot heading plus sensor angle ray_angle = self._true_pose.hd + sensor_angle ray_dir = jnp.array([jnp.cos(ray_angle), jnp.sin(ray_angle)]) - def check_wall_intersection(min_dist, i): - p1 = jnp.array(self.walls[i]) - p2 = jnp.array(self.walls[i+1]) - - # Wall vector - wall = p2 - p1 - # Vector from wall start to ray start - s = ray_start - p1 - - # Compute intersection using parametric equations - # Ray: ray_start + t*ray_dir - # Wall: p1 + u*wall - wall_norm = wall/jnp.linalg.norm(wall) - denom = jnp.cross(ray_dir, wall_norm) - - t = jnp.cross(wall_norm, s) / (denom + 1e-10) - u = jnp.cross(ray_dir, s) / (denom + 1e-10) - - # Check if intersection is valid (in front of ray and within wall segment) - is_valid = (jnp.abs(denom) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) - return jnp.where(is_valid, jnp.minimum(min_dist, t), min_dist) + # Vectorized computation for all walls at once + p1 = self.walls[:, 0] # Shape: (N, 2) + p2 = self.walls[:, 1] # Shape: (N, 2) + + # Wall direction vectors + wall_vec = p2 - p1 # Shape: (N, 2) + + # Vector from wall start to ray start + to_start = ray_start - p1 # Shape: (N, 2) + + # Compute determinant (cross product in 2D) + # This tells us if ray and wall are parallel and their relative orientation + det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) + + # Compute intersection parameters + u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) + t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) + + # Valid intersections: not parallel, in front of ray, within wall segment + is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) - min_dist = lax.fori_loop(0, len(self.walls)-1, check_wall_intersection, jnp.inf) - return float(jnp.where(jnp.isinf(min_dist), 100.0, min_dist).item()) \ No newline at end of file + # Find minimum valid distance + min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) + return jnp.where(jnp.isinf(min_dist), 10.0, min_dist) \ No newline at end of file diff --git a/robot_2/test_reality.py b/robot_2/test_reality.py new file mode 100644 index 0000000..f6605ab --- /dev/null +++ b/robot_2/test_reality.py @@ -0,0 +1,39 @@ +import jax.numpy as jnp +import pytest +from robot_2.reality import Reality, Pose + +def test_basic_motion(): + """Test that robot moves as expected without noise""" + # Convert walls to JAX array at creation - now in (N,2,2) shape + walls = jnp.array([ + [[0.0, 0.0], [1.0, 0.0]], # bottom wall + [[1.0, 0.0], [1.0, 1.0]], # right wall + [[1.0, 1.0], [0.0, 1.0]], # top wall + [[0.0, 1.0], [0.0, 0.0]] # left wall + ]) + world = Reality(walls, motion_noise=0.0, sensor_noise=0.0) + + # Move forward 1 unit - ignore readings since we're testing motion + _ = world.execute_control((1.0, 0.0)) + assert world._true_pose.p[0] == pytest.approx(1.5) # Started at 0.5, moved 1.0 + assert world._true_pose.p[1] == pytest.approx(0.5) # Y shouldn't change + + # Rotate 90 degrees (π/2 radians) + _ = world.execute_control((0.0, jnp.pi/2)) + assert world._true_pose.hd == pytest.approx(jnp.pi/2) + +def test_pose_methods(): + """Test Pose step_along and rotate methods""" + p = Pose(jnp.array([1.0, 1.0]), 0.0) + + # Step along heading 0 (right) + p2 = p.step_along(1.0) + assert p2.p[0] == pytest.approx(2.0) + assert p2.p[1] == pytest.approx(1.0) + + # Rotate 90 degrees and step + p3 = p.rotate(jnp.pi/2).step_along(1.0) + assert p3.p[0] == pytest.approx(1.0) + assert p3.p[1] == pytest.approx(2.0) + +pytest.main(["-v"]) # \ No newline at end of file diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index e204b69..346d06c 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -121,8 +121,17 @@ def drawing_system(on_complete): ) ) -initial_state = Plot.initialState({ - "walls": [], +def initial_walls(): + return [ + # Frame around domain (timestamp 0) + [0, 0, 0], [10, 0, 0], # Bottom + [10, 0, 0], [10, 10, 0], # Right + [10, 10, 0], [0, 10, 0], # Top + [0, 10, 0], [0, 0, 0], # Left + ] + +initial_state = { + "walls": initial_walls(), "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, "sensor_noise": 0.1, "motion_noise": 0.1, @@ -131,8 +140,11 @@ def drawing_system(on_complete): "robot_path": [], # The planned path "estimated_pose": None, # Robot's best guess of current position "sensor_readings": [], # Current sensor readings - "show_uncertainty": True # Whether to show position uncertainty cloud - }) + "show_uncertainty": True , # Whether to show position uncertainty cloud + "debug_message": "", + "show_debug": False + } + canvas = ( # Draw completed walls @@ -161,12 +173,20 @@ def drawing_system(on_complete): stroke=Plot.constantly("Robot Path"), strokeWidth=PATH_WIDTH ) + # Draw true path when in debug mode + + Plot.line( + js("$state.show_debug ? $state.true_path : []"), + stroke=Plot.constantly("True Path"), + strokeWidth=2 + ) # Draw robot - + Plot.dot( + + Plot.text( js("[[$state.robot_pose.x, $state.robot_pose.y]]"), - r=10, + text=Plot.constantly("🤖"), + fontSize=30, fill=Plot.constantly("Robot"), - title="Robot" + title="Robot", + rotate=js("$state.robot_pose.heading * 180 / Math.PI") # Convert radians to degrees ) + Plot.domain([0, 10], [0, 10]) + Plot.grid() @@ -174,12 +194,130 @@ def drawing_system(on_complete): + Plot.colorMap({ "Walls": "#666", "Drawing": "#999", - "Robot Path": "green", + "Robot Path": "blue", "Robot": "blue" }) + Plot.colorLegend() + # Add sensor rays when show_debug is true + + Plot.line( + js(""" + $state.show_debug && $state.sensor_readings ? + Array.from($state.sensor_readings).flatMap((r, i) => { + const heading = $state.robot_pose.heading || 0; + const angle = heading + (i * Math.PI * 2) / 8; + const x = $state.robot_pose.x; + const y = $state.robot_pose.y; + return [ + [x, y, i], + [x + r * Math.cos(angle), + y + r * Math.sin(angle), i] + ] + }) : [] + """), + z="2", + stroke="red", + strokeWidth=1 + ) + + Plot.clip() + + Plot.colorMap({ + "Walls": "#666", + "Robot": "blue", + "Sensor Rays": "red", + "True Path": "green" + }) ) + +def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: + """Convert wall vertices from UI format to JAX array of wall segments + Returns: array of shape (N, 2, 2) where: + N = number of walls + First 2 = start/end point + Second 2 = x,y coordinates + """ + # Convert to array and reshape to (N,3) where columns are x,y,timestamp + points = jnp.array(walls_list).reshape(-1, 3) + # Get consecutive pairs of points + p1 = points[:-1] # All points except last + p2 = points[1:] # All points except first + # Keep only pairs with matching timestamps + mask = p1[:, 2] == p2[:, 2] + # Stack the x,y coordinates into wall segments + segments = jnp.stack([p1[mask][:, :2], p2[mask][:, :2]], axis=1) + return segments + +def create_reality(walls_list: List[List[float]], motion_noise: float, sensor_noise: float) -> reality.Reality: + """Create Reality instance with proper JAX arrays""" + walls = convert_walls_to_jax(walls_list) + return reality.Reality(walls, motion_noise, sensor_noise) + +def path_to_controls(path_points: List[List[float]]) -> List[Tuple[float, float]]: + """Convert a series of points into (distance, angle) control pairs + Returns: List of (forward_dist, rotation_angle) controls + """ + controls = [] + for i in range(len(path_points) - 1): + p1 = jnp.array(path_points[i][:2]) # current point [x,y] + p2 = jnp.array(path_points[i+1][:2]) # next point [x,y] + + # Calculate distance and angle to next point + delta = p2 - p1 + distance = jnp.linalg.norm(delta) + target_angle = jnp.arctan2(delta[1], delta[0]) + + # If not first point, need to rotate from previous heading + if i > 0: + prev_delta = p1 - jnp.array(path_points[i-1][:2]) + prev_angle = jnp.arctan2(prev_delta[1], prev_delta[0]) + rotation = target_angle - prev_angle + else: + # For first point, rotate from initial heading (0) + rotation = target_angle + + controls.append((float(distance), float(rotation))) + + return controls + +def debug_reality(widget, e): + """Quick visual check of Reality class""" + if not widget.state.robot_path: + return # Need a path to get initial pose + + # Get initial pose from start of path + start_point = widget.state.robot_path[0] + initial_pose = reality.Pose(jnp.array([start_point[0], start_point[1]]), 0.0) + + walls = convert_walls_to_jax(widget.state.walls) + world = reality.Reality(walls, + motion_noise=widget.state.motion_noise, + sensor_noise=widget.state.sensor_noise, + initial_pose=initial_pose) + + # Convert path to controls and execute them + controls = path_to_controls(widget.state.robot_path) + readings = [] + true_path = [[float(world._true_pose.p[0]), float(world._true_pose.p[1])]] # Start position + + for control in controls: + reading = world.execute_control(control) + readings.append(reading) + # Record position after each control + true_path.append([float(world._true_pose.p[0]), float(world._true_pose.p[1])]) + + # Update state with final readings, pose, and full path + widget.state.update({ + "robot_pose": { + "x": float(world._true_pose.p[0]), + "y": float(world._true_pose.p[1]), + "heading": float(world._true_pose.hd) + }, + "sensor_readings": readings[-1] if readings else [], + "true_path": true_path, + "show_debug": True, + "debug_message": f"Executed {len(controls)} controls\nFinal readings: {readings[-1] if readings else []}" + }) + +# Add debug button to toolbar toolbar = Plot.html("Select tool:") | ["div", {"class": "flex gap-2 h-10"}, ["button", { "class": js("$state.selected_tool === 'walls' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), @@ -188,7 +326,15 @@ def drawing_system(on_complete): ["button", { "class": js("$state.selected_tool === 'path' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), "onClick": js("() => $state.selected_tool = 'path'") - }, "Draw Robot Path"] + }, "Draw Robot Path"], + ["button", { + "class": "px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400", + "onClick": debug_reality + }, "Debug Reality"], + ["button", { + "class": "px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400", + "onClick": lambda w, e: w.state.update(initial_state | {"walls": initial_walls()}) + }, "Clear"] ] instructions = Plot.md(""" 1. Draw walls @@ -196,6 +342,10 @@ def drawing_system(on_complete): 3. Adjust noise levels to see how they affect: - Sensor readings - Motion uncertainty - """) + """) | ["div", js("$state.debug_message")] -canvas & (toolbar | instructions) | initial_state \ No newline at end of file +( + canvas & + (toolbar | instructions | sliders) + | Plot.initialState(initial_state, sync=True) + | Plot.onChange({"robot_path": debug_reality})) From a698bd3072d8b8b9475fbe177f23054b0f621f8a Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 00:41:45 +0100 Subject: [PATCH 56/86] add wall collision --- pyproject.toml | 2 +- robot_2/reality.py | 30 +++++++++++++++++++----------- robot_2/where_am_i.py | 15 ++++++++++++--- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 08d73af..8031658 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ packages = [ python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e", source = "gcp" } -# genstudio = {version = "2024.11.021", source = "gcp"} +# genstudio = {version = "2024.11.022", source = "gcp"} genstudio = {path = "../genstudio", develop = true} ipykernel = "^6.29.3" matplotlib = "^3.8.3" diff --git a/robot_2/reality.py b/robot_2/reality.py index fa80786..014f032 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -63,21 +63,29 @@ def __init__(self, walls: jnp.ndarray, motion_noise: float, sensor_noise: float, self._true_pose = initial_pose if initial_pose is not None else Pose(jnp.array([0.5, 0.5]), 0.0) self._key = PRNGKey(0) - def execute_control(self, control: Tuple[float, float]): - """ - Execute a control command with noise - control: (forward_dist, rotation_angle) - """ + def execute_control(self, control: Tuple[float, float]) -> List[float]: + """Execute a control command with noise, stopping if we hit a wall""" dist, angle = control # Add noise to motion - self._key, k1, k2 = jax.random.split(self._key, 3) - noisy_dist = dist + jax.random.normal(k1) * self.motion_noise - noisy_angle = angle + jax.random.normal(k2) * self.motion_noise + noisy_dist = dist + jax.random.normal(self._key) * self.motion_noise + noisy_angle = angle + jax.random.normal(self._key) * self.motion_noise + + # First rotate (can always rotate) + self._true_pose = self._true_pose.rotate(noisy_angle) + + # Then try to move forward, checking for collisions + ray_dir = jnp.array([jnp.cos(self._true_pose.hd), jnp.sin(self._true_pose.hd)]) + + # Use our existing ray-casting to check distance to nearest wall + min_dist = self._compute_distance_to_wall(0.0) # 0 angle = forward + + # Only move as far as we can before hitting a wall (minus small safety margin) + safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) + safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards - # Update true pose - self._true_pose = self._true_pose.step_along(float(noisy_dist)).rotate(float(noisy_angle)) + self._true_pose = self._true_pose.step_along(safe_dist) - # Return only sensor readings + # Return sensor readings from new position return self.get_sensor_readings() def get_sensor_readings(self) -> jnp.ndarray: diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 346d06c..a7f0366 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -111,13 +111,17 @@ def drawing_system(on_complete): "sensor_noise", range=[0, 1], step=0.05, - label="Sensor Noise" + init=0.1, + label="Sensor Noise:", + showValue=True ) & Plot.Slider( "motion_noise", range=[0, 1], step=0.05, - label="Motion Noise" + init=0.1, + label="Motion Noise:", + showValue=True ) ) @@ -316,6 +320,11 @@ def debug_reality(widget, e): "show_debug": True, "debug_message": f"Executed {len(controls)} controls\nFinal readings: {readings[-1] if readings else []}" }) + +def clear_state(w, _): + w.state.update(initial_state) + + # Add debug button to toolbar toolbar = Plot.html("Select tool:") | ["div", {"class": "flex gap-2 h-10"}, @@ -333,7 +342,7 @@ def debug_reality(widget, e): }, "Debug Reality"], ["button", { "class": "px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400", - "onClick": lambda w, e: w.state.update(initial_state | {"walls": initial_walls()}) + "onClick": clear_state }, "Clear"] ] instructions = Plot.md(""" From f22618d597e9f1955a3b162abb305444623638db Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 01:00:54 +0100 Subject: [PATCH 57/86] noise sliders work --- robot_2/where_am_i.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index a7f0366..18e711f 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -111,7 +111,6 @@ def drawing_system(on_complete): "sensor_noise", range=[0, 1], step=0.05, - init=0.1, label="Sensor Noise:", showValue=True ) @@ -119,7 +118,6 @@ def drawing_system(on_complete): "motion_noise", range=[0, 1], step=0.05, - init=0.1, label="Motion Noise:", showValue=True ) @@ -317,8 +315,7 @@ def debug_reality(widget, e): }, "sensor_readings": readings[-1] if readings else [], "true_path": true_path, - "show_debug": True, - "debug_message": f"Executed {len(controls)} controls\nFinal readings: {readings[-1] if readings else []}" + "show_debug": True }) def clear_state(w, _): @@ -357,4 +354,8 @@ def clear_state(w, _): canvas & (toolbar | instructions | sliders) | Plot.initialState(initial_state, sync=True) - | Plot.onChange({"robot_path": debug_reality})) + | Plot.onChange({ + "robot_path": debug_reality, + "sensor_noise": debug_reality, + "motion_noise": debug_reality + })) From 85368a60f329891400133ef7a4bf4093e902730a Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 01:34:52 +0100 Subject: [PATCH 58/86] show possible paths, without collision implemented --- robot_2/where_am_i.py | 101 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 2 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 18e711f..851a2ea 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -58,6 +58,87 @@ from typing import TypedDict, List, Tuple, Any import robot_2.reality as reality +import jax +import jax.numpy as jnp +import genjax +from genjax import normal, mv_normal_diag +from penzai import pz # Import penzai for pytree dataclasses +from typing import List, Tuple + +@pz.pytree_dataclass +class MotionSettings(genjax.PythonicPytree): + """Settings for motion uncertainty""" + p_noise: float = 0.1 # Position noise + hd_noise: float = 0.1 # Heading noise + +@pz.pytree_dataclass +class Pose(genjax.PythonicPytree): + """Robot pose with position and heading""" + p: jax.Array # [x, y] + hd: float # heading in radians + + def dp(self): + """Get direction vector from heading""" + return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) + + def step_along(self, s: float) -> "Pose": + """Move forward by distance s""" + return Pose(self.p + s * self.dp(), self.hd) + + def rotate(self, angle: float) -> "Pose": + """Rotate by angle (in radians)""" + return Pose(self.p, self.hd + angle) + +@genjax.gen +def step_proposal(motion_settings: MotionSettings, start_pose: Pose, control: Tuple[float, float]) -> Pose: + """Generate a noisy next pose given current pose and control""" + dist, angle = control + + # Sample noisy position after moving + intended_p = start_pose.p + dist * start_pose.dp() + p = mv_normal_diag( + intended_p, + motion_settings.p_noise * jnp.ones(2) + ) @ "p" + + # Sample noisy heading + hd = normal( + start_pose.hd + angle, + motion_settings.hd_noise + ) @ "hd" + + return Pose(p, hd) + +@genjax.gen +def generate_path(motion_settings: MotionSettings, start_pose: Pose, controls: List[Tuple[float, float]]): + """Generate a complete path by sampling noisy steps""" + pose = start_pose + path = [pose] + + for control in controls: + pose = step_proposal(motion_settings, pose, control) @ f"step_{len(path)}" + path.append(pose) + + return path + +def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: List[List[float]]): + """Generate n possible paths given the planned path""" + controls = path_to_controls(robot_path) + start_point = robot_path[0] + start_pose = Pose(jnp.array(start_point[:2]), 0.0) + settings = MotionSettings() + + # Split key for multiple samples + keys = jax.random.split(key, n_paths) + + # Sample paths using GenJAX + def sample_single_path(k): + trace = generate_path.simulate(k, (settings, start_pose, controls)) + poses = trace.get_retval() + return jnp.stack([pose.p for pose in poses]) + + paths = jax.vmap(sample_single_path)(keys) + return paths _gensym_counter = 0 @@ -138,7 +219,7 @@ def initial_walls(): "sensor_noise": 0.1, "motion_noise": 0.1, "show_sensors": True, - "selected_tool": "walls", + "selected_tool": "path", "robot_path": [], # The planned path "estimated_pose": None, # Robot's best guess of current position "sensor_readings": [], # Current sensor readings @@ -220,6 +301,16 @@ def initial_walls(): stroke="red", strokeWidth=1 ) + + Plot.line( + js(""" + if (!$state.show_debug || !$state.possible_paths) {return [];}; + return $state.possible_paths.flatMap((path, pathIdx) => + path.map(([x, y]) => [x, y, pathIdx]) + ) + """, expression=False), + stroke="blue", + strokeOpacity=0.2 + ) + Plot.clip() + Plot.colorMap({ "Walls": "#666", @@ -287,7 +378,7 @@ def debug_reality(widget, e): # Get initial pose from start of path start_point = widget.state.robot_path[0] - initial_pose = reality.Pose(jnp.array([start_point[0], start_point[1]]), 0.0) + initial_pose = Pose(jnp.array([start_point[0], start_point[1]]), 0.0) walls = convert_walls_to_jax(widget.state.walls) world = reality.Reality(walls, @@ -306,6 +397,11 @@ def debug_reality(widget, e): # Record position after each control true_path.append([float(world._true_pose.p[0]), float(world._true_pose.p[1])]) + # Generate possible paths + key = jax.random.PRNGKey(0) + possible_paths = sample_possible_paths(key, 20, widget.state.robot_path) + + # Update state with final readings, pose, and full path widget.state.update({ "robot_pose": { @@ -313,6 +409,7 @@ def debug_reality(widget, e): "y": float(world._true_pose.p[1]), "heading": float(world._true_pose.hd) }, + "possible_paths": possible_paths, "sensor_readings": readings[-1] if readings else [], "true_path": true_path, "show_debug": True From fc7eb214d5040b0335ac8de100cac7d6b1f3b88c Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 02:02:16 +0100 Subject: [PATCH 59/86] hide true path --- robot_2/where_am_i.py | 192 ++++++++++++++++++++++++------------------ 1 file changed, 108 insertions(+), 84 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 851a2ea..1583f70 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -90,50 +90,77 @@ def rotate(self, angle: float) -> "Pose": return Pose(self.p, self.hd + angle) @genjax.gen -def step_proposal(motion_settings: MotionSettings, start_pose: Pose, control: Tuple[float, float]) -> Pose: - """Generate a noisy next pose given current pose and control""" +def step_proposal(motion_settings: MotionSettings, start_pose: Pose, control: Tuple[float, float], walls: jnp.ndarray) -> Pose: + """Generate a noisy next pose given current pose and control, respecting walls""" dist, angle = control - # Sample noisy position after moving - intended_p = start_pose.p + dist * start_pose.dp() - p = mv_normal_diag( - intended_p, - motion_settings.p_noise * jnp.ones(2) - ) @ "p" - - # Sample noisy heading + # Sample noisy heading first hd = normal( start_pose.hd + angle, motion_settings.hd_noise ) @ "hd" + # Calculate noisy forward distance + noisy_dist = normal( + dist, + motion_settings.p_noise + ) @ "dist" + + # Simplified collision check using just the wall segments + ray_start = start_pose.p + ray_dir = jnp.array([jnp.cos(hd), jnp.sin(hd)]) + + # Vectorized ray-wall intersection (copied from Reality class) + p1 = walls[:, 0] # Shape: (N, 2) + p2 = walls[:, 1] # Shape: (N, 2) + wall_vec = p2 - p1 + to_start = ray_start - p1 + + det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) + u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) + t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) + + is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) + min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) + min_dist = jnp.where(jnp.isinf(min_dist), 10.0, min_dist) + + # Limit distance to avoid wall collision + safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) + safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards + + # Calculate final position + p = start_pose.p + safe_dist * ray_dir + return Pose(p, hd) @genjax.gen -def generate_path(motion_settings: MotionSettings, start_pose: Pose, controls: List[Tuple[float, float]]): - """Generate a complete path by sampling noisy steps""" +def generate_path(motion_settings: MotionSettings, start_pose: Pose, controls: List[Tuple[float, float]], walls: jnp.ndarray): + """Generate a complete path by sampling noisy steps, respecting walls""" pose = start_pose path = [pose] for control in controls: - pose = step_proposal(motion_settings, pose, control) @ f"step_{len(path)}" + pose = step_proposal(motion_settings, pose, control, walls) @ f"step_{len(path)}" path.append(pose) return path -def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: List[List[float]]): - """Generate n possible paths given the planned path""" +def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: List[List[float]], walls: List[List[float]]): + """Generate n possible paths given the planned path, respecting walls""" controls = path_to_controls(robot_path) start_point = robot_path[0] start_pose = Pose(jnp.array(start_point[:2]), 0.0) settings = MotionSettings() + # Convert walls to JAX array format + wall_segments = convert_walls_to_jax(walls) + # Split key for multiple samples keys = jax.random.split(key, n_paths) # Sample paths using GenJAX def sample_single_path(k): - trace = generate_path.simulate(k, (settings, start_pose, controls)) + trace = generate_path.simulate(k, (settings, start_pose, controls, wall_segments)) poses = trace.get_retval() return jnp.stack([pose.p for pose in poses]) @@ -220,14 +247,42 @@ def initial_walls(): "motion_noise": 0.1, "show_sensors": True, "selected_tool": "path", - "robot_path": [], # The planned path - "estimated_pose": None, # Robot's best guess of current position - "sensor_readings": [], # Current sensor readings - "show_uncertainty": True , # Whether to show position uncertainty cloud - "debug_message": "", - "show_debug": False + "robot_path": [], + "estimated_pose": None, + "sensor_readings": [], + "show_uncertainty": True, + "show_true_position": True # New flag for toggling true position visibility } +sensor_rays = Plot.line( + js(""" + Array.from($state.sensor_readings).flatMap((r, i) => { + const heading = $state.robot_pose.heading || 0; + const angle = heading + (i * Math.PI * 2) / 8; + const x = $state.robot_pose.x; + const y = $state.robot_pose.y; + return [ + [x, y, i], + [x + r * Math.cos(angle), + y + r * Math.sin(angle), i] + ] + }) + """), + z="2", + stroke="red", + strokeWidth=1 + ) + +true_path = Plot.line( + js("$state.true_path"), + stroke=Plot.constantly("True Path"), + strokeWidth=2 + ) + +planned_path = Plot.line( + js("$state.robot_path"), + stroke=Plot.constantly("Robot Path"), + strokeWidth=PATH_WIDTH), canvas = ( # Draw completed walls @@ -250,57 +305,32 @@ def initial_walls(): $state.update(['robot_path', 'reset', line]); } }""")) - # Draw robot path - + Plot.line( - js("$state.robot_path"), - stroke=Plot.constantly("Robot Path"), - strokeWidth=PATH_WIDTH - ) - # Draw true path when in debug mode - + Plot.line( - js("$state.show_debug ? $state.true_path : []"), - stroke=Plot.constantly("True Path"), - strokeWidth=2 - ) + + planned_path + # Draw robot - + Plot.text( - js("[[$state.robot_pose.x, $state.robot_pose.y]]"), - text=Plot.constantly("🤖"), - fontSize=30, - fill=Plot.constantly("Robot"), - title="Robot", - rotate=js("$state.robot_pose.heading * 180 / Math.PI") # Convert radians to degrees - ) + + Plot.cond( + js("$state.show_true_position"), + [Plot.text( + js("[[$state.robot_pose.x, $state.robot_pose.y]]"), + text=Plot.constantly("🤖"), + fontSize=30, + textAnchor="middle", + dy="-0.35em", + rotate=js("$state.robot_pose.heading * 180 / Math.PI")), + true_path, + sensor_rays + ] + ) + Plot.domain([0, 10], [0, 10]) + Plot.grid() + Plot.aspectRatio(1) + Plot.colorMap({ "Walls": "#666", - "Drawing": "#999", + "Sensor Rays": "red", + "True Path": "green", "Robot Path": "blue", - "Robot": "blue" }) + Plot.colorLegend() - # Add sensor rays when show_debug is true - + Plot.line( - js(""" - $state.show_debug && $state.sensor_readings ? - Array.from($state.sensor_readings).flatMap((r, i) => { - const heading = $state.robot_pose.heading || 0; - const angle = heading + (i * Math.PI * 2) / 8; - const x = $state.robot_pose.x; - const y = $state.robot_pose.y; - return [ - [x, y, i], - [x + r * Math.cos(angle), - y + r * Math.sin(angle), i] - ] - }) : [] - """), - z="2", - stroke="red", - strokeWidth=1 - ) + Plot.line( js(""" if (!$state.show_debug || !$state.possible_paths) {return [];}; @@ -312,12 +342,6 @@ def initial_walls(): strokeOpacity=0.2 ) + Plot.clip() - + Plot.colorMap({ - "Walls": "#666", - "Robot": "blue", - "Sensor Rays": "red", - "True Path": "green" - }) ) @@ -399,7 +423,7 @@ def debug_reality(widget, e): # Generate possible paths key = jax.random.PRNGKey(0) - possible_paths = sample_possible_paths(key, 20, widget.state.robot_path) + possible_paths = sample_possible_paths(key, 20, widget.state.robot_path, widget.state.walls) # Update state with final readings, pose, and full path @@ -430,29 +454,29 @@ def clear_state(w, _): "class": js("$state.selected_tool === 'path' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), "onClick": js("() => $state.selected_tool = 'path'") }, "Draw Robot Path"], - ["button", { - "class": "px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400", - "onClick": debug_reality - }, "Debug Reality"], ["button", { "class": "px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400", "onClick": clear_state }, "Clear"] ] -instructions = Plot.md(""" -1. Draw walls -2. Draw a robot path -3. Adjust noise levels to see how they affect: - - Sensor readings - - Motion uncertainty - """) | ["div", js("$state.debug_message")] + +reality_toggle = Plot.html("") | ["label.flex.items-center.gap-2.p-2", + "Show true position:", + ["input", { + "type": "checkbox", + "checked": js("$state.show_true_position"), + "onChange": js("(e) => $state.show_true_position = e.target.checked") + }]] + +# Modify the onChange handlers at the bottom ( canvas & - (toolbar | instructions | sliders) + (toolbar | sliders | reality_toggle | sensor_rays + {"height": 200}) | Plot.initialState(initial_state, sync=True) | Plot.onChange({ "robot_path": debug_reality, "sensor_noise": debug_reality, - "motion_noise": debug_reality + "motion_noise": debug_reality, + "walls": debug_reality })) From b6ac6d3e3404f6544dab4a4030b49b2f53290740 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 14:03:37 +0100 Subject: [PATCH 60/86] simplify planned path fix possible paths line mark --- robot_2/where_am_i.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 1583f70..acc1bdf 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -197,13 +197,15 @@ def drawing_system(on_complete): const dx = e.x - last[0]; const dy = e.y - last[1]; // Only add point if moved more than threshold distance - if (Math.sqrt(dx*dx + dy*dy) > 0.2) {{ + if (Math.sqrt(dx*dx + dy*dy) >= 1) {{ $state.update(['{key}', 'append', [e.x, e.y, e.startTime]]); }} }} }}"""), "onDrawEnd": js(f"""(e) => {{ if ($state.{key}.length > 1) {{ + const points = [...$state.{key}, [e.x, e.y, e.startTime]] + // Simplify line by keeping only every 3rd point // keep this, we may re-enable later //const simplified = $state.{key}.filter((_, i) => i % 3 === 0); @@ -282,7 +284,9 @@ def initial_walls(): planned_path = Plot.line( js("$state.robot_path"), stroke=Plot.constantly("Robot Path"), - strokeWidth=PATH_WIDTH), + strokeWidth=2, + r=3, + marker="circle"), canvas = ( # Draw completed walls @@ -339,7 +343,8 @@ def initial_walls(): ) """, expression=False), stroke="blue", - strokeOpacity=0.2 + strokeOpacity=0.2, + z="2" ) + Plot.clip() ) @@ -365,8 +370,7 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: def create_reality(walls_list: List[List[float]], motion_noise: float, sensor_noise: float) -> reality.Reality: """Create Reality instance with proper JAX arrays""" - walls = convert_walls_to_jax(walls_list) - return reality.Reality(walls, motion_noise, sensor_noise) + return reality.Reality(convert_walls_to_jax(walls_list), motion_noise, sensor_noise) def path_to_controls(path_points: List[List[float]]) -> List[Tuple[float, float]]: """Convert a series of points into (distance, angle) control pairs From d4929c2f3b0b7e0f726ab767e116defeabab9f31 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 14:13:41 +0100 Subject: [PATCH 61/86] use motion settings when sampling possible paths --- robot_2/where_am_i.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index acc1bdf..2bc7bf4 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -145,12 +145,13 @@ def generate_path(motion_settings: MotionSettings, start_pose: Pose, controls: L return path -def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: List[List[float]], walls: List[List[float]]): +def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: List[List[float]], walls: List[List[float]], motion_noise: float): """Generate n possible paths given the planned path, respecting walls""" controls = path_to_controls(robot_path) start_point = robot_path[0] - start_pose = Pose(jnp.array(start_point[:2]), 0.0) - settings = MotionSettings() + start_pose = Pose(jnp.array(start_point[:2], dtype=jnp.float32), 0.0) + motion_noise = jnp.float32(motion_noise) + settings = MotionSettings(p_noise=motion_noise, hd_noise=motion_noise) # Convert walls to JAX array format wall_segments = convert_walls_to_jax(walls) @@ -272,7 +273,8 @@ def initial_walls(): """), z="2", stroke="red", - strokeWidth=1 + strokeWidth=1, + marker="circle" ) true_path = Plot.line( @@ -427,7 +429,7 @@ def debug_reality(widget, e): # Generate possible paths key = jax.random.PRNGKey(0) - possible_paths = sample_possible_paths(key, 20, widget.state.robot_path, widget.state.walls) + possible_paths = sample_possible_paths(key, 20, widget.state.robot_path, widget.state.walls, widget.state.motion_noise) # Update state with final readings, pose, and full path From 74cc342f6d47f362b688d4f571e7cdd9e7986302 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sat, 30 Nov 2024 17:33:15 +0100 Subject: [PATCH 62/86] functional approach fix heading rendering --- robot_2/reality.py | 172 ++++++++++++++++++++---------------------- robot_2/where_am_i.py | 70 ++++++++--------- 2 files changed, 116 insertions(+), 126 deletions(-) diff --git a/robot_2/reality.py b/robot_2/reality.py index 014f032..69b2fc5 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -36,98 +36,86 @@ def rotate(self, angle: float) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) -@final # Indicates class won't be subclassed -class Reality: - """ - Simulates the true state of the world, which the robot can only access through - noisy sensors and imperfect motion. +def execute_control(walls: jnp.ndarray, motion_noise: float, sensor_noise: float, + current_pose: Pose, control: Tuple[float, float], + key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: + """Execute a control command with noise, stopping if we hit a wall - The robot: - - Can try to move (but motion will have noise) - - Can take sensor readings (but they will have noise) - - Cannot access _true_pose directly + Args: + walls: JAX array of shape (N, 2, 2) containing wall segments + motion_noise: Standard deviation of noise added to motion + sensor_noise: Standard deviation of noise added to sensor readings + current_pose: The pose to start from + control: (distance, angle) tuple of motion command + key: JAX random key for noise generation + + Returns: + (new_pose, sensor_readings, new_key) tuple """ + dist, angle = control + k1, k2, k3 = jax.random.split(key, 3) + + # Add noise to motion + noisy_dist = dist + jax.random.normal(k1) * motion_noise + noisy_angle = angle + jax.random.normal(k2) * motion_noise + + # First rotate (can always rotate) + new_pose = current_pose.rotate(noisy_angle) + + # Then try to move forward, checking for collisions + min_dist = compute_distance_to_wall(walls, new_pose, 0.0) # 0 angle = forward + + # Only move as far as we can before hitting a wall + safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) + safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards + + new_pose = new_pose.step_along(safe_dist) + + # Get sensor readings from new position + readings, k4 = get_sensor_readings(walls, sensor_noise, new_pose, k3) + + return new_pose, readings, k4 - def __init__(self, walls: jnp.ndarray, motion_noise: float, sensor_noise: float, - initial_pose: Optional[Pose] = None): - """ - Args: - walls: JAX array of shape (N, 2, 2) containing wall segments - motion_noise: Standard deviation of noise added to motion - sensor_noise: Standard deviation of noise added to sensor readings - initial_pose: Optional starting pose, defaults to (0.5, 0.5, 0.0) - """ - self.walls = walls - self.motion_noise = motion_noise - self.sensor_noise = sensor_noise - self._true_pose = initial_pose if initial_pose is not None else Pose(jnp.array([0.5, 0.5]), 0.0) - self._key = PRNGKey(0) - - def execute_control(self, control: Tuple[float, float]) -> List[float]: - """Execute a control command with noise, stopping if we hit a wall""" - dist, angle = control - # Add noise to motion - noisy_dist = dist + jax.random.normal(self._key) * self.motion_noise - noisy_angle = angle + jax.random.normal(self._key) * self.motion_noise - - # First rotate (can always rotate) - self._true_pose = self._true_pose.rotate(noisy_angle) - - # Then try to move forward, checking for collisions - ray_dir = jnp.array([jnp.cos(self._true_pose.hd), jnp.sin(self._true_pose.hd)]) - - # Use our existing ray-casting to check distance to nearest wall - min_dist = self._compute_distance_to_wall(0.0) # 0 angle = forward - - # Only move as far as we can before hitting a wall (minus small safety margin) - safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) - safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards - - self._true_pose = self._true_pose.step_along(safe_dist) - - # Return sensor readings from new position - return self.get_sensor_readings() - - def get_sensor_readings(self) -> jnp.ndarray: - """Return noisy distance readings to walls""" - angles = jnp.linspace(0, 2*jnp.pi, 8, endpoint=False) # 8 evenly spaced sensors - keys = jax.random.split(self._key, 8) - self._key = keys[0] - - def get_reading(key, angle): - true_dist = self._compute_distance_to_wall(angle) - return true_dist + jax.random.normal(key) * self.sensor_noise - - readings = jax.vmap(get_reading)(keys[1:], angles[:-1]) - return readings - - def _compute_distance_to_wall(self, sensor_angle: float) -> jax.Array: - """Compute true distance to nearest wall along sensor ray using fast 2D ray-segment intersection""" - ray_start = self._true_pose.p - ray_angle = self._true_pose.hd + sensor_angle - ray_dir = jnp.array([jnp.cos(ray_angle), jnp.sin(ray_angle)]) - - # Vectorized computation for all walls at once - p1 = self.walls[:, 0] # Shape: (N, 2) - p2 = self.walls[:, 1] # Shape: (N, 2) - - # Wall direction vectors - wall_vec = p2 - p1 # Shape: (N, 2) - - # Vector from wall start to ray start - to_start = ray_start - p1 # Shape: (N, 2) - - # Compute determinant (cross product in 2D) - # This tells us if ray and wall are parallel and their relative orientation - det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) - - # Compute intersection parameters - u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) - t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) - - # Valid intersections: not parallel, in front of ray, within wall segment - is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) - - # Find minimum valid distance - min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) - return jnp.where(jnp.isinf(min_dist), 10.0, min_dist) \ No newline at end of file +def get_sensor_readings(walls: jnp.ndarray, sensor_noise: float, + pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: + """Return noisy distance readings to walls from given pose""" + angles = jnp.linspace(0, 2*jnp.pi, 8, endpoint=False) + keys = jax.random.split(key, 8) + + def get_reading(key, angle): + true_dist = compute_distance_to_wall(walls, pose, angle) + return true_dist + jax.random.normal(key) * sensor_noise + + readings = jax.vmap(get_reading)(keys[1:], angles[:-1]) + return readings, keys[0] + +def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, sensor_angle: float) -> float: + """Compute true distance to nearest wall along sensor ray""" + ray_start = pose.p + ray_angle = pose.hd + sensor_angle + ray_dir = jnp.array([jnp.cos(ray_angle), jnp.sin(ray_angle)]) + + # Vectorized computation for all walls at once + p1 = walls[:, 0] # Shape: (N, 2) + p2 = walls[:, 1] # Shape: (N, 2) + + # Wall direction vectors + wall_vec = p2 - p1 # Shape: (N, 2) + + # Vector from wall start to ray start + to_start = ray_start - p1 # Shape: (N, 2) + + # Compute determinant (cross product in 2D) + # This tells us if ray and wall are parallel and their relative orientation + det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) + + # Compute intersection parameters + u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) + t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) + + # Valid intersections: not parallel, in front of ray, within wall segment + is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) + + # Find minimum valid distance + min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) + return jnp.where(jnp.isinf(min_dist), 10.0, min_dist) \ No newline at end of file diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 2bc7bf4..0d9f02f 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -64,6 +64,7 @@ from genjax import normal, mv_normal_diag from penzai import pz # Import penzai for pytree dataclasses from typing import List, Tuple +from robot_2.reality import Pose @pz.pytree_dataclass class MotionSettings(genjax.PythonicPytree): @@ -150,20 +151,28 @@ def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: Lis controls = path_to_controls(robot_path) start_point = robot_path[0] start_pose = Pose(jnp.array(start_point[:2], dtype=jnp.float32), 0.0) - motion_noise = jnp.float32(motion_noise) - settings = MotionSettings(p_noise=motion_noise, hd_noise=motion_noise) - - # Convert walls to JAX array format - wall_segments = convert_walls_to_jax(walls) + walls = convert_walls_to_jax(walls) # Split key for multiple samples keys = jax.random.split(key, n_paths) - # Sample paths using GenJAX + # Sample paths using execute_control def sample_single_path(k): - trace = generate_path.simulate(k, (settings, start_pose, controls, wall_segments)) - poses = trace.get_retval() - return jnp.stack([pose.p for pose in poses]) + pose = start_pose + path = [pose.p] + + for control in controls: + pose, _, k = reality.execute_control( + walls=walls, + motion_noise=motion_noise, + sensor_noise=0.0, # Don't need sensor readings for path sampling + current_pose=pose, + control=control, + key=k + ) + path.append(pose.p) + + return jnp.stack(path) paths = jax.vmap(sample_single_path)(keys) return paths @@ -322,7 +331,7 @@ def initial_walls(): fontSize=30, textAnchor="middle", dy="-0.35em", - rotate=js("$state.robot_pose.heading * 180 / Math.PI")), + rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI")), true_path, sensor_rays ] @@ -370,10 +379,6 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: segments = jnp.stack([p1[mask][:, :2], p2[mask][:, :2]], axis=1) return segments -def create_reality(walls_list: List[List[float]], motion_noise: float, sensor_noise: float) -> reality.Reality: - """Create Reality instance with proper JAX arrays""" - return reality.Reality(convert_walls_to_jax(walls_list), motion_noise, sensor_noise) - def path_to_controls(path_points: List[List[float]]) -> List[Tuple[float, float]]: """Convert a series of points into (distance, angle) control pairs Returns: List of (forward_dist, rotation_angle) controls @@ -402,49 +407,46 @@ def path_to_controls(path_points: List[List[float]]) -> List[Tuple[float, float] return controls def debug_reality(widget, e): - """Quick visual check of Reality class""" if not widget.state.robot_path: - return # Need a path to get initial pose + return - # Get initial pose from start of path - start_point = widget.state.robot_path[0] - initial_pose = Pose(jnp.array([start_point[0], start_point[1]]), 0.0) - walls = convert_walls_to_jax(widget.state.walls) - world = reality.Reality(walls, - motion_noise=widget.state.motion_noise, - sensor_noise=widget.state.sensor_noise, - initial_pose=initial_pose) + key = jax.random.PRNGKey(0) + current_pose = Pose(jnp.array(widget.state.robot_path[0][:2]), 0.0) - # Convert path to controls and execute them controls = path_to_controls(widget.state.robot_path) readings = [] - true_path = [[float(world._true_pose.p[0]), float(world._true_pose.p[1])]] # Start position + true_path = [[float(current_pose.p[0]), float(current_pose.p[1])]] for control in controls: - reading = world.execute_control(control) + current_pose, reading, key = reality.execute_control( + walls=walls, + motion_noise=widget.state.motion_noise, + sensor_noise=widget.state.sensor_noise, + current_pose=current_pose, + control=control, + key=key + ) readings.append(reading) - # Record position after each control - true_path.append([float(world._true_pose.p[0]), float(world._true_pose.p[1])]) + true_path.append([float(current_pose.p[0]), float(current_pose.p[1])]) # Generate possible paths key = jax.random.PRNGKey(0) possible_paths = sample_possible_paths(key, 20, widget.state.robot_path, widget.state.walls, widget.state.motion_noise) - # Update state with final readings, pose, and full path widget.state.update({ "robot_pose": { - "x": float(world._true_pose.p[0]), - "y": float(world._true_pose.p[1]), - "heading": float(world._true_pose.hd) + "x": float(current_pose.p[0]), + "y": float(current_pose.p[1]), + "heading": float(current_pose.hd) }, "possible_paths": possible_paths, "sensor_readings": readings[-1] if readings else [], "true_path": true_path, "show_debug": True }) - + def clear_state(w, _): w.state.update(initial_state) From 50eccba103e7f4234c8e01d786c00679b6d07bae Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sun, 1 Dec 2024 07:11:59 +0100 Subject: [PATCH 63/86] allow no walls --- robot_2/reality.py | 9 +++++++-- robot_2/where_am_i.py | 5 ++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/robot_2/reality.py b/robot_2/reality.py index 69b2fc5..45b01be 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -91,9 +91,14 @@ def get_reading(key, angle): def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, sensor_angle: float) -> float: """Compute true distance to nearest wall along sensor ray""" + if walls.shape[0] == 0: # No walls + return 10.0 # Return max sensor range + ray_start = pose.p - ray_angle = pose.hd + sensor_angle - ray_dir = jnp.array([jnp.cos(ray_angle), jnp.sin(ray_angle)]) + ray_dir = jnp.array([ + jnp.cos(pose.hd + sensor_angle), + jnp.sin(pose.hd + sensor_angle) + ]) # Vectorized computation for all walls at once p1 = walls[:, 0] # Shape: (N, 2) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 0d9f02f..baeb08c 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -260,10 +260,11 @@ def initial_walls(): "show_sensors": True, "selected_tool": "path", "robot_path": [], + "possible_paths": [], "estimated_pose": None, "sensor_readings": [], "show_uncertainty": True, - "show_true_position": True # New flag for toggling true position visibility + "show_true_position": False } sensor_rays = Plot.line( @@ -368,6 +369,8 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: First 2 = start/end point Second 2 = x,y coordinates """ + if not walls_list: + return jnp.array([]).reshape((0, 2, 2)) # Empty array with shape (0,2,2) # Convert to array and reshape to (N,3) where columns are x,y,timestamp points = jnp.array(walls_list).reshape(-1, 3) # Get consecutive pairs of points From 219374d9d1f5950874ac0354eafabb8abae7cd68 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sun, 1 Dec 2024 08:43:06 +0100 Subject: [PATCH 64/86] layout --- robot_2/where_am_i.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index baeb08c..deb1cd7 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -234,7 +234,7 @@ def drawing_system(on_complete): label="Sensor Noise:", showValue=True ) - & Plot.Slider( + | Plot.Slider( "motion_noise", range=[0, 1], step=0.05, @@ -451,39 +451,39 @@ def debug_reality(widget, e): }) def clear_state(w, _): - w.state.update(initial_state) + w.state.update(initial_state | {"selected_tool": w.state.selected_tool}) - +selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" + # Add debug button to toolbar -toolbar = Plot.html("Select tool:") | ["div", {"class": "flex gap-2 h-10"}, - ["button", { - "class": js("$state.selected_tool === 'walls' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), - "onClick": js("() => $state.selected_tool = 'walls'") - }, "Draw Walls"], - ["button", { - "class": js("$state.selected_tool === 'path' ? 'px-3 py-1 rounded bg-gray-300 hover:bg-gray-400 active:bg-gray-500 focus:outline-none' : 'px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400 focus:outline-none'"), +toolbar = Plot.html("Select tool:") | ["div.flex.gap-2", + [selectable_button, { + "data-selected": js("$state.selected_tool === 'path'"), "onClick": js("() => $state.selected_tool = 'path'") - }, "Draw Robot Path"], - ["button", { - "class": "px-3 py-1 rounded bg-gray-200 hover:bg-gray-300 active:bg-gray-400", + }, "🤖 Path"], + [selectable_button, { + "data-selected": js("$state.selected_tool === 'walls'"), + "onClick": js("() => $state.selected_tool = 'walls'") + }, "✏️ Walls"], + [selectable_button, { "onClick": clear_state }, "Clear"] ] -reality_toggle = Plot.html("") | ["label.flex.items-center.gap-2.p-2", - "Show true position:", +reality_toggle = Plot.html("") | ["label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", ["input", { "type": "checkbox", "checked": js("$state.show_true_position"), "onChange": js("(e) => $state.show_true_position = e.target.checked") - }]] + }], "Show true position"] # Modify the onChange handlers at the bottom ( canvas & - (toolbar | sliders | reality_toggle | sensor_rays + {"height": 200}) + (sliders | toolbar | reality_toggle | sensor_rays + {"height": 200}) + & {"widths": ["400px", 1]} | Plot.initialState(initial_state, sync=True) | Plot.onChange({ "robot_path": debug_reality, From 765b486ff8999150da2041c6763837221feb875a Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sun, 1 Dec 2024 13:54:56 +0100 Subject: [PATCH 65/86] jax jit --- robot_2/reality.py | 17 ++-- robot_2/where_am_i.py | 182 +++++++++++++++++++++++------------------- 2 files changed, 112 insertions(+), 87 deletions(-) diff --git a/robot_2/reality.py b/robot_2/reality.py index 45b01be..c498fb9 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -36,6 +36,7 @@ def rotate(self, angle: float) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) +@jax.jit def execute_control(walls: jnp.ndarray, motion_noise: float, sensor_noise: float, current_pose: Pose, control: Tuple[float, float], key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: @@ -76,19 +77,21 @@ def execute_control(walls: jnp.ndarray, motion_noise: float, sensor_noise: float return new_pose, readings, k4 +@jax.jit def get_sensor_readings(walls: jnp.ndarray, sensor_noise: float, pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: """Return noisy distance readings to walls from given pose""" - angles = jnp.linspace(0, 2*jnp.pi, 8, endpoint=False) - keys = jax.random.split(key, 8) + n_sensors = 8 + key, subkey = jax.random.split(key) + angles = jnp.linspace(0, 2*jnp.pi, n_sensors, endpoint=False) + noise = jax.random.normal(subkey, (n_sensors,)) * sensor_noise - def get_reading(key, angle): - true_dist = compute_distance_to_wall(walls, pose, angle) - return true_dist + jax.random.normal(key) * sensor_noise + # Compute all distances at once + readings = jax.vmap(lambda a: compute_distance_to_wall(walls, pose, a))(angles) - readings = jax.vmap(get_reading)(keys[1:], angles[:-1]) - return readings, keys[0] + return readings + noise, key +@jax.jit def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, sensor_angle: float) -> float: """Compute true distance to nearest wall along sensor ray""" if walls.shape[0] == 0: # No walls diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index deb1cd7..53c3d56 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -65,6 +65,7 @@ from penzai import pz # Import penzai for pytree dataclasses from typing import List, Tuple from robot_2.reality import Pose +from functools import partial @pz.pytree_dataclass class MotionSettings(genjax.PythonicPytree): @@ -146,41 +147,52 @@ def generate_path(motion_settings: MotionSettings, start_pose: Pose, controls: L return path -def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, robot_path: List[List[float]], walls: List[List[float]], motion_noise: float): +def sample_single_path(carry, control, walls, motion_noise): + """Single step of path sampling that can be used with scan""" + pose, key = carry + dist, angle = control + pose, _, key = reality.execute_control( + walls=walls, + motion_noise=motion_noise, + sensor_noise=0.0, # Don't need sensor readings for path sampling + current_pose=pose, + control=(dist, angle), + key=key + ) + return (pose, key), pose.p + +@partial(jax.jit, static_argnums=(1,)) +def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, + robot_path: jnp.ndarray, walls: jnp.ndarray, + motion_noise: float): """Generate n possible paths given the planned path, respecting walls""" - controls = path_to_controls(robot_path) - start_point = robot_path[0] - start_pose = Pose(jnp.array(start_point[:2], dtype=jnp.float32), 0.0) - walls = convert_walls_to_jax(walls) + # Extract just x,y coordinates from path + path_points = robot_path[:, :2] # Shape: (N, 2) + controls = path_to_controls(path_points) + + start_point = path_points[0] + start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) # Split key for multiple samples keys = jax.random.split(key, n_paths) - # Sample paths using execute_control - def sample_single_path(k): - pose = start_pose - path = [pose.p] - - for control in controls: - pose, _, k = reality.execute_control( - walls=walls, - motion_noise=motion_noise, - sensor_noise=0.0, # Don't need sensor readings for path sampling - current_pose=pose, - control=control, - key=k - ) - path.append(pose.p) - - return jnp.stack(path) + def sample_path_scan(key): + init_carry = (start_pose, key) + (final_pose, final_key), path_points = jax.lax.scan( + lambda carry, control: sample_single_path(carry, control, walls, motion_noise), + init_carry, + controls + ) + return jnp.concatenate([start_pose.p[None, :], path_points], axis=0) - paths = jax.vmap(sample_single_path)(keys) + paths = jax.vmap(sample_path_scan)(keys) return paths _gensym_counter = 0 WALL_WIDTH=6 PATH_WIDTH=6 +SEGMENT_THRESHOLD=0.5 def gensym(prefix: str = "g") -> str: """Generate a unique symbol with an optional prefix, similar to Clojure's gensym.""" @@ -207,7 +219,7 @@ def drawing_system(on_complete): const dx = e.x - last[0]; const dy = e.y - last[1]; // Only add point if moved more than threshold distance - if (Math.sqrt(dx*dx + dy*dy) >= 1) {{ + if (Math.sqrt(dx*dx + dy*dy) >= {SEGMENT_THRESHOLD}) {{ $state.update(['{key}', 'append', [e.x, e.y, e.startTime]]); }} }} @@ -370,83 +382,93 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: Second 2 = x,y coordinates """ if not walls_list: - return jnp.array([]).reshape((0, 2, 2)) # Empty array with shape (0,2,2) - # Convert to array and reshape to (N,3) where columns are x,y,timestamp - points = jnp.array(walls_list).reshape(-1, 3) + return jnp.array([]).reshape((0, 2, 2)) + + # Convert everything to JAX at once, using float32 for timestamps + points = jnp.array(walls_list, dtype=jnp.float32) # Shape: (N, 3) + # Get consecutive pairs of points - p1 = points[:-1] # All points except last - p2 = points[1:] # All points except first - # Keep only pairs with matching timestamps - mask = p1[:, 2] == p2[:, 2] - # Stack the x,y coordinates into wall segments - segments = jnp.stack([p1[mask][:, :2], p2[mask][:, :2]], axis=1) - return segments + p1 = points[:-1] # Shape: (N-1, 3) + p2 = points[1:] # Shape: (N-1, 3) + + # Create wall segments array + segments = jnp.stack([ + p1[:, :2], # x,y coordinates of start points + p2[:, :2] # x,y coordinates of end points + ], axis=1) # Shape: (N-1, 2, 2) + + # Use timestamps to mask valid segments + valid_mask = p1[:, 2] == p2[:, 2] + + # Return masked segments + return segments * valid_mask[:, None, None] -def path_to_controls(path_points: List[List[float]]) -> List[Tuple[float, float]]: +def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: """Convert a series of points into (distance, angle) control pairs - Returns: List of (forward_dist, rotation_angle) controls + Returns: JAX array of shape (N,2) containing (forward_dist, rotation_angle) controls """ - controls = [] - for i in range(len(path_points) - 1): - p1 = jnp.array(path_points[i][:2]) # current point [x,y] - p2 = jnp.array(path_points[i+1][:2]) # next point [x,y] - - # Calculate distance and angle to next point - delta = p2 - p1 - distance = jnp.linalg.norm(delta) - target_angle = jnp.arctan2(delta[1], delta[0]) - - # If not first point, need to rotate from previous heading - if i > 0: - prev_delta = p1 - jnp.array(path_points[i-1][:2]) - prev_angle = jnp.arctan2(prev_delta[1], prev_delta[0]) - rotation = target_angle - prev_angle - else: - # For first point, rotate from initial heading (0) - rotation = target_angle - - controls.append((float(distance), float(rotation))) + points = jnp.array([p[:2] for p in path_points]) + deltas = points[1:] - points[:-1] + distances = jnp.linalg.norm(deltas, axis=1) + angles = jnp.arctan2(deltas[:, 1], deltas[:, 0]) + # Calculate angle changes + angle_changes = jnp.diff(angles, prepend=0.0) + return jnp.stack([distances, angle_changes], axis=1) + +@jax.jit +def simulate_robot_path(start_pose: Pose, controls: jnp.ndarray, walls: jnp.ndarray, + motion_noise: float, sensor_noise: float, key: jax.random.PRNGKey): + """Jitted pure function for simulating robot path""" + def step_fn(carry, control): + pose, k = carry + new_pose, readings, new_key = reality.execute_control( + walls=walls, + motion_noise=motion_noise, + sensor_noise=sensor_noise, + current_pose=pose, + control=control, + key=k + ) + return (new_pose, new_key), (new_pose, readings) - return controls + return jax.lax.scan(step_fn, (start_pose, key), controls) def debug_reality(widget, e): if not widget.state.robot_path: return + # Handle data conversion at the boundary + path = jnp.array(widget.state.robot_path, dtype=jnp.float32) walls = convert_walls_to_jax(widget.state.walls) + + start_pose = Pose(path[0, :2], 0.0) + controls = path_to_controls(path) key = jax.random.PRNGKey(0) - current_pose = Pose(jnp.array(widget.state.robot_path[0][:2]), 0.0) - controls = path_to_controls(widget.state.robot_path) - readings = [] - true_path = [[float(current_pose.p[0]), float(current_pose.p[1])]] + # Use jitted function for core computation + (final_pose, _), (poses, readings) = simulate_robot_path( + start_pose, controls, walls, + widget.state.motion_noise, widget.state.sensor_noise, key + ) - for control in controls: - current_pose, reading, key = reality.execute_control( - walls=walls, - motion_noise=widget.state.motion_noise, - sensor_noise=widget.state.sensor_noise, - current_pose=current_pose, - control=control, - key=key - ) - readings.append(reading) - true_path.append([float(current_pose.p[0]), float(current_pose.p[1])]) + # Convert poses to path + true_path = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) # Generate possible paths - key = jax.random.PRNGKey(0) - possible_paths = sample_possible_paths(key, 20, widget.state.robot_path, widget.state.walls, widget.state.motion_noise) + possible_paths = sample_possible_paths( + key, 20, path, walls, widget.state.motion_noise + ) - # Update state with final readings, pose, and full path + # Update widget state widget.state.update({ "robot_pose": { - "x": float(current_pose.p[0]), - "y": float(current_pose.p[1]), - "heading": float(current_pose.hd) + "x": float(final_pose.p[0]), + "y": float(final_pose.p[1]), + "heading": float(final_pose.hd) }, "possible_paths": possible_paths, - "sensor_readings": readings[-1] if readings else [], - "true_path": true_path, + "sensor_readings": readings[-1] if len(readings) > 0 else [], + "true_path": [[float(x), float(y)] for x, y in true_path], "show_debug": True }) From 7cd2f4b967cde6a147b3415487cdbb59ac1efe13 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sun, 1 Dec 2024 14:16:50 +0100 Subject: [PATCH 66/86] dynamic sensor count --- robot_2/reality.py | 45 +++++------- robot_2/where_am_i.py | 162 ++++++++++++++++-------------------------- 2 files changed, 79 insertions(+), 128 deletions(-) diff --git a/robot_2/reality.py b/robot_2/reality.py index c498fb9..cc4de27 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -20,6 +20,7 @@ from typing import List, Tuple, final, Optional import jax.random from jax.random import PRNGKey +from functools import partial @dataclass class Pose: @@ -36,35 +37,23 @@ def rotate(self, angle: float) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) -@jax.jit -def execute_control(walls: jnp.ndarray, motion_noise: float, sensor_noise: float, +@partial(jax.jit, static_argnums=(1,)) +def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", current_pose: Pose, control: Tuple[float, float], key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: - """Execute a control command with noise, stopping if we hit a wall - - Args: - walls: JAX array of shape (N, 2, 2) containing wall segments - motion_noise: Standard deviation of noise added to motion - sensor_noise: Standard deviation of noise added to sensor readings - current_pose: The pose to start from - control: (distance, angle) tuple of motion command - key: JAX random key for noise generation - - Returns: - (new_pose, sensor_readings, new_key) tuple - """ + """Execute a control command with noise, stopping if we hit a wall""" dist, angle = control k1, k2, k3 = jax.random.split(key, 3) # Add noise to motion - noisy_dist = dist + jax.random.normal(k1) * motion_noise - noisy_angle = angle + jax.random.normal(k2) * motion_noise + noisy_dist = dist + jax.random.normal(k1) * settings.p_noise + noisy_angle = angle + jax.random.normal(k2) * settings.hd_noise # First rotate (can always rotate) new_pose = current_pose.rotate(noisy_angle) # Then try to move forward, checking for collisions - min_dist = compute_distance_to_wall(walls, new_pose, 0.0) # 0 angle = forward + min_dist = compute_distance_to_wall(walls, new_pose, 0.0, settings.sensor_range) # Only move as far as we can before hitting a wall safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) @@ -73,29 +62,30 @@ def execute_control(walls: jnp.ndarray, motion_noise: float, sensor_noise: float new_pose = new_pose.step_along(safe_dist) # Get sensor readings from new position - readings, k4 = get_sensor_readings(walls, sensor_noise, new_pose, k3) + readings, k4 = get_sensor_readings(walls, n_sensors, settings, new_pose, k3) return new_pose, readings, k4 -@jax.jit -def get_sensor_readings(walls: jnp.ndarray, sensor_noise: float, +@partial(jax.jit, static_argnums=(1,)) +def get_sensor_readings(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: """Return noisy distance readings to walls from given pose""" - n_sensors = 8 key, subkey = jax.random.split(key) angles = jnp.linspace(0, 2*jnp.pi, n_sensors, endpoint=False) - noise = jax.random.normal(subkey, (n_sensors,)) * sensor_noise + noise = jax.random.normal(subkey, (n_sensors,)) * settings.sensor_noise # Compute all distances at once - readings = jax.vmap(lambda a: compute_distance_to_wall(walls, pose, a))(angles) + readings = jax.vmap(lambda a: compute_distance_to_wall( + walls, pose, a, settings.sensor_range))(angles) return readings + noise, key @jax.jit -def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, sensor_angle: float) -> float: +def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, + sensor_angle: float, sensor_range: float) -> float: """Compute true distance to nearest wall along sensor ray""" if walls.shape[0] == 0: # No walls - return 10.0 # Return max sensor range + return sensor_range ray_start = pose.p ray_dir = jnp.array([ @@ -114,7 +104,6 @@ def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, sensor_angle: float to_start = ray_start - p1 # Shape: (N, 2) # Compute determinant (cross product in 2D) - # This tells us if ray and wall are parallel and their relative orientation det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) # Compute intersection parameters @@ -126,4 +115,4 @@ def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, sensor_angle: float # Find minimum valid distance min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) - return jnp.where(jnp.isinf(min_dist), 10.0, min_dist) \ No newline at end of file + return jnp.where(jnp.isinf(min_dist), sensor_range, min_dist) \ No newline at end of file diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 53c3d56..9678e0c 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -67,11 +67,16 @@ from robot_2.reality import Pose from functools import partial +WALL_WIDTH=6 +PATH_WIDTH=6 +SEGMENT_THRESHOLD=0.25 @pz.pytree_dataclass -class MotionSettings(genjax.PythonicPytree): - """Settings for motion uncertainty""" - p_noise: float = 0.1 # Position noise - hd_noise: float = 0.1 # Heading noise +class RobotSettings(genjax.PythonicPytree): + """Robot configuration and uncertainty settings""" + p_noise: float = 0.1 # Position noise + hd_noise: float = 0.1 # Heading noise + sensor_noise: float = 0.1 # Sensor noise + sensor_range: float = 10.0 # Maximum sensor range @pz.pytree_dataclass class Pose(genjax.PythonicPytree): @@ -91,80 +96,23 @@ def rotate(self, angle: float) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) -@genjax.gen -def step_proposal(motion_settings: MotionSettings, start_pose: Pose, control: Tuple[float, float], walls: jnp.ndarray) -> Pose: - """Generate a noisy next pose given current pose and control, respecting walls""" - dist, angle = control - - # Sample noisy heading first - hd = normal( - start_pose.hd + angle, - motion_settings.hd_noise - ) @ "hd" - - # Calculate noisy forward distance - noisy_dist = normal( - dist, - motion_settings.p_noise - ) @ "dist" - - # Simplified collision check using just the wall segments - ray_start = start_pose.p - ray_dir = jnp.array([jnp.cos(hd), jnp.sin(hd)]) - - # Vectorized ray-wall intersection (copied from Reality class) - p1 = walls[:, 0] # Shape: (N, 2) - p2 = walls[:, 1] # Shape: (N, 2) - wall_vec = p2 - p1 - to_start = ray_start - p1 - - det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) - u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) - t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) - - is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) - min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) - min_dist = jnp.where(jnp.isinf(min_dist), 10.0, min_dist) - - # Limit distance to avoid wall collision - safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) - safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards - - # Calculate final position - p = start_pose.p + safe_dist * ray_dir - - return Pose(p, hd) - -@genjax.gen -def generate_path(motion_settings: MotionSettings, start_pose: Pose, controls: List[Tuple[float, float]], walls: jnp.ndarray): - """Generate a complete path by sampling noisy steps, respecting walls""" - pose = start_pose - path = [pose] - - for control in controls: - pose = step_proposal(motion_settings, pose, control, walls) @ f"step_{len(path)}" - path.append(pose) - - return path - -def sample_single_path(carry, control, walls, motion_noise): +def sample_single_path(carry, control, walls, n_sensors, settings): """Single step of path sampling that can be used with scan""" pose, key = carry - dist, angle = control pose, _, key = reality.execute_control( walls=walls, - motion_noise=motion_noise, - sensor_noise=0.0, # Don't need sensor readings for path sampling + n_sensors=n_sensors, + settings=settings, current_pose=pose, - control=(dist, angle), + control=control, key=key ) return (pose, key), pose.p -@partial(jax.jit, static_argnums=(1,)) -def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, +@partial(jax.jit, static_argnums=(1, 2)) # n_paths and n_sensors are static +def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, n_sensors: int, robot_path: jnp.ndarray, walls: jnp.ndarray, - motion_noise: float): + settings: RobotSettings): """Generate n possible paths given the planned path, respecting walls""" # Extract just x,y coordinates from path path_points = robot_path[:, :2] # Shape: (N, 2) @@ -179,7 +127,7 @@ def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, def sample_path_scan(key): init_carry = (start_pose, key) (final_pose, final_key), path_points = jax.lax.scan( - lambda carry, control: sample_single_path(carry, control, walls, motion_noise), + lambda carry, control: sample_single_path(carry, control, walls, n_sensors, settings), init_carry, controls ) @@ -190,9 +138,6 @@ def sample_path_scan(key): _gensym_counter = 0 -WALL_WIDTH=6 -PATH_WIDTH=6 -SEGMENT_THRESHOLD=0.5 def gensym(prefix: str = "g") -> str: """Generate a unique symbol with an optional prefix, similar to Clojure's gensym.""" @@ -242,17 +187,24 @@ def drawing_system(on_complete): Plot.Slider( "sensor_noise", range=[0, 1], - step=0.05, + step=0.02, label="Sensor Noise:", showValue=True ) | Plot.Slider( "motion_noise", range=[0, 1], - step=0.05, + step=0.02, label="Motion Noise:", showValue=True ) + | Plot.Slider( + "n_sensors", + range=[4, 32], + step=1, + label="Number of Sensors:", + showValue=True + ) ) def initial_walls(): @@ -269,6 +221,7 @@ def initial_walls(): "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, "sensor_noise": 0.1, "motion_noise": 0.1, + "n_sensors": 8, "show_sensors": True, "selected_tool": "path", "robot_path": [], @@ -280,24 +233,25 @@ def initial_walls(): } sensor_rays = Plot.line( - js(""" - Array.from($state.sensor_readings).flatMap((r, i) => { - const heading = $state.robot_pose.heading || 0; - const angle = heading + (i * Math.PI * 2) / 8; - const x = $state.robot_pose.x; - const y = $state.robot_pose.y; - return [ - [x, y, i], - [x + r * Math.cos(angle), - y + r * Math.sin(angle), i] - ] - }) - """), - z="2", - stroke="red", - strokeWidth=1, - marker="circle" - ) + js(""" + Array.from($state.sensor_readings).map((r, i) => { + const heading = $state.robot_pose.heading || 0; + const n_sensors = $state.n_sensors; + const angle = heading + (i * Math.PI * 2) / n_sensors; + const x = $state.robot_pose.x; + const y = $state.robot_pose.y; + return [ + [x, y, i], + [x + r * Math.cos(angle), + y + r * Math.sin(angle), i] + ] + }).flat() + """), + z="2", + stroke="red", + strokeWidth=1, + marker="circle" +) true_path = Plot.line( js("$state.true_path"), @@ -415,16 +369,16 @@ def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: angle_changes = jnp.diff(angles, prepend=0.0) return jnp.stack([distances, angle_changes], axis=1) -@jax.jit -def simulate_robot_path(start_pose: Pose, controls: jnp.ndarray, walls: jnp.ndarray, - motion_noise: float, sensor_noise: float, key: jax.random.PRNGKey): +@partial(jax.jit, static_argnums=(1,)) +def simulate_robot_path(start_pose: Pose, n_sensors: int, controls: jnp.ndarray, + walls: jnp.ndarray, settings: RobotSettings, key: jax.random.PRNGKey): """Jitted pure function for simulating robot path""" def step_fn(carry, control): pose, k = carry new_pose, readings, new_key = reality.execute_control( walls=walls, - motion_noise=motion_noise, - sensor_noise=sensor_noise, + n_sensors=n_sensors, + settings=settings, current_pose=pose, control=control, key=k @@ -437,9 +391,17 @@ def debug_reality(widget, e): if not widget.state.robot_path: return + # Create settings object + settings = RobotSettings( + p_noise=widget.state.motion_noise, + hd_noise=widget.state.motion_noise, + sensor_noise=widget.state.sensor_noise, + ) + # Handle data conversion at the boundary path = jnp.array(widget.state.robot_path, dtype=jnp.float32) walls = convert_walls_to_jax(widget.state.walls) + n_sensors = int(widget.state.n_sensors) # Convert to int explicitly start_pose = Pose(path[0, :2], 0.0) controls = path_to_controls(path) @@ -447,8 +409,7 @@ def debug_reality(widget, e): # Use jitted function for core computation (final_pose, _), (poses, readings) = simulate_robot_path( - start_pose, controls, walls, - widget.state.motion_noise, widget.state.sensor_noise, key + start_pose, n_sensors, controls, walls, settings, key ) # Convert poses to path @@ -456,7 +417,7 @@ def debug_reality(widget, e): # Generate possible paths possible_paths = sample_possible_paths( - key, 20, path, walls, widget.state.motion_noise + key, 20, n_sensors, path, walls, settings # Pass n_sensors separately ) # Update widget state @@ -511,5 +472,6 @@ def clear_state(w, _): "robot_path": debug_reality, "sensor_noise": debug_reality, "motion_noise": debug_reality, + "n_sensors": debug_reality, "walls": debug_reality })) From efd57a98e81975023d081ad73fde805dbb58c8b4 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sun, 1 Dec 2024 14:27:59 +0100 Subject: [PATCH 67/86] use masking for sensor readings --- robot_2/reality.py | 21 +++++++++++++++------ robot_2/where_am_i.py | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/robot_2/reality.py b/robot_2/reality.py index cc4de27..c347333 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -37,7 +37,7 @@ def rotate(self, angle: float) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) -@partial(jax.jit, static_argnums=(1,)) +@jax.jit def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", current_pose: Pose, control: Tuple[float, float], key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: @@ -66,19 +66,28 @@ def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings return new_pose, readings, k4 -@partial(jax.jit, static_argnums=(1,)) +@jax.jit def get_sensor_readings(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: """Return noisy distance readings to walls from given pose""" + MAX_SENSORS = 32 # Fixed maximum key, subkey = jax.random.split(key) - angles = jnp.linspace(0, 2*jnp.pi, n_sensors, endpoint=False) - noise = jax.random.normal(subkey, (n_sensors,)) * settings.sensor_noise - # Compute all distances at once + # Calculate angles based on n_sensors, but generate MAX_SENSORS of them + angle_step = 2 * jnp.pi / n_sensors + angles = jnp.arange(MAX_SENSORS) * angle_step + noise = jax.random.normal(subkey, (MAX_SENSORS,)) * settings.sensor_noise + readings = jax.vmap(lambda a: compute_distance_to_wall( walls, pose, a, settings.sensor_range))(angles) - return readings + noise, key + # Create a mask for the first n_sensors elements + mask = jnp.arange(MAX_SENSORS) < n_sensors + + # Apply mask and pad with zeros + readings = (readings + noise) * mask + + return readings, key @jax.jit def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 9678e0c..bf365f3 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -109,7 +109,7 @@ def sample_single_path(carry, control, walls, n_sensors, settings): ) return (pose, key), pose.p -@partial(jax.jit, static_argnums=(1, 2)) # n_paths and n_sensors are static +@partial(jax.jit, static_argnums=(1)) # n_paths is static def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, n_sensors: int, robot_path: jnp.ndarray, walls: jnp.ndarray, settings: RobotSettings): @@ -369,7 +369,7 @@ def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: angle_changes = jnp.diff(angles, prepend=0.0) return jnp.stack([distances, angle_changes], axis=1) -@partial(jax.jit, static_argnums=(1,)) +@jax.jit def simulate_robot_path(start_pose: Pose, n_sensors: int, controls: jnp.ndarray, walls: jnp.ndarray, settings: RobotSettings, key: jax.random.PRNGKey): """Jitted pure function for simulating robot path""" From 860955bc132729c9115c867cc892b0a5ac1dff94 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Sun, 1 Dec 2024 22:35:02 +0100 Subject: [PATCH 68/86] fixed wall threshold issue todo: wall segments are invisibly joined --- robot_2/emoji.py | 2 + robot_2/reality.py | 25 +-- robot_2/robot.py | 101 ++++++++++ robot_2/visualization.py | 236 ++++++++++++++++++++++ robot_2/where_am_i.py | 420 +++++---------------------------------- 5 files changed, 398 insertions(+), 386 deletions(-) create mode 100644 robot_2/emoji.py create mode 100644 robot_2/robot.py create mode 100644 robot_2/visualization.py diff --git a/robot_2/emoji.py b/robot_2/emoji.py new file mode 100644 index 0000000..e388ddb --- /dev/null +++ b/robot_2/emoji.py @@ -0,0 +1,2 @@ +robot = "🤖" +pencil = "✏️" \ No newline at end of file diff --git a/robot_2/reality.py b/robot_2/reality.py index c347333..8f6fa1f 100644 --- a/robot_2/reality.py +++ b/robot_2/reality.py @@ -22,25 +22,12 @@ from jax.random import PRNGKey from functools import partial -@dataclass -class Pose: - """Represents a robot's position (x,y) and heading (radians)""" - p: jnp.ndarray # position [x, y] - hd: float # heading in radians - - def step_along(self, s: float) -> "Pose": - """Move forward by distance s""" - dp = jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) - return Pose(self.p + s * dp, self.hd) - - def rotate(self, angle: float) -> "Pose": - """Rotate by angle (in radians)""" - return Pose(self.p, self.hd + angle) +WALL_COLLISION_THRESHOLD = 0.01 @jax.jit def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", - current_pose: Pose, control: Tuple[float, float], - key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: + current_pose: "Pose", control: Tuple[float, float], + key: PRNGKey) -> Tuple["Pose", jnp.ndarray, PRNGKey]: """Execute a control command with noise, stopping if we hit a wall""" dist, angle = control k1, k2, k3 = jax.random.split(key, 3) @@ -56,7 +43,7 @@ def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings min_dist = compute_distance_to_wall(walls, new_pose, 0.0, settings.sensor_range) # Only move as far as we can before hitting a wall - safe_dist = jnp.minimum(noisy_dist, min_dist - 0.1) + safe_dist = jnp.minimum(noisy_dist, min_dist - WALL_COLLISION_THRESHOLD) safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards new_pose = new_pose.step_along(safe_dist) @@ -68,7 +55,7 @@ def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings @jax.jit def get_sensor_readings(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", - pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: + pose: "Pose", key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: """Return noisy distance readings to walls from given pose""" MAX_SENSORS = 32 # Fixed maximum key, subkey = jax.random.split(key) @@ -90,7 +77,7 @@ def get_sensor_readings(walls: jnp.ndarray, n_sensors: int, settings: "RobotSett return readings, key @jax.jit -def compute_distance_to_wall(walls: jnp.ndarray, pose: Pose, +def compute_distance_to_wall(walls: jnp.ndarray, pose: "Pose", sensor_angle: float, sensor_range: float) -> float: """Compute true distance to nearest wall along sensor ray""" if walls.shape[0] == 0: # No walls diff --git a/robot_2/robot.py b/robot_2/robot.py new file mode 100644 index 0000000..e8e28d1 --- /dev/null +++ b/robot_2/robot.py @@ -0,0 +1,101 @@ +import jax +import jax.numpy as jnp +from functools import partial +from penzai import pz +import genjax +from typing import List, Tuple +from robot_2.reality import execute_control +from jax.random import PRNGKey +from dataclasses import dataclass + + +@pz.pytree_dataclass +class Pose(genjax.PythonicPytree): + """Robot pose with position and heading""" + p: jax.Array # [x, y] + hd: float # heading in radians + + def dp(self): + """Get direction vector from heading""" + return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) + + def step_along(self, s: float) -> "Pose": + """Move forward by distance s""" + return Pose(self.p + s * self.dp(), self.hd) + + def rotate(self, angle: float) -> "Pose": + """Rotate by angle (in radians)""" + return Pose(self.p, self.hd + angle) + +@pz.pytree_dataclass +class RobotSettings(genjax.PythonicPytree): + """Robot configuration and uncertainty settings""" + p_noise: float = 0.1 # Position noise + hd_noise: float = 0.1 # Heading noise + sensor_noise: float = 0.1 # Sensor noise + sensor_range: float = 10.0 # Maximum sensor range + +def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: + """Convert a series of points into (distance, angle) control pairs""" + points = jnp.array([p[:2] for p in path_points]) + deltas = points[1:] - points[:-1] + distances = jnp.linalg.norm(deltas, axis=1) + angles = jnp.arctan2(deltas[:, 1], deltas[:, 0]) + angle_changes = jnp.diff(angles, prepend=0.0) + return jnp.stack([distances, angle_changes], axis=1) + +def sample_single_path(carry, control, walls, n_sensors, settings): + """Single step of path sampling that can be used with scan""" + pose, key = carry + pose, _, key = execute_control( + walls=walls, + n_sensors=n_sensors, + settings=settings, + current_pose=pose, + control=control, + key=key + ) + return (pose, key), pose.p + +@partial(jax.jit, static_argnums=(1)) +def sample_possible_paths(key: jnp.ndarray, n_paths: int, n_sensors: int, + robot_path: jnp.ndarray, walls: jnp.ndarray, + settings: RobotSettings): + """Generate n possible paths given the planned path, respecting walls""" + path_points = robot_path[:, :2] + controls = path_to_controls(path_points) + + start_point = path_points[0] + start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) + + keys = jax.random.split(key, n_paths) + + def sample_path_scan(key): + init_carry = (start_pose, key) + (final_pose, final_key), path_points = jax.lax.scan( + lambda carry, control: sample_single_path(carry, control, walls, n_sensors, settings), + init_carry, + controls + ) + return jnp.concatenate([start_pose.p[None, :], path_points], axis=0) + + paths = jax.vmap(sample_path_scan)(keys) + return paths + +@jax.jit +def simulate_robot_path(start_pose: Pose, n_sensors: int, controls: jnp.ndarray, + walls: jnp.ndarray, settings: RobotSettings, key: jnp.ndarray): + """Simulate robot path with noise and sensor readings""" + def step_fn(carry, control): + pose, k = carry + new_pose, readings, new_key = execute_control( + walls=walls, + n_sensors=n_sensors, + settings=settings, + current_pose=pose, + control=control, + key=k + ) + return (new_pose, new_key), (new_pose, readings) + + return jax.lax.scan(step_fn, (start_pose, key), controls) diff --git a/robot_2/visualization.py b/robot_2/visualization.py new file mode 100644 index 0000000..8a83258 --- /dev/null +++ b/robot_2/visualization.py @@ -0,0 +1,236 @@ +from genstudio.plot import js +import genstudio.plot as Plot +from robot_2.emoji import robot, pencil + +WALL_WIDTH = 6 +PATH_WIDTH = 6 + +def drawing_system(key, on_complete): + """Create drawing system for walls and paths""" + line = Plot.line( + js(f"$state.{key}"), + stroke="#ccc", + strokeWidth=4, + strokeDasharray="4" + ) + + events = Plot.events({ + "_initialState": Plot.initialState({key: []}), + "onDrawStart": js(f"""(e) => {{ + $state.{key} = [[e.x, e.y, e.startTime]]; + }}"""), + "onDraw": js(f"""(e) => {{ + if ($state.{key}.length > 0) {{ + const last = $state.{key}[$state.{key}.length - 1]; + const dx = e.x - last[0]; + const dy = e.y - last[1]; + $state.update(['{key}', 'append', [e.x, e.y, e.startTime]]); + }} + }}"""), + "onDrawEnd": js(f"""(e) => {{ + if ($state.{key}.length > 1) {{ + const points = [...$state.{key}, [e.x, e.y, e.startTime]]; + console.log("onDrawEnd", %1) + const ret = {{ + points: points, + simplify: (threshold=0) => {{ + const result = [points[0]]; + let lastKept = points[0]; + + for (let i = 1; i < points.length - 1; i++) {{ + const p = points[i]; + const dx = p[0] - lastKept[0]; + const dy = p[1] - lastKept[1]; + const dist = Math.sqrt(dx*dx + dy*dy); + + if (dist >= threshold) {{ + result.push(p); + lastKept = p; + }} + }} + + result.push(points[points.length - 1]); + return result; + }} + }} + %1(ret) + }} + $state.{key} = []; + }}""", on_complete) + }) + return line + events +def create_sliders(): + """Create control sliders for robot parameters""" + return ( + Plot.Slider( + "sensor_noise", + range=[0, 1], + step=0.02, + label="Sensor Noise:", + showValue=True + ) + | Plot.Slider( + "motion_noise", + range=[0, 1], + step=0.02, + label="Motion Noise:", + showValue=True + ) + | Plot.Slider( + "n_sensors", + range=[4, 32], + step=1, + label="Number of Sensors:", + showValue=True + ) + ) + +def clear_state(w, _): + """Reset visualization state""" + w.state.update(v.create_initial_state() | {"selected_tool": w.state.selected_tool}) + + +def create_toolbar(): + """Create toolbar with tool selection buttons""" + selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" + + return Plot.html("Select tool:") | ["div.flex.gap-2", + [selectable_button, { + "data-selected": js("$state.selected_tool === 'path'"), + "onClick": js("() => $state.selected_tool = 'path'") + }, f"{robot} Path"], + [selectable_button, { + "data-selected": js("$state.selected_tool === 'walls'"), + "onClick": js("() => $state.selected_tool = 'walls'") + }, f"{pencil} Walls"], + [selectable_button, { + "onClick": clear_state + }, "Clear"] + ] + +def create_reality_toggle(): + """Create toggle for showing true position""" + return Plot.html("") | ["label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", + ["input", { + "type": "checkbox", + "checked": js("$state.show_true_position"), + "onChange": js("(e) => $state.show_true_position = e.target.checked") + }], "Show true position"] + +def create_sensor_rays(): + """Create visualization for sensor rays""" + return Plot.line( + js(""" + Array.from($state.sensor_readings).map((r, i) => { + const heading = $state.robot_pose.heading || 0; + const n_sensors = $state.n_sensors; + const angle = heading + (i * Math.PI * 2) / n_sensors; + const x = $state.robot_pose.x; + const y = $state.robot_pose.y; + return [ + [x, y, i], + [x + r * Math.cos(angle), + y + r * Math.sin(angle), i] + ] + }).flat() + """), + z="2", + stroke="red", + strokeWidth=1, + marker="circle" + ) + +def create_robot_canvas(drawing_system_handler): + """Create main robot visualization canvas""" + return ( + # Draw completed walls + Plot.line( + js("$state.walls"), + stroke=Plot.constantly("Walls"), + strokeWidth=WALL_WIDTH, + z="2", + render=Plot.renderChildEvents({"onClick": js("""(e) => { + const zs = new Set($state.walls.map(w => w[2])); + const targetZ = [...zs][e.index]; + $state.walls = $state.walls.filter(([x, y, z]) => z !== targetZ) + }""")}) + ) + # Draw current line being drawn + + drawing_system("current_line", drawing_system_handler) + + # Draw planned path + + Plot.line( + js("$state.robot_path"), + stroke=Plot.constantly("Robot Path"), + strokeWidth=2, + r=3, + marker="circle" + ) + + # Draw robot and true path when enabled + + Plot.cond( + js("$state.show_true_position"), + [Plot.text( + js("[[$state.robot_pose.x, $state.robot_pose.y]]"), + text=Plot.constantly(robot), + fontSize=30, + textAnchor="middle", + dy="-0.35em", + rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI")), + Plot.line( + js("$state.true_path"), + stroke=Plot.constantly("True Path"), + strokeWidth=2 + ), + create_sensor_rays() + ] + ) + + Plot.domain([0, 10], [0, 10]) + + Plot.grid() + + Plot.aspectRatio(1) + + Plot.colorMap({ + "Walls": "#666", + "Sensor Rays": "red", + "True Path": "green", + "Robot Path": "blue", + }) + + Plot.colorLegend() + + Plot.line( + js(""" + if (!$state.show_debug || !$state.possible_paths) {return [];}; + return $state.possible_paths.flatMap((path, pathIdx) => + path.map(([x, y]) => [x, y, pathIdx]) + ) + """, expression=False), + stroke="blue", + strokeOpacity=0.2, + z="2" + ) + + Plot.clip() + ) + +def create_initial_state(): + """Create initial state for visualization""" + return { + "walls": [ + # Frame around domain (timestamp 0) + [0, 0, 0], [10, 0, 0], # Bottom + [10, 0, 0], [10, 10, 0], # Right + [10, 10, 0], [0, 10, 0], # Top + [0, 10, 0], [0, 0, 0], # Left + ], + "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, + "sensor_noise": 0.1, + "motion_noise": 0.1, + "n_sensors": 8, + "show_sensors": True, + "selected_tool": "path", + "robot_path": [], + "possible_paths": [], + "estimated_pose": None, + "sensor_readings": [], + "show_uncertainty": True, + "show_true_position": False, + "current_line": [] + } + diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index bf365f3..e3c4e35 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -52,375 +52,67 @@ # pyright: reportUnknownMemberType=false import genstudio.plot as Plot -from genstudio.plot import js import numpy as np import jax.numpy as jnp -from typing import TypedDict, List, Tuple, Any - -import robot_2.reality as reality -import jax -import jax.numpy as jnp -import genjax -from genjax import normal, mv_normal_diag -from penzai import pz # Import penzai for pytree dataclasses from typing import List, Tuple -from robot_2.reality import Pose -from functools import partial - -WALL_WIDTH=6 -PATH_WIDTH=6 -SEGMENT_THRESHOLD=0.25 -@pz.pytree_dataclass -class RobotSettings(genjax.PythonicPytree): - """Robot configuration and uncertainty settings""" - p_noise: float = 0.1 # Position noise - hd_noise: float = 0.1 # Heading noise - sensor_noise: float = 0.1 # Sensor noise - sensor_range: float = 10.0 # Maximum sensor range - -@pz.pytree_dataclass -class Pose(genjax.PythonicPytree): - """Robot pose with position and heading""" - p: jax.Array # [x, y] - hd: float # heading in radians - - def dp(self): - """Get direction vector from heading""" - return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) - - def step_along(self, s: float) -> "Pose": - """Move forward by distance s""" - return Pose(self.p + s * self.dp(), self.hd) - - def rotate(self, angle: float) -> "Pose": - """Rotate by angle (in radians)""" - return Pose(self.p, self.hd + angle) - -def sample_single_path(carry, control, walls, n_sensors, settings): - """Single step of path sampling that can be used with scan""" - pose, key = carry - pose, _, key = reality.execute_control( - walls=walls, - n_sensors=n_sensors, - settings=settings, - current_pose=pose, - control=control, - key=key - ) - return (pose, key), pose.p - -@partial(jax.jit, static_argnums=(1)) # n_paths is static -def sample_possible_paths(key: jax.random.PRNGKey, n_paths: int, n_sensors: int, - robot_path: jnp.ndarray, walls: jnp.ndarray, - settings: RobotSettings): - """Generate n possible paths given the planned path, respecting walls""" - # Extract just x,y coordinates from path - path_points = robot_path[:, :2] # Shape: (N, 2) - controls = path_to_controls(path_points) - - start_point = path_points[0] - start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) - - # Split key for multiple samples - keys = jax.random.split(key, n_paths) - - def sample_path_scan(key): - init_carry = (start_pose, key) - (final_pose, final_key), path_points = jax.lax.scan( - lambda carry, control: sample_single_path(carry, control, walls, n_sensors, settings), - init_carry, - controls - ) - return jnp.concatenate([start_pose.p[None, :], path_points], axis=0) - - paths = jax.vmap(sample_path_scan)(keys) - return paths - -_gensym_counter = 0 - - -def gensym(prefix: str = "g") -> str: - """Generate a unique symbol with an optional prefix, similar to Clojure's gensym.""" - global _gensym_counter - _gensym_counter += 1 - return f"{prefix}{_gensym_counter}" - -def drawing_system(on_complete): - key = gensym("current_line") - line = Plot.line( - js(f"$state.{key}"), - stroke="#ccc", - strokeWidth=4, - strokeDasharray="4") - - events = Plot.events({ - "_initialState": Plot.initialState({key: []}), - "onDrawStart": js(f"""(e) => {{ - $state.{key} = [[e.x, e.y, e.startTime]]; - }}"""), - "onDraw": js(f"""(e) => {{ - if ($state.{key}.length > 0) {{ - const last = $state.{key}[$state.{key}.length - 1]; - const dx = e.x - last[0]; - const dy = e.y - last[1]; - // Only add point if moved more than threshold distance - if (Math.sqrt(dx*dx + dy*dy) >= {SEGMENT_THRESHOLD}) {{ - $state.update(['{key}', 'append', [e.x, e.y, e.startTime]]); - }} - }} - }}"""), - "onDrawEnd": js(f"""(e) => {{ - if ($state.{key}.length > 1) {{ - const points = [...$state.{key}, [e.x, e.y, e.startTime]] - - // Simplify line by keeping only every 3rd point - // keep this, we may re-enable later - //const simplified = $state.{key}.filter((_, i) => i % 3 === 0); - %1($state.{key}) - }} - $state.{key} = []; - }}""", on_complete) - }) - return line + events - -sliders = ( - Plot.Slider( - "sensor_noise", - range=[0, 1], - step=0.02, - label="Sensor Noise:", - showValue=True - ) - | Plot.Slider( - "motion_noise", - range=[0, 1], - step=0.02, - label="Motion Noise:", - showValue=True - ) - | Plot.Slider( - "n_sensors", - range=[4, 32], - step=1, - label="Number of Sensors:", - showValue=True - ) -) - -def initial_walls(): - return [ - # Frame around domain (timestamp 0) - [0, 0, 0], [10, 0, 0], # Bottom - [10, 0, 0], [10, 10, 0], # Right - [10, 10, 0], [0, 10, 0], # Top - [0, 10, 0], [0, 0, 0], # Left - ] - -initial_state = { - "walls": initial_walls(), - "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, - "sensor_noise": 0.1, - "motion_noise": 0.1, - "n_sensors": 8, - "show_sensors": True, - "selected_tool": "path", - "robot_path": [], - "possible_paths": [], - "estimated_pose": None, - "sensor_readings": [], - "show_uncertainty": True, - "show_true_position": False - } - -sensor_rays = Plot.line( - js(""" - Array.from($state.sensor_readings).map((r, i) => { - const heading = $state.robot_pose.heading || 0; - const n_sensors = $state.n_sensors; - const angle = heading + (i * Math.PI * 2) / n_sensors; - const x = $state.robot_pose.x; - const y = $state.robot_pose.y; - return [ - [x, y, i], - [x + r * Math.cos(angle), - y + r * Math.sin(angle), i] - ] - }).flat() - """), - z="2", - stroke="red", - strokeWidth=1, - marker="circle" -) - -true_path = Plot.line( - js("$state.true_path"), - stroke=Plot.constantly("True Path"), - strokeWidth=2 - ) - -planned_path = Plot.line( - js("$state.robot_path"), - stroke=Plot.constantly("Robot Path"), - strokeWidth=2, - r=3, - marker="circle"), +import jax -canvas = ( - # Draw completed walls - Plot.line( - js("$state.walls"), - stroke=Plot.constantly("Walls"), - strokeWidth=WALL_WIDTH, - z="2", - render=Plot.renderChildEvents({"onClick": js("""(e) => { - const zs = new Set($state.walls.map(w => w[2])); - const targetZ = [...zs][e.index]; - $state.walls = $state.walls.filter(([x, y, z]) => z !== targetZ) - }""")}) - ) - # Draw current line being drawn - + drawing_system(Plot.js("""(line) => { - if ($state.selected_tool === 'walls') { - $state.update(['walls', 'concat', line]); - } else if ($state.selected_tool === 'path') { - $state.update(['robot_path', 'reset', line]); - } - }""")) - + planned_path - - # Draw robot - + Plot.cond( - js("$state.show_true_position"), - [Plot.text( - js("[[$state.robot_pose.x, $state.robot_pose.y]]"), - text=Plot.constantly("🤖"), - fontSize=30, - textAnchor="middle", - dy="-0.35em", - rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI")), - true_path, - sensor_rays - ] - ) - + Plot.domain([0, 10], [0, 10]) - + Plot.grid() - + Plot.aspectRatio(1) - + Plot.colorMap({ - "Walls": "#666", - "Sensor Rays": "red", - "True Path": "green", - "Robot Path": "blue", - }) - + Plot.colorLegend() - + Plot.line( - js(""" - if (!$state.show_debug || !$state.possible_paths) {return [];}; - return $state.possible_paths.flatMap((path, pathIdx) => - path.map(([x, y]) => [x, y, pathIdx]) - ) - """, expression=False), - stroke="blue", - strokeOpacity=0.2, - z="2" - ) - + Plot.clip() - ) +import robot_2.visualization as v +import robot_2.robot as robot +key = jax.random.PRNGKey(0) def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: - """Convert wall vertices from UI format to JAX array of wall segments - Returns: array of shape (N, 2, 2) where: - N = number of walls - First 2 = start/end point - Second 2 = x,y coordinates - """ + """Convert wall vertices from UI format to JAX array of wall segments""" if not walls_list: return jnp.array([]).reshape((0, 2, 2)) - # Convert everything to JAX at once, using float32 for timestamps - points = jnp.array(walls_list, dtype=jnp.float32) # Shape: (N, 3) - - # Get consecutive pairs of points - p1 = points[:-1] # Shape: (N-1, 3) - p2 = points[1:] # Shape: (N-1, 3) + points = jnp.array(walls_list, dtype=jnp.float32) + p1 = points[:-1] + p2 = points[1:] - # Create wall segments array segments = jnp.stack([ - p1[:, :2], # x,y coordinates of start points - p2[:, :2] # x,y coordinates of end points - ], axis=1) # Shape: (N-1, 2, 2) + p1[:, :2], + p2[:, :2] + ], axis=1) - # Use timestamps to mask valid segments valid_mask = p1[:, 2] == p2[:, 2] - - # Return masked segments return segments * valid_mask[:, None, None] -def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: - """Convert a series of points into (distance, angle) control pairs - Returns: JAX array of shape (N,2) containing (forward_dist, rotation_angle) controls - """ - points = jnp.array([p[:2] for p in path_points]) - deltas = points[1:] - points[:-1] - distances = jnp.linalg.norm(deltas, axis=1) - angles = jnp.arctan2(deltas[:, 1], deltas[:, 0]) - # Calculate angle changes - angle_changes = jnp.diff(angles, prepend=0.0) - return jnp.stack([distances, angle_changes], axis=1) - -@jax.jit -def simulate_robot_path(start_pose: Pose, n_sensors: int, controls: jnp.ndarray, - walls: jnp.ndarray, settings: RobotSettings, key: jax.random.PRNGKey): - """Jitted pure function for simulating robot path""" - def step_fn(carry, control): - pose, k = carry - new_pose, readings, new_key = reality.execute_control( - walls=walls, - n_sensors=n_sensors, - settings=settings, - current_pose=pose, - control=control, - key=k - ) - return (new_pose, new_key), (new_pose, readings) - - return jax.lax.scan(step_fn, (start_pose, key), controls) - -def debug_reality(widget, e): +def debug_reality(widget, e, refresh=False): + """Handle updates to robot simulation""" if not widget.state.robot_path: return + + + global key + if refresh: + key, subkey = jax.random.split(key) - # Create settings object - settings = RobotSettings( + settings = robot.RobotSettings( p_noise=widget.state.motion_noise, hd_noise=widget.state.motion_noise, sensor_noise=widget.state.sensor_noise, ) - # Handle data conversion at the boundary path = jnp.array(widget.state.robot_path, dtype=jnp.float32) walls = convert_walls_to_jax(widget.state.walls) - n_sensors = int(widget.state.n_sensors) # Convert to int explicitly + n_sensors = int(widget.state.n_sensors) + + start_pose = robot.Pose(path[0, :2], 0.0) + controls = robot.path_to_controls(path) - start_pose = Pose(path[0, :2], 0.0) - controls = path_to_controls(path) - key = jax.random.PRNGKey(0) + key_true, key_possible = jax.random.split(key) - # Use jitted function for core computation - (final_pose, _), (poses, readings) = simulate_robot_path( - start_pose, n_sensors, controls, walls, settings, key + (final_pose, _), (poses, readings) = robot.simulate_robot_path( + start_pose, n_sensors, controls, walls, settings, key_true ) - # Convert poses to path true_path = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) - - # Generate possible paths - possible_paths = sample_possible_paths( - key, 20, n_sensors, path, walls, settings # Pass n_sensors separately + possible_paths = robot.sample_possible_paths( + key_possible, 20, n_sensors, path, walls, settings ) - # Update widget state widget.state.update({ "robot_pose": { "x": float(final_pose.p[0]), @@ -433,45 +125,39 @@ def debug_reality(widget, e): "show_debug": True }) -def clear_state(w, _): - w.state.update(initial_state | {"selected_tool": w.state.selected_tool}) - -selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" - -# Add debug button to toolbar -toolbar = Plot.html("Select tool:") | ["div.flex.gap-2", - [selectable_button, { - "data-selected": js("$state.selected_tool === 'path'"), - "onClick": js("() => $state.selected_tool = 'path'") - }, "🤖 Path"], - [selectable_button, { - "data-selected": js("$state.selected_tool === 'walls'"), - "onClick": js("() => $state.selected_tool = 'walls'") - }, "✏️ Walls"], - [selectable_button, { - "onClick": clear_state - }, "Clear"] - ] - - -reality_toggle = Plot.html("") | ["label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", - ["input", { - "type": "checkbox", - "checked": js("$state.show_true_position"), - "onChange": js("(e) => $state.show_true_position = e.target.checked") - }], "Show true position"] +# Create the visualization +canvas = v.create_robot_canvas(Plot.js("""({points, simplify}) => { + mode = $state.selected_tool + if (mode === 'walls') { + $state.update(['walls', 'concat', simplify(0.25)]) + } + if (mode === 'path') { + $state.robot_path = simplify(0.25) + } + }""")) +sliders = v.create_sliders() +toolbar = v.create_toolbar() +reality_toggle = v.create_reality_toggle() + + +key_refresh = ( + ["div.rounded-lg.p-2.bg-[repeating-linear-gradient(45deg,#86efac,#86efac_10px,#bbf7d0_10px,#bbf7d0_20px)]", + {"onMouseMove": lambda w, e: debug_reality(w, e, refresh=True)}, + "Key Refresh"] +) -# Modify the onChange handlers at the bottom +# Combine all components ( canvas & - (sliders | toolbar | reality_toggle | sensor_rays + {"height": 200}) + (sliders | toolbar | reality_toggle | key_refresh) & {"widths": ["400px", 1]} - | Plot.initialState(initial_state, sync=True) + | Plot.initialState(v.create_initial_state(), sync=True) | Plot.onChange({ "robot_path": debug_reality, "sensor_noise": debug_reality, "motion_noise": debug_reality, "n_sensors": debug_reality, "walls": debug_reality - })) + }) +) From 0f5c204852b5f47dcd485fc27a61571f861293fb Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 00:24:26 +0100 Subject: [PATCH 69/86] key scrubber --- robot_2/visualization.py | 16 +++++----- robot_2/where_am_i.py | 64 ++++++++++++++++++++++++++++++++-------- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/robot_2/visualization.py b/robot_2/visualization.py index 8a83258..0a15033 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -1,6 +1,7 @@ from genstudio.plot import js import genstudio.plot as Plot from robot_2.emoji import robot, pencil +from typing import Dict, List, Union, Any WALL_WIDTH = 6 PATH_WIDTH = 6 @@ -17,20 +18,19 @@ def drawing_system(key, on_complete): events = Plot.events({ "_initialState": Plot.initialState({key: []}), "onDrawStart": js(f"""(e) => {{ - $state.{key} = [[e.x, e.y, e.startTime]]; + $state.{key} = [[e.x, e.y, e.key]]; }}"""), "onDraw": js(f"""(e) => {{ if ($state.{key}.length > 0) {{ const last = $state.{key}[$state.{key}.length - 1]; const dx = e.x - last[0]; const dy = e.y - last[1]; - $state.update(['{key}', 'append', [e.x, e.y, e.startTime]]); + $state.update(['{key}', 'append', [e.x, e.y, e.key]]); }} }}"""), "onDrawEnd": js(f"""(e) => {{ if ($state.{key}.length > 1) {{ - const points = [...$state.{key}, [e.x, e.y, e.startTime]]; - console.log("onDrawEnd", %1) + const points = [...$state.{key}, [e.x, e.y, e.key]]; const ret = {{ points: points, simplify: (threshold=0) => {{ @@ -59,6 +59,7 @@ def drawing_system(key, on_complete): }}""", on_complete) }) return line + events + def create_sliders(): """Create control sliders for robot parameters""" return ( @@ -87,7 +88,7 @@ def create_sliders(): def clear_state(w, _): """Reset visualization state""" - w.state.update(v.create_initial_state() | {"selected_tool": w.state.selected_tool}) + w.state.update(create_initial_state(w.state.current_key) | {"selected_tool": w.state.selected_tool}) def create_toolbar(): @@ -209,7 +210,7 @@ def create_robot_canvas(drawing_system_handler): + Plot.clip() ) -def create_initial_state(): +def create_initial_state(key) -> Dict[str, Any]: """Create initial state for visualization""" return { "walls": [ @@ -231,6 +232,7 @@ def create_initial_state(): "sensor_readings": [], "show_uncertainty": True, "show_true_position": False, - "current_line": [] + "current_line": [], + "current_key": key } diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index e3c4e35..c872ec9 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -79,15 +79,13 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: valid_mask = p1[:, 2] == p2[:, 2] return segments * valid_mask[:, None, None] -def debug_reality(widget, e, refresh=False): +def debug_reality(widget, e, subkey=None): """Handle updates to robot simulation""" if not widget.state.robot_path: return + current_key = subkey if subkey is not None else key - global key - if refresh: - key, subkey = jax.random.split(key) settings = robot.RobotSettings( p_noise=widget.state.motion_noise, @@ -102,7 +100,7 @@ def debug_reality(widget, e, refresh=False): start_pose = robot.Pose(path[0, :2], 0.0) controls = robot.path_to_controls(path) - key_true, key_possible = jax.random.split(key) + key_true, key_possible = jax.random.split(current_key) (final_pose, _), (poses, readings) = robot.simulate_robot_path( start_pose, n_sensors, controls, walls, settings, key_true @@ -122,7 +120,8 @@ def debug_reality(widget, e, refresh=False): "possible_paths": possible_paths, "sensor_readings": readings[-1] if len(readings) > 0 else [], "true_path": [[float(x), float(y)] for x, y in true_path], - "show_debug": True + "show_debug": True, + "current_key": current_key[0] # Send current key to frontend }) @@ -139,20 +138,61 @@ def debug_reality(widget, e, refresh=False): sliders = v.create_sliders() toolbar = v.create_toolbar() reality_toggle = v.create_reality_toggle() - + key_refresh = ( - ["div.rounded-lg.p-2.bg-[repeating-linear-gradient(45deg,#86efac,#86efac_10px,#bbf7d0_10px,#bbf7d0_20px)]", - {"onMouseMove": lambda w, e: debug_reality(w, e, refresh=True)}, - "Key Refresh"] + [Plot.js(""" + ({children}) => { + const [inside, setInside] = React.useState(false) + const [waiting, setWaiting] = React.useState(false) + const [paused, setPaused] = React.useState(false) + + const text = paused + ? 'Click to Start' + : inside + ? 'Click to Pause' + : 'Fresh Keys' + + const onMouseMove = React.useCallback(async (e) => { + if (paused || waiting) return null; + const rect = e.currentTarget.getBoundingClientRect(); + const x = e.clientX - rect.left; + const stripeIndex = Math.floor(x / stripeWidth); + setWaiting(true) + await %1({key: $state.current_key, index: stripeIndex}); + setWaiting(false) + }) + + const stripeWidth = 4; // Width of each stripe in pixels + + return html(["div.rounded-lg.p-2.delay-100", { + "style": { + background: paused + ? 'repeating-linear-gradient(90deg,#aaa,#aaa 4px,#ddd 4px,#ddd 8px)' + : 'repeating-linear-gradient(90deg,#86efac,#86efac 4px,#bbf7d0 4px,#bbf7d0 8px)', + position: 'relative', + opacity: waiting ? 0.5 : 1, + transition: 'opacity 0.3s ease' + }, + "onMouseEnter": () => !paused && setInside(true), + "onMouseLeave": () => setInside(false), + "onClick": () => setPaused(!paused), + "onMouseMove": onMouseMove + }, text]) + } + """, lambda w, e: debug_reality(w, e, subkey=jax.random.split(jax.random.PRNGKey(e.key), e.index + 1)[e.index]) + + )] ) +# + # Combine all components ( canvas & - (sliders | toolbar | reality_toggle | key_refresh) + (sliders | toolbar | reality_toggle | key_refresh | Plot.js("$state.current_key")) & {"widths": ["400px", 1]} - | Plot.initialState(v.create_initial_state(), sync=True) + | Plot.initialState(v.create_initial_state(key[0]), sync=True) | Plot.onChange({ "robot_path": debug_reality, "sensor_noise": debug_reality, From 1324511b420bae4bfad79776274de850c9fd8d17 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 00:24:29 +0100 Subject: [PATCH 70/86] genstudio update --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8031658..103f425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ packages = [ python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e", source = "gcp" } -# genstudio = {version = "2024.11.022", source = "gcp"} +# genstudio = {version = "2024.12.003", source = "gcp"} genstudio = {path = "../genstudio", develop = true} ipykernel = "^6.29.3" matplotlib = "^3.8.3" From f040d43fc7f9e29d8761d04f57174492f89af372 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 11:40:49 +0100 Subject: [PATCH 71/86] improved key scrubber --- robot_2/emoji.py | 4 +- robot_2/visualization.py | 94 +++++++++++++++++++++++++++++++++++++--- robot_2/where_am_i.py | 82 ++++++++++++----------------------- 3 files changed, 117 insertions(+), 63 deletions(-) diff --git a/robot_2/emoji.py b/robot_2/emoji.py index e388ddb..1af87dc 100644 --- a/robot_2/emoji.py +++ b/robot_2/emoji.py @@ -1,2 +1,4 @@ robot = "🤖" -pencil = "✏️" \ No newline at end of file +pencil = "✏️" +recycle = "♻️" +clipboard = "📋" \ No newline at end of file diff --git a/robot_2/visualization.py b/robot_2/visualization.py index 0a15033..8869c0c 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -1,6 +1,6 @@ from genstudio.plot import js import genstudio.plot as Plot -from robot_2.emoji import robot, pencil +import robot_2.emoji as emoji from typing import Dict, List, Union, Any WALL_WIDTH = 6 @@ -88,7 +88,7 @@ def create_sliders(): def clear_state(w, _): """Reset visualization state""" - w.state.update(create_initial_state(w.state.current_key) | {"selected_tool": w.state.selected_tool}) + w.state.update(create_initial_state(w.state.current_seed) | {"selected_tool": w.state.selected_tool}) def create_toolbar(): @@ -99,11 +99,11 @@ def create_toolbar(): [selectable_button, { "data-selected": js("$state.selected_tool === 'path'"), "onClick": js("() => $state.selected_tool = 'path'") - }, f"{robot} Path"], + }, f"{emoji.robot} Path"], [selectable_button, { "data-selected": js("$state.selected_tool === 'walls'"), "onClick": js("() => $state.selected_tool = 'walls'") - }, f"{pencil} Walls"], + }, f"{emoji.pencil} Walls"], [selectable_button, { "onClick": clear_state }, "Clear"] @@ -173,7 +173,7 @@ def create_robot_canvas(drawing_system_handler): js("$state.show_true_position"), [Plot.text( js("[[$state.robot_pose.x, $state.robot_pose.y]]"), - text=Plot.constantly(robot), + text=Plot.constantly(emoji.robot), fontSize=30, textAnchor="middle", dy="-0.35em", @@ -210,7 +210,7 @@ def create_robot_canvas(drawing_system_handler): + Plot.clip() ) -def create_initial_state(key) -> Dict[str, Any]: +def create_initial_state(seed) -> Dict[str, Any]: """Create initial state for visualization""" return { "walls": [ @@ -233,6 +233,86 @@ def create_initial_state(key) -> Dict[str, Any]: "show_uncertainty": True, "show_true_position": False, "current_line": [], - "current_key": key + "current_seed": seed } + +def key_scrubber(handle_seed_index): + """Create a scrubber UI component for exploring different random seeds. + + The component shows a striped bar that can be clicked to pause/resume and + scrubbed to explore different seeds. A recycle button allows cycling through seeds. + + Args: + handle_seed_index: Callback function that takes a dict with 'key' (current seed) + and 'index' (stripe index or -1 for cycle) and handles seed changes. + + Returns: + A Plot.js component containing the scrubber UI. + """ + return ( + [Plot.js(""" + ({children}) => { + const [inside, setInside] = React.useState(false) + const [waiting, setWaiting] = React.useState(false) + const [paused, setPaused] = React.useState(false) + + const text = paused + ? 'Click to Start' + : inside + ? 'Click to Pause' + : 'Explore Keys' + + const onMouseMove = React.useCallback(async (e) => { + if (paused || waiting) return null; + const rect = e.currentTarget.getBoundingClientRect(); + const x = e.clientX - rect.left; + const stripeIndex = Math.max(0, Math.floor(x / stripeWidth)); + setWaiting(true) + await %1({key: $state.current_seed, index: stripeIndex}); + setWaiting(false) + }) + const stripeWidth = 4; // Width of each stripe in pixels + + return html(["div.flex.flex-col.gap-1", [ + ["div.flex.flex-row.gap-1", [ + ["div.rounded-lg.p-2.delay-100.flex-grow", { + "style": { + background: paused + ? 'repeating-linear-gradient(90deg,#aaa,#aaa 4px,#ddd 4px,#ddd 8px)' + : 'repeating-linear-gradient(90deg,#86efac,#86efac 4px,#bbf7d0 4px,#bbf7d0 8px)', + position: 'relative', + opacity: waiting ? 0.5 : 1, + transition: 'opacity 0.3s ease' + }, + "onMouseEnter": () => !paused && setInside(true), + "onMouseLeave": () => setInside(false), + "onClick": () => setPaused(!paused), + "onMouseMove": onMouseMove + }, text], + ["button.rounded-lg.p-1.text-xl.hover:bg-green-100", { + "onClick": async () => { + setWaiting(true); + await %1({key: $state.current_seed, index: -1}); // Special index to indicate cycle + setWaiting(false); + }, + "style": { + opacity: waiting ? 0.5 : 1, + transition: 'opacity 0.3s ease' + } + }, %2] + ]], + ["div.text-md.flex.gap-2.mx-auto.p-2.border.hover:border-gray-400.cursor-pointer.w-[140px].text-center", { + "onClick": () => { + navigator.clipboard.writeText($state.current_seed.toString()); + }, + "style": { + cursor: "pointer" + } + }, $state.current_seed, ["div.text-gray-500.ml-auto", "copy"]] + ]]) + } + """, handle_seed_index, emoji.recycle + + )] +) \ No newline at end of file diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index c872ec9..3f774b1 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -14,7 +14,6 @@ # language: python # name: python3 # --- - # %% [markdown] # # Robot Localization: A Robot's Perspective # @@ -56,11 +55,13 @@ import jax.numpy as jnp from typing import List, Tuple import jax +from jax.random import PRNGKey, split import robot_2.visualization as v import robot_2.robot as robot +import robot_2.emoji as emoji -key = jax.random.PRNGKey(0) +key = PRNGKey(0) def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: """Convert wall vertices from UI format to JAX array of wall segments""" @@ -79,14 +80,16 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: valid_mask = p1[:, 2] == p2[:, 2] return segments * valid_mask[:, None, None] -def debug_reality(widget, e, subkey=None): +def debug_reality(widget, e, seed=None): """Handle updates to robot simulation""" if not widget.state.robot_path: return - current_key = subkey if subkey is not None else key + current_seed = jnp.array(seed if seed is not None else widget.state.current_seed) + assert jnp.issubdtype(current_seed.dtype, jnp.integer), "Seed must be an integer" + + current_key = PRNGKey(current_seed) - settings = robot.RobotSettings( p_noise=widget.state.motion_noise, hd_noise=widget.state.motion_noise, @@ -100,7 +103,7 @@ def debug_reality(widget, e, subkey=None): start_pose = robot.Pose(path[0, :2], 0.0) controls = robot.path_to_controls(path) - key_true, key_possible = jax.random.split(current_key) + key_true, key_possible = split(current_key) (final_pose, _), (poses, readings) = robot.simulate_robot_path( start_pose, n_sensors, controls, walls, settings, key_true @@ -121,7 +124,7 @@ def debug_reality(widget, e, subkey=None): "sensor_readings": readings[-1] if len(readings) > 0 else [], "true_path": [[float(x), float(y)] for x, y in true_path], "show_debug": True, - "current_key": current_key[0] # Send current key to frontend + "current_seed": current_seed # Send current key to frontend }) @@ -139,60 +142,28 @@ def debug_reality(widget, e, subkey=None): toolbar = v.create_toolbar() reality_toggle = v.create_reality_toggle() - -key_refresh = ( - [Plot.js(""" - ({children}) => { - const [inside, setInside] = React.useState(false) - const [waiting, setWaiting] = React.useState(false) - const [paused, setPaused] = React.useState(false) - - const text = paused - ? 'Click to Start' - : inside - ? 'Click to Pause' - : 'Fresh Keys' - - const onMouseMove = React.useCallback(async (e) => { - if (paused || waiting) return null; - const rect = e.currentTarget.getBoundingClientRect(); - const x = e.clientX - rect.left; - const stripeIndex = Math.floor(x / stripeWidth); - setWaiting(true) - await %1({key: $state.current_key, index: stripeIndex}); - setWaiting(false) - }) - - const stripeWidth = 4; // Width of each stripe in pixels - - return html(["div.rounded-lg.p-2.delay-100", { - "style": { - background: paused - ? 'repeating-linear-gradient(90deg,#aaa,#aaa 4px,#ddd 4px,#ddd 8px)' - : 'repeating-linear-gradient(90deg,#86efac,#86efac 4px,#bbf7d0 4px,#bbf7d0 8px)', - position: 'relative', - opacity: waiting ? 0.5 : 1, - transition: 'opacity 0.3s ease' - }, - "onMouseEnter": () => !paused && setInside(true), - "onMouseLeave": () => setInside(false), - "onClick": () => setPaused(!paused), - "onMouseMove": onMouseMove - }, text]) - } - """, lambda w, e: debug_reality(w, e, subkey=jax.random.split(jax.random.PRNGKey(e.key), e.index + 1)[e.index]) - - )] -) +def handleSeedIndex(w, e): + global key + try: + if e.index == 0: + seed = key[0] + elif e.index == -1: + key = split(key, 2)[0] + seed = key[0] + else: + seed = split(key, e.index)[e.index-1][0] + debug_reality(w, e, seed=seed) + except Exception as err: + print(f"Error handling seed index: {err}, {e.key}, {e.index}") -# +key_scrubber = v.key_scrubber(handleSeedIndex) # Combine all components ( canvas & - (sliders | toolbar | reality_toggle | key_refresh | Plot.js("$state.current_key")) + (sliders | toolbar | reality_toggle | key_scrubber) & {"widths": ["400px", 1]} - | Plot.initialState(v.create_initial_state(key[0]), sync=True) + | Plot.initialState(v.create_initial_state(0), sync=True) | Plot.onChange({ "robot_path": debug_reality, "sensor_noise": debug_reality, @@ -201,3 +172,4 @@ def debug_reality(widget, e, subkey=None): "walls": debug_reality }) ) + From a898409da789b24b158d5eeca95014590f8269de Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 13:07:34 +0100 Subject: [PATCH 72/86] consolidate into where_am_i.py --- robot_2/reality.py | 114 -------- robot_2/robot.py | 101 -------- robot_2/test_reality.py | 39 --- robot_2/test_where_am_i.py | 63 +++++ robot_2/visualization.py | 185 +------------ robot_2/where_am_i.py | 514 ++++++++++++++++++++++++++++++++++--- 6 files changed, 550 insertions(+), 466 deletions(-) delete mode 100644 robot_2/reality.py delete mode 100644 robot_2/robot.py delete mode 100644 robot_2/test_reality.py create mode 100644 robot_2/test_where_am_i.py diff --git a/robot_2/reality.py b/robot_2/reality.py deleted file mode 100644 index 8f6fa1f..0000000 --- a/robot_2/reality.py +++ /dev/null @@ -1,114 +0,0 @@ -# pyright: reportUnknownMemberType=false -""" -Reality Simulation for Robot Localization Tutorial - -This module simulates the "true" state of the world that the robot can only -interact with through noisy sensors and imperfect motion. This separation helps -reinforce that localization is about the robot figuring out where it is using only: - -1. What it THINKS it did (control commands) -2. What it can SENSE (noisy sensor readings) -3. What it KNOWS about the world (wall map) - -The tutorial code should never peek at _true_pose - it must work only with -the information available through execute_control() and get_sensor_readings(). -""" - -import jax.numpy as jnp -import jax.lax as lax -from dataclasses import dataclass -from typing import List, Tuple, final, Optional -import jax.random -from jax.random import PRNGKey -from functools import partial - -WALL_COLLISION_THRESHOLD = 0.01 - -@jax.jit -def execute_control(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", - current_pose: "Pose", control: Tuple[float, float], - key: PRNGKey) -> Tuple["Pose", jnp.ndarray, PRNGKey]: - """Execute a control command with noise, stopping if we hit a wall""" - dist, angle = control - k1, k2, k3 = jax.random.split(key, 3) - - # Add noise to motion - noisy_dist = dist + jax.random.normal(k1) * settings.p_noise - noisy_angle = angle + jax.random.normal(k2) * settings.hd_noise - - # First rotate (can always rotate) - new_pose = current_pose.rotate(noisy_angle) - - # Then try to move forward, checking for collisions - min_dist = compute_distance_to_wall(walls, new_pose, 0.0, settings.sensor_range) - - # Only move as far as we can before hitting a wall - safe_dist = jnp.minimum(noisy_dist, min_dist - WALL_COLLISION_THRESHOLD) - safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards - - new_pose = new_pose.step_along(safe_dist) - - # Get sensor readings from new position - readings, k4 = get_sensor_readings(walls, n_sensors, settings, new_pose, k3) - - return new_pose, readings, k4 - -@jax.jit -def get_sensor_readings(walls: jnp.ndarray, n_sensors: int, settings: "RobotSettings", - pose: "Pose", key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: - """Return noisy distance readings to walls from given pose""" - MAX_SENSORS = 32 # Fixed maximum - key, subkey = jax.random.split(key) - - # Calculate angles based on n_sensors, but generate MAX_SENSORS of them - angle_step = 2 * jnp.pi / n_sensors - angles = jnp.arange(MAX_SENSORS) * angle_step - noise = jax.random.normal(subkey, (MAX_SENSORS,)) * settings.sensor_noise - - readings = jax.vmap(lambda a: compute_distance_to_wall( - walls, pose, a, settings.sensor_range))(angles) - - # Create a mask for the first n_sensors elements - mask = jnp.arange(MAX_SENSORS) < n_sensors - - # Apply mask and pad with zeros - readings = (readings + noise) * mask - - return readings, key - -@jax.jit -def compute_distance_to_wall(walls: jnp.ndarray, pose: "Pose", - sensor_angle: float, sensor_range: float) -> float: - """Compute true distance to nearest wall along sensor ray""" - if walls.shape[0] == 0: # No walls - return sensor_range - - ray_start = pose.p - ray_dir = jnp.array([ - jnp.cos(pose.hd + sensor_angle), - jnp.sin(pose.hd + sensor_angle) - ]) - - # Vectorized computation for all walls at once - p1 = walls[:, 0] # Shape: (N, 2) - p2 = walls[:, 1] # Shape: (N, 2) - - # Wall direction vectors - wall_vec = p2 - p1 # Shape: (N, 2) - - # Vector from wall start to ray start - to_start = ray_start - p1 # Shape: (N, 2) - - # Compute determinant (cross product in 2D) - det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) - - # Compute intersection parameters - u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) - t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) - - # Valid intersections: not parallel, in front of ray, within wall segment - is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) - - # Find minimum valid distance - min_dist = jnp.min(jnp.where(is_valid, t, jnp.inf)) - return jnp.where(jnp.isinf(min_dist), sensor_range, min_dist) \ No newline at end of file diff --git a/robot_2/robot.py b/robot_2/robot.py deleted file mode 100644 index e8e28d1..0000000 --- a/robot_2/robot.py +++ /dev/null @@ -1,101 +0,0 @@ -import jax -import jax.numpy as jnp -from functools import partial -from penzai import pz -import genjax -from typing import List, Tuple -from robot_2.reality import execute_control -from jax.random import PRNGKey -from dataclasses import dataclass - - -@pz.pytree_dataclass -class Pose(genjax.PythonicPytree): - """Robot pose with position and heading""" - p: jax.Array # [x, y] - hd: float # heading in radians - - def dp(self): - """Get direction vector from heading""" - return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) - - def step_along(self, s: float) -> "Pose": - """Move forward by distance s""" - return Pose(self.p + s * self.dp(), self.hd) - - def rotate(self, angle: float) -> "Pose": - """Rotate by angle (in radians)""" - return Pose(self.p, self.hd + angle) - -@pz.pytree_dataclass -class RobotSettings(genjax.PythonicPytree): - """Robot configuration and uncertainty settings""" - p_noise: float = 0.1 # Position noise - hd_noise: float = 0.1 # Heading noise - sensor_noise: float = 0.1 # Sensor noise - sensor_range: float = 10.0 # Maximum sensor range - -def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: - """Convert a series of points into (distance, angle) control pairs""" - points = jnp.array([p[:2] for p in path_points]) - deltas = points[1:] - points[:-1] - distances = jnp.linalg.norm(deltas, axis=1) - angles = jnp.arctan2(deltas[:, 1], deltas[:, 0]) - angle_changes = jnp.diff(angles, prepend=0.0) - return jnp.stack([distances, angle_changes], axis=1) - -def sample_single_path(carry, control, walls, n_sensors, settings): - """Single step of path sampling that can be used with scan""" - pose, key = carry - pose, _, key = execute_control( - walls=walls, - n_sensors=n_sensors, - settings=settings, - current_pose=pose, - control=control, - key=key - ) - return (pose, key), pose.p - -@partial(jax.jit, static_argnums=(1)) -def sample_possible_paths(key: jnp.ndarray, n_paths: int, n_sensors: int, - robot_path: jnp.ndarray, walls: jnp.ndarray, - settings: RobotSettings): - """Generate n possible paths given the planned path, respecting walls""" - path_points = robot_path[:, :2] - controls = path_to_controls(path_points) - - start_point = path_points[0] - start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) - - keys = jax.random.split(key, n_paths) - - def sample_path_scan(key): - init_carry = (start_pose, key) - (final_pose, final_key), path_points = jax.lax.scan( - lambda carry, control: sample_single_path(carry, control, walls, n_sensors, settings), - init_carry, - controls - ) - return jnp.concatenate([start_pose.p[None, :], path_points], axis=0) - - paths = jax.vmap(sample_path_scan)(keys) - return paths - -@jax.jit -def simulate_robot_path(start_pose: Pose, n_sensors: int, controls: jnp.ndarray, - walls: jnp.ndarray, settings: RobotSettings, key: jnp.ndarray): - """Simulate robot path with noise and sensor readings""" - def step_fn(carry, control): - pose, k = carry - new_pose, readings, new_key = execute_control( - walls=walls, - n_sensors=n_sensors, - settings=settings, - current_pose=pose, - control=control, - key=k - ) - return (new_pose, new_key), (new_pose, readings) - - return jax.lax.scan(step_fn, (start_pose, key), controls) diff --git a/robot_2/test_reality.py b/robot_2/test_reality.py deleted file mode 100644 index f6605ab..0000000 --- a/robot_2/test_reality.py +++ /dev/null @@ -1,39 +0,0 @@ -import jax.numpy as jnp -import pytest -from robot_2.reality import Reality, Pose - -def test_basic_motion(): - """Test that robot moves as expected without noise""" - # Convert walls to JAX array at creation - now in (N,2,2) shape - walls = jnp.array([ - [[0.0, 0.0], [1.0, 0.0]], # bottom wall - [[1.0, 0.0], [1.0, 1.0]], # right wall - [[1.0, 1.0], [0.0, 1.0]], # top wall - [[0.0, 1.0], [0.0, 0.0]] # left wall - ]) - world = Reality(walls, motion_noise=0.0, sensor_noise=0.0) - - # Move forward 1 unit - ignore readings since we're testing motion - _ = world.execute_control((1.0, 0.0)) - assert world._true_pose.p[0] == pytest.approx(1.5) # Started at 0.5, moved 1.0 - assert world._true_pose.p[1] == pytest.approx(0.5) # Y shouldn't change - - # Rotate 90 degrees (π/2 radians) - _ = world.execute_control((0.0, jnp.pi/2)) - assert world._true_pose.hd == pytest.approx(jnp.pi/2) - -def test_pose_methods(): - """Test Pose step_along and rotate methods""" - p = Pose(jnp.array([1.0, 1.0]), 0.0) - - # Step along heading 0 (right) - p2 = p.step_along(1.0) - assert p2.p[0] == pytest.approx(2.0) - assert p2.p[1] == pytest.approx(1.0) - - # Rotate 90 degrees and step - p3 = p.rotate(jnp.pi/2).step_along(1.0) - assert p3.p[0] == pytest.approx(1.0) - assert p3.p[1] == pytest.approx(2.0) - -pytest.main(["-v"]) # \ No newline at end of file diff --git a/robot_2/test_where_am_i.py b/robot_2/test_where_am_i.py new file mode 100644 index 0000000..2a615ec --- /dev/null +++ b/robot_2/test_where_am_i.py @@ -0,0 +1,63 @@ +import jax.numpy as jnp +import pytest +from robot_2.where_am_i import World, Pose, RobotCapabilities, execute_control +import robot_2.where_am_i as where_am_i +from jax.random import PRNGKey + +def test_basic_motion(): + """Test that robot moves as expected without noise""" + # Convert walls to JAX array at creation - now in (N,2,2) shape + walls = jnp.array([ + [[0.0, 0.0], [1.0, 0.0]], # bottom wall + [[1.0, 0.0], [1.0, 1.0]], # right wall + [[1.0, 1.0], [0.0, 1.0]], # top wall + [[0.0, 1.0], [0.0, 0.0]] # left wall + ]) + world = World(walls) + robot = RobotCapabilities( + p_noise=0.0, + hd_noise=0.0, + sensor_noise=0.0, + n_sensors=8, + sensor_range=10.0 + ) + + start_pose = Pose(jnp.array([0.5, 0.5]), 0.0) + key = PRNGKey(0) + + # Move forward 1 unit + new_pose, readings, key = execute_control( + world=world, + robot=robot, + current_pose=start_pose, + control=(1.0, 0.0), + key=key + ) + assert new_pose.p[0] == pytest.approx(1.0 - where_am_i.WALL_COLLISION_THRESHOLD) # Started at 0.5, blocked by wall at 1.0 + assert new_pose.p[1] == pytest.approx(0.5) # Y shouldn't change + + # Rotate 90 degrees (π/2 radians) + new_pose, readings, key = execute_control( + world=world, + robot=robot, + current_pose=new_pose, + control=(0.0, jnp.pi/2), + key=key + ) + assert new_pose.hd == pytest.approx(jnp.pi/2) + +def test_pose_methods(): + """Test Pose step_along and rotate methods""" + p = Pose(jnp.array([1.0, 1.0]), 0.0) + + # Step along heading 0 (right) + p2 = p.step_along(1.0) + assert p2.p[0] == pytest.approx(2.0) + assert p2.p[1] == pytest.approx(1.0) + + # Rotate 90 degrees and step + p3 = p.rotate(jnp.pi/2).step_along(1.0) + assert p3.p[0] == pytest.approx(1.0) + assert p3.p[1] == pytest.approx(2.0) + +pytest.main(["-v"]) # \ No newline at end of file diff --git a/robot_2/visualization.py b/robot_2/visualization.py index 8869c0c..757f47a 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -2,9 +2,9 @@ import genstudio.plot as Plot import robot_2.emoji as emoji from typing import Dict, List, Union, Any +import jax.numpy as jnp + -WALL_WIDTH = 6 -PATH_WIDTH = 6 def drawing_system(key, on_complete): """Create drawing system for walls and paths""" @@ -60,182 +60,6 @@ def drawing_system(key, on_complete): }) return line + events -def create_sliders(): - """Create control sliders for robot parameters""" - return ( - Plot.Slider( - "sensor_noise", - range=[0, 1], - step=0.02, - label="Sensor Noise:", - showValue=True - ) - | Plot.Slider( - "motion_noise", - range=[0, 1], - step=0.02, - label="Motion Noise:", - showValue=True - ) - | Plot.Slider( - "n_sensors", - range=[4, 32], - step=1, - label="Number of Sensors:", - showValue=True - ) - ) - -def clear_state(w, _): - """Reset visualization state""" - w.state.update(create_initial_state(w.state.current_seed) | {"selected_tool": w.state.selected_tool}) - - -def create_toolbar(): - """Create toolbar with tool selection buttons""" - selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" - - return Plot.html("Select tool:") | ["div.flex.gap-2", - [selectable_button, { - "data-selected": js("$state.selected_tool === 'path'"), - "onClick": js("() => $state.selected_tool = 'path'") - }, f"{emoji.robot} Path"], - [selectable_button, { - "data-selected": js("$state.selected_tool === 'walls'"), - "onClick": js("() => $state.selected_tool = 'walls'") - }, f"{emoji.pencil} Walls"], - [selectable_button, { - "onClick": clear_state - }, "Clear"] - ] - -def create_reality_toggle(): - """Create toggle for showing true position""" - return Plot.html("") | ["label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", - ["input", { - "type": "checkbox", - "checked": js("$state.show_true_position"), - "onChange": js("(e) => $state.show_true_position = e.target.checked") - }], "Show true position"] - -def create_sensor_rays(): - """Create visualization for sensor rays""" - return Plot.line( - js(""" - Array.from($state.sensor_readings).map((r, i) => { - const heading = $state.robot_pose.heading || 0; - const n_sensors = $state.n_sensors; - const angle = heading + (i * Math.PI * 2) / n_sensors; - const x = $state.robot_pose.x; - const y = $state.robot_pose.y; - return [ - [x, y, i], - [x + r * Math.cos(angle), - y + r * Math.sin(angle), i] - ] - }).flat() - """), - z="2", - stroke="red", - strokeWidth=1, - marker="circle" - ) - -def create_robot_canvas(drawing_system_handler): - """Create main robot visualization canvas""" - return ( - # Draw completed walls - Plot.line( - js("$state.walls"), - stroke=Plot.constantly("Walls"), - strokeWidth=WALL_WIDTH, - z="2", - render=Plot.renderChildEvents({"onClick": js("""(e) => { - const zs = new Set($state.walls.map(w => w[2])); - const targetZ = [...zs][e.index]; - $state.walls = $state.walls.filter(([x, y, z]) => z !== targetZ) - }""")}) - ) - # Draw current line being drawn - + drawing_system("current_line", drawing_system_handler) - - # Draw planned path - + Plot.line( - js("$state.robot_path"), - stroke=Plot.constantly("Robot Path"), - strokeWidth=2, - r=3, - marker="circle" - ) - - # Draw robot and true path when enabled - + Plot.cond( - js("$state.show_true_position"), - [Plot.text( - js("[[$state.robot_pose.x, $state.robot_pose.y]]"), - text=Plot.constantly(emoji.robot), - fontSize=30, - textAnchor="middle", - dy="-0.35em", - rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI")), - Plot.line( - js("$state.true_path"), - stroke=Plot.constantly("True Path"), - strokeWidth=2 - ), - create_sensor_rays() - ] - ) - + Plot.domain([0, 10], [0, 10]) - + Plot.grid() - + Plot.aspectRatio(1) - + Plot.colorMap({ - "Walls": "#666", - "Sensor Rays": "red", - "True Path": "green", - "Robot Path": "blue", - }) - + Plot.colorLegend() - + Plot.line( - js(""" - if (!$state.show_debug || !$state.possible_paths) {return [];}; - return $state.possible_paths.flatMap((path, pathIdx) => - path.map(([x, y]) => [x, y, pathIdx]) - ) - """, expression=False), - stroke="blue", - strokeOpacity=0.2, - z="2" - ) - + Plot.clip() - ) - -def create_initial_state(seed) -> Dict[str, Any]: - """Create initial state for visualization""" - return { - "walls": [ - # Frame around domain (timestamp 0) - [0, 0, 0], [10, 0, 0], # Bottom - [10, 0, 0], [10, 10, 0], # Right - [10, 10, 0], [0, 10, 0], # Top - [0, 10, 0], [0, 0, 0], # Left - ], - "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, - "sensor_noise": 0.1, - "motion_noise": 0.1, - "n_sensors": 8, - "show_sensors": True, - "selected_tool": "path", - "robot_path": [], - "possible_paths": [], - "estimated_pose": None, - "sensor_readings": [], - "show_uncertainty": True, - "show_true_position": False, - "current_line": [], - "current_seed": seed - } - def key_scrubber(handle_seed_index): """Create a scrubber UI component for exploring different random seeds. @@ -315,4 +139,7 @@ def key_scrubber(handle_seed_index): """, handle_seed_index, emoji.recycle )] -) \ No newline at end of file +) + + + \ No newline at end of file diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 3f774b1..407f760 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -50,20 +50,297 @@ # pyright: reportUnusedExpression=false # pyright: reportUnknownMemberType=false +from dataclasses import dataclass +from functools import partial +from typing import List, Tuple, Any, Dict + +import genjax import genstudio.plot as Plot -import numpy as np -import jax.numpy as jnp -from typing import List, Tuple import jax +import jax.numpy as jnp +import numpy as np from jax.random import PRNGKey, split +from penzai import pz +from genstudio.plot import js -import robot_2.visualization as v -import robot_2.robot as robot import robot_2.emoji as emoji +import robot_2.visualization as v key = PRNGKey(0) -def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: + +WALL_COLLISION_THRESHOLD = 0.15 +WALL_WIDTH = 6 +PATH_WIDTH = 6 + +@pz.pytree_dataclass +class Pose(genjax.PythonicPytree): + """Robot pose with position and heading""" + p: jax.numpy.ndarray # [x, y] + hd: float # heading in radians + + def dp(self): + """Get direction vector from heading""" + return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) + + def step_along(self, s: float) -> "Pose": + """Move forward by distance s""" + return Pose(self.p + s * self.dp(), self.hd) + + def rotate(self, angle: float) -> "Pose": + """Rotate by angle (in radians)""" + return Pose(self.p, self.hd + angle) + +@pz.pytree_dataclass +class World(genjax.PythonicPytree): + """The physical environment with walls that robots can collide with""" + walls: jnp.ndarray # [N, 2, 2] array of wall segments + + @jax.jit + def ray_distance(self, ray_start: jnp.ndarray, ray_dir: jnp.ndarray, max_dist: float) -> float: + """Find distance to nearest wall along a ray""" + if self.walls.shape[0] == 0: # No walls + return max_dist + + # Vectorized computation for all walls at once + p1 = self.walls[:, 0] # Shape: (N, 2) + p2 = self.walls[:, 1] # Shape: (N, 2) + + # Wall direction vectors + wall_vec = p2 - p1 # Shape: (N, 2) + + # Vector from wall start to ray start + to_start = ray_start - p1 # Shape: (N, 2) + + # Compute determinant (cross product in 2D) + det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) + + # Compute intersection parameters + u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) + t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) + + # Valid intersections: not parallel, in front of ray, within wall segment + is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) + + # Find minimum valid distance + min_dist = jnp.min(jnp.where(is_valid, t * jnp.linalg.norm(ray_dir), jnp.inf)) + return jnp.where(jnp.isinf(min_dist), max_dist, min_dist) + + @jax.jit + def check_movement(self, start_pos: jnp.ndarray, end_pos: jnp.ndarray, + collision_radius: float = WALL_COLLISION_THRESHOLD) -> Tuple[bool, jnp.ndarray]: + """Check if movement between two points collides with walls + + Args: + start_pos: [x, y] starting position + end_pos: [x, y] intended end position + collision_radius: How close we can get to walls + + Returns: + (can_move, safe_pos) where safe_pos is either end_pos or the + furthest safe position along the movement line + """ + movement_dir = end_pos - start_pos + dist = jnp.linalg.norm(movement_dir) + + # Replace if with where + ray_dir = jnp.where( + dist > 1e-6, + movement_dir / dist, + jnp.array([1.0, 0.0]) # Default direction if no movement + ) + + wall_dist = self.ray_distance(start_pos, ray_dir, dist) + + # Stop short of wall by collision_radius + safe_dist = jnp.maximum(0.0, wall_dist - collision_radius) + safe_pos = start_pos + ray_dir * safe_dist + + # Use where to select between start_pos and safe_pos + final_pos = jnp.where( + dist > 1e-6, + safe_pos, + start_pos + ) + + return wall_dist > dist - collision_radius, final_pos + +@pz.pytree_dataclass +class RobotCapabilities(genjax.PythonicPytree): + """Physical capabilities and limitations of the robot""" + p_noise: float # Position noise (std dev in meters) + hd_noise: float # Heading noise (std dev in radians) + sensor_noise: float # Sensor noise (std dev in meters) + n_sensors: int = 8 # Number of distance sensors + sensor_range: float = 10.0 # Maximum sensor range in meters + + def try_move(self, world: World, current_pos: jnp.ndarray, + desired_pos: jnp.ndarray, key: PRNGKey) -> jnp.ndarray: + """Try to move to desired_pos, respecting walls and adding noise""" + # Add motion noise + noise = jax.random.normal(key, shape=(2,)) * self.p_noise + noisy_target = desired_pos + noise + + # Check for collisions + _, safe_pos = world.check_movement(current_pos, noisy_target) + return safe_pos + +def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: + """Convert a series of points into (distance, angle) control pairs""" + points = jnp.array([p[:2] for p in path_points]) + deltas = points[1:] - points[:-1] + distances = jnp.linalg.norm(deltas, axis=1) + angles = jnp.arctan2(deltas[:, 1], deltas[:, 0]) + angle_changes = jnp.diff(angles, prepend=0.0) + return jnp.stack([distances, angle_changes], axis=1) + +@jax.jit +def get_sensor_readings(world: World, robot: RobotCapabilities, + pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: + """Return noisy distance readings to walls from given pose""" + MAX_SENSORS = 32 # Fixed maximum + key, subkey = jax.random.split(key) + + # Calculate angles based on n_sensors, but generate MAX_SENSORS of them + angle_step = 2 * jnp.pi / robot.n_sensors + angles = jnp.arange(MAX_SENSORS) * angle_step + noise = jax.random.normal(subkey, (MAX_SENSORS,)) * robot.sensor_noise + + readings = jax.vmap(lambda a: world.ray_distance( + ray_start=pose.p, + ray_dir=jnp.array([ + jnp.cos(pose.hd + a), + jnp.sin(pose.hd + a) + ]), + max_dist=robot.sensor_range + ))(angles) + + # Create a mask for the first n_sensors elements + mask = jnp.arange(MAX_SENSORS) < robot.n_sensors + + # Apply mask and pad with zeros + readings = (readings + noise) * mask + + return readings, key + + +@jax.jit +def execute_control(world: World, robot: RobotCapabilities, + current_pose: Pose, control: Tuple[float, float], + key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: + """Execute a control command with noise, stopping if we hit a wall + + Args: + control: (distance, angle) pair where: + - angle is how much to turn FIRST + - distance is how far to move AFTER turning + """ + dist, angle = control + k1, k2, k3 = jax.random.split(key, 3) + + # Add noise to motion + noisy_dist = dist + jax.random.normal(k1) * robot.p_noise + noisy_angle = angle + jax.random.normal(k2) * robot.hd_noise + + # First rotate (can always rotate) + new_pose = current_pose.rotate(noisy_angle) + + # Check distance to wall in our current heading direction + min_dist = world.ray_distance( + ray_start=new_pose.p, + ray_dir=new_pose.dp(), + max_dist=robot.sensor_range + ) + + # Only move as far as we can before hitting a wall + safe_dist = jnp.minimum(noisy_dist, min_dist - WALL_COLLISION_THRESHOLD) + safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards + + new_pose = new_pose.step_along(safe_dist) + + # Get sensor readings from new position + readings, k4 = get_sensor_readings(world, robot, new_pose, k3) + + return new_pose, readings, k4 + +def sample_single_path(carry, control, world, robot): + """Single step of path sampling that can be used with scan""" + pose, key = carry + pose, _, key = execute_control( + world=world, + robot=robot, + current_pose=pose, + control=control, + key=key + ) + return (pose, key), pose.p + +@partial(jax.jit, static_argnums=(1)) +def sample_possible_paths(key: jnp.ndarray, n_paths: int, + robot_path: jnp.ndarray, world: World, + robot: RobotCapabilities): + """Generate n possible paths given the planned path, respecting walls + + This simulates multiple possible outcomes given: + 1. The robot's intended path + 2. The physical world constraints + 3. The robot's motion/sensor characteristics + + Returns: + Array of shape [n_paths, n_steps, 2] containing possible trajectories + """ + path_points = robot_path[:, :2] + controls = path_to_controls(path_points) + + start_point = path_points[0] + start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) + + keys = jax.random.split(key, n_paths) + + def sample_path_scan(key): + init_carry = (start_pose, key) + (final_pose, final_key), path_points = jax.lax.scan( + lambda carry, control: sample_single_path(carry, control, world, robot), + init_carry, + controls + ) + return jnp.concatenate([start_pose.p[None, :], path_points], axis=0) + + paths = jax.vmap(sample_path_scan)(keys) + return paths + +@jax.jit +def simulate_robot_path(world: World, robot: RobotCapabilities, + start_pose: Pose, controls: jnp.ndarray, + key: jnp.ndarray): + """Simulate robot path with noise and sensor readings + + This simulates a single execution of the robot's planned path, including: + 1. Noisy motion according to robot capabilities + 2. Wall collisions from the physical world + 3. Noisy sensor readings + + Returns: + ((final_pose, final_key), (poses, readings)) where: + - poses contains all intermediate poses + - readings contains sensor readings at each step + """ + def step_fn(carry, control): + pose, k = carry + new_pose, readings, new_key = execute_control( + world=world, + robot=robot, + current_pose=pose, + control=control, + key=k + ) + return (new_pose, new_key), (new_pose, readings) + + return jax.lax.scan(step_fn, (start_pose, key), controls) + + + +def walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: """Convert wall vertices from UI format to JAX array of wall segments""" if not walls_list: return jnp.array([]).reshape((0, 2, 2)) @@ -80,7 +357,7 @@ def convert_walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: valid_mask = p1[:, 2] == p2[:, 2] return segments * valid_mask[:, None, None] -def debug_reality(widget, e, seed=None): +def simulate_robot_uncertainty(widget, e, seed=None): """Handle updates to robot simulation""" if not widget.state.robot_path: return @@ -90,28 +367,33 @@ def debug_reality(widget, e, seed=None): current_key = PRNGKey(current_seed) - settings = robot.RobotSettings( + # Create world and robot objects with ALL parameters + world = World(walls_to_jax(widget.state.walls)) + robot = RobotCapabilities( p_noise=widget.state.motion_noise, hd_noise=widget.state.motion_noise, sensor_noise=widget.state.sensor_noise, + n_sensors=widget.state.n_sensors, + sensor_range=10.0 ) path = jnp.array(widget.state.robot_path, dtype=jnp.float32) - walls = convert_walls_to_jax(widget.state.walls) - n_sensors = int(widget.state.n_sensors) - - start_pose = robot.Pose(path[0, :2], 0.0) - controls = robot.path_to_controls(path) + start_pose = Pose(path[0, :2], 0.0) + controls = path_to_controls(path) key_true, key_possible = split(current_key) - (final_pose, _), (poses, readings) = robot.simulate_robot_path( - start_pose, n_sensors, controls, walls, settings, key_true + (final_pose, _), (poses, readings) = simulate_robot_path( + world=world, # <-- Pass World object + robot=robot, + start_pose=start_pose, + controls=controls, + key=key_true ) true_path = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) - possible_paths = robot.sample_possible_paths( - key_possible, 20, n_sensors, path, walls, settings + possible_paths = sample_possible_paths( + key_possible, 20, path, world, robot # <-- Pass World object ) widget.state.update({ @@ -124,12 +406,11 @@ def debug_reality(widget, e, seed=None): "sensor_readings": readings[-1] if len(readings) > 0 else [], "true_path": [[float(x), float(y)] for x, y in true_path], "show_debug": True, - "current_seed": current_seed # Send current key to frontend + "current_seed": current_seed }) -# Create the visualization -canvas = v.create_robot_canvas(Plot.js("""({points, simplify}) => { +drawing_system_handler = Plot.js("""({points, simplify}) => { mode = $state.selected_tool if (mode === 'walls') { $state.update(['walls', 'concat', simplify(0.25)]) @@ -137,10 +418,159 @@ def debug_reality(widget, e, seed=None): if (mode === 'path') { $state.robot_path = simplify(0.25) } - }""")) -sliders = v.create_sliders() -toolbar = v.create_toolbar() -reality_toggle = v.create_reality_toggle() + }""") + +sliders = ( + Plot.Slider( + "sensor_noise", + range=[0, 1], + step=0.02, + label="Sensor Noise:", + showValue=True + ) + | Plot.Slider( + "motion_noise", + range=[0, 0.5], + step=0.01, + label="Motion Noise:", + showValue=True + ) + | Plot.Slider( + "n_sensors", + range=[4, 32], + step=1, + label="Number of Sensors:", + showValue=True + ) + ) + + +def create_initial_state(seed) -> Dict[str, Any]: + """Create initial state for visualization""" + return { + "walls": [ + # Frame around domain (timestamp 0) + [0, 0, 0], [10, 0, 0], # Bottom + [10, 0, 0], [10, 10, 0], # Right + [10, 10, 0], [0, 10, 0], # Top + [0, 10, 0], [0, 0, 0], # Left + ], + "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, + "sensor_noise": 0.1, + "motion_noise": 0.1, + "n_sensors": 8, + "show_sensors": True, + "selected_tool": "path", + "robot_path": [], + "possible_paths": [], + "estimated_pose": None, + "sensor_readings": [], + "show_uncertainty": True, + "show_true_position": False, + "current_line": [], + "current_seed": seed + } + +true_position_toggle = Plot.html( + ["label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", + ["input", { + "type": "checkbox", + "checked": js("$state.show_true_position"), + "onChange": js("(e) => $state.show_true_position = e.target.checked") + }], "Show true position"] +) + +sensor_rays = Plot.line( + js(""" + Array.from($state.sensor_readings).map((r, i) => { + const heading = $state.robot_pose.heading || 0; + const n_sensors = $state.n_sensors; + const angle = heading + (i * Math.PI * 2) / n_sensors; + const x = $state.robot_pose.x; + const y = $state.robot_pose.y; + return [ + [x, y, i], + [x + r * Math.cos(angle), + y + r * Math.sin(angle), i] + ] + }).flat() + """), + z="2", + stroke="red", + strokeWidth=1, + marker="circle" + ) + +true_path = Plot.cond( + js("$state.show_true_position"), + [Plot.text( + js("[[$state.robot_pose.x, $state.robot_pose.y]]"), + text=Plot.constantly(emoji.robot), + fontSize=30, + textAnchor="middle", + dy="-0.35em", + rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI")), + Plot.line( + js("$state.true_path"), + stroke=Plot.constantly("True Path"), + strokeWidth=2 + ), + sensor_rays + ] + ) + +planned_path = Plot.line( + js("$state.robot_path"), + stroke=Plot.constantly("Robot Path"), + strokeWidth=2, + r=3, + marker="circle" + ) + +walls = Plot.line( + js("$state.walls"), + stroke=Plot.constantly("Walls"), + strokeWidth=WALL_WIDTH, + z="2", + render=Plot.renderChildEvents({"onClick": js("""(e) => { + const zs = new Set($state.walls.map(w => w[2])); + const targetZ = [...zs][e.index]; + $state.walls = $state.walls.filter(([x, y, z]) => z !== targetZ) + }""")}) + ) + +possible_paths = Plot.line( + js(""" + if (!$state.show_debug || !$state.possible_paths) {return [];}; + return $state.possible_paths.flatMap((path, pathIdx) => + path.map(([x, y]) => [x, y, pathIdx]) + ) + """, expression=False), + stroke="blue", + strokeOpacity=0.2, + z="2" + ) + +def clear_state(w, _): + """Reset visualization state""" + w.state.update(create_initial_state(w.state.current_seed) | {"selected_tool": w.state.selected_tool}) + + +selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" + +toolbar = Plot.html("Select tool:") | ["div.flex.gap-2", + [selectable_button, { + "data-selected": js("$state.selected_tool === 'path'"), + "onClick": js("() => $state.selected_tool = 'path'") + }, f"{emoji.robot} Path"], + [selectable_button, { + "data-selected": js("$state.selected_tool === 'walls'"), + "onClick": js("() => $state.selected_tool = 'walls'") + }, f"{emoji.pencil} Walls"], + [selectable_button, { + "onClick": clear_state + }, "Clear"] +] def handleSeedIndex(w, e): global key @@ -152,24 +582,42 @@ def handleSeedIndex(w, e): seed = key[0] else: seed = split(key, e.index)[e.index-1][0] - debug_reality(w, e, seed=seed) + simulate_robot_uncertainty(w, e, seed=seed) except Exception as err: print(f"Error handling seed index: {err}, {e.key}, {e.index}") key_scrubber = v.key_scrubber(handleSeedIndex) -# Combine all components +canvas = ( + v.drawing_system("current_line", drawing_system_handler) + + walls + + planned_path + + true_path + + possible_paths + + Plot.domain([0, 10], [0, 10]) + + Plot.grid() + + Plot.aspectRatio(1) + + Plot.colorMap({ + "Walls": "#666", + "Sensor Rays": "red", + "True Path": "green", + "Robot Path": "blue", + }) + + Plot.colorLegend() + + Plot.clip() + ) + ( canvas & - (sliders | toolbar | reality_toggle | key_scrubber) + (sliders | toolbar | true_position_toggle | key_scrubber) & {"widths": ["400px", 1]} - | Plot.initialState(v.create_initial_state(0), sync=True) + | Plot.initialState(create_initial_state(0), sync=True) | Plot.onChange({ - "robot_path": debug_reality, - "sensor_noise": debug_reality, - "motion_noise": debug_reality, - "n_sensors": debug_reality, - "walls": debug_reality + "robot_path": simulate_robot_uncertainty, + "sensor_noise": simulate_robot_uncertainty, + "motion_noise": simulate_robot_uncertainty, + "n_sensors": simulate_robot_uncertainty, + "walls": simulate_robot_uncertainty }) ) From ab53ef5b2a00926d5e9a17959f54f0a472885031 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 14:48:53 +0100 Subject: [PATCH 73/86] rotating sensor rays --- robot_2/where_am_i.py | 44 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 407f760..4497ce8 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -367,7 +367,7 @@ def simulate_robot_uncertainty(widget, e, seed=None): current_key = PRNGKey(current_seed) - # Create world and robot objects with ALL parameters + # Create world and robot objects world = World(walls_to_jax(widget.state.walls)) robot = RobotCapabilities( p_noise=widget.state.motion_noise, @@ -384,7 +384,7 @@ def simulate_robot_uncertainty(widget, e, seed=None): key_true, key_possible = split(current_key) (final_pose, _), (poses, readings) = simulate_robot_path( - world=world, # <-- Pass World object + world=world, robot=robot, start_pose=start_pose, controls=controls, @@ -393,7 +393,7 @@ def simulate_robot_uncertainty(widget, e, seed=None): true_path = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) possible_paths = sample_possible_paths( - key_possible, 20, path, world, robot # <-- Pass World object + key_possible, 20, path, world, robot ) widget.state.update({ @@ -501,6 +501,41 @@ def create_initial_state(seed) -> Dict[str, Any]: marker="circle" ) + +animation_frame = Plot.Slider("frame", fps=30, range=628, controls=False) +rotating_sensor_rays = ( + Plot.line( + js(""" + Array.from($state.sensor_readings).map((r, i) => { + const heading = $state.robot_pose.heading || 0; + const n_sensors = $state.n_sensors; + let angle = heading + (i * Math.PI * 2) / n_sensors; + if (!$state.show_true_position) { + angle += $state.frame * 0.01; + } + const x = $state.robot_pose.x; + const y = $state.robot_pose.y; + return [ + [0, 0, i], + [r * Math.cos(angle), + r * Math.sin(angle), i] + ] + }).flat() + """), + z="2", + stroke="red", + strokeWidth=1, + marker="circle" + ) + + {"height": 200, "width": 200, "className": "bg-gray-100"} + + Plot.aspectRatio(1) + + Plot.domain([-10, 10]) + + Plot.hideAxis() + + Plot.gridX(interval=1) + + Plot.gridY(interval=1) + +) + true_path = Plot.cond( js("$state.show_true_position"), [Plot.text( @@ -609,7 +644,7 @@ def handleSeedIndex(w, e): ( canvas & - (sliders | toolbar | true_position_toggle | key_scrubber) + (sliders | toolbar | true_position_toggle | key_scrubber | rotating_sensor_rays) & {"widths": ["400px", 1]} | Plot.initialState(create_initial_state(0), sync=True) | Plot.onChange({ @@ -619,5 +654,6 @@ def handleSeedIndex(w, e): "n_sensors": simulate_robot_uncertainty, "walls": simulate_robot_uncertainty }) + | animation_frame ) From 7e658a5386f6d300500da2da37237b830c2281ef Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 14:56:28 +0100 Subject: [PATCH 74/86] simplify --- robot_2/where_am_i.py | 120 ++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 69 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 4497ce8..4eff694 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -263,67 +263,17 @@ def execute_control(world: World, robot: RobotCapabilities, return new_pose, readings, k4 -def sample_single_path(carry, control, world, robot): - """Single step of path sampling that can be used with scan""" - pose, key = carry - pose, _, key = execute_control( - world=world, - robot=robot, - current_pose=pose, - control=control, - key=key - ) - return (pose, key), pose.p - -@partial(jax.jit, static_argnums=(1)) -def sample_possible_paths(key: jnp.ndarray, n_paths: int, - robot_path: jnp.ndarray, world: World, - robot: RobotCapabilities): - """Generate n possible paths given the planned path, respecting walls - - This simulates multiple possible outcomes given: - 1. The robot's intended path - 2. The physical world constraints - 3. The robot's motion/sensor characteristics - - Returns: - Array of shape [n_paths, n_steps, 2] containing possible trajectories - """ - path_points = robot_path[:, :2] - controls = path_to_controls(path_points) - - start_point = path_points[0] - start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) - - keys = jax.random.split(key, n_paths) - - def sample_path_scan(key): - init_carry = (start_pose, key) - (final_pose, final_key), path_points = jax.lax.scan( - lambda carry, control: sample_single_path(carry, control, world, robot), - init_carry, - controls - ) - return jnp.concatenate([start_pose.p[None, :], path_points], axis=0) - - paths = jax.vmap(sample_path_scan)(keys) - return paths - @jax.jit def simulate_robot_path(world: World, robot: RobotCapabilities, start_pose: Pose, controls: jnp.ndarray, key: jnp.ndarray): """Simulate robot path with noise and sensor readings - This simulates a single execution of the robot's planned path, including: - 1. Noisy motion according to robot capabilities - 2. Wall collisions from the physical world - 3. Noisy sensor readings - Returns: - ((final_pose, final_key), (poses, readings)) where: - - poses contains all intermediate poses - - readings contains sensor readings at each step + Tuple of: + - Array of shape [n_steps, 2] containing positions + - Array of shape [n_steps] containing headings + - Array of shape [n_steps, n_sensors] containing sensor readings """ def step_fn(carry, control): pose, k = carry @@ -336,9 +286,41 @@ def step_fn(carry, control): ) return (new_pose, new_key), (new_pose, readings) - return jax.lax.scan(step_fn, (start_pose, key), controls) - + (_, _), (poses, readings) = jax.lax.scan(step_fn, (start_pose, key), controls) + + # Extract positions and headings + positions = jnp.concatenate([ + start_pose.p[None, :], + jax.vmap(lambda p: p.p)(poses) + ]) + headings = jnp.concatenate([ + jnp.array([start_pose.hd]), + jax.vmap(lambda p: p.hd)(poses) + ]) + + return positions, headings, readings +@partial(jax.jit, static_argnums=(1)) +def sample_possible_paths(key: jnp.ndarray, n_paths: int, + robot_path: jnp.ndarray, world: World, + robot: RobotCapabilities): + """Generate n possible paths given the planned path, respecting walls""" + path_points = robot_path[:, :2] + controls = path_to_controls(path_points) + + start_point = path_points[0] + start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) + + keys = jax.random.split(key, n_paths) + + # Vectorize over different random keys + return jax.vmap(lambda k: simulate_robot_path( + world=world, + robot=robot, + start_pose=start_pose, + controls=controls, + key=k + ))(keys) def walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: """Convert wall vertices from UI format to JAX array of wall segments""" @@ -383,27 +365,27 @@ def simulate_robot_uncertainty(widget, e, seed=None): key_true, key_possible = split(current_key) - (final_pose, _), (poses, readings) = simulate_robot_path( - world=world, - robot=robot, - start_pose=start_pose, - controls=controls, - key=key_true + # Get single true path with full pose information + true_paths, true_headings, true_readings = sample_possible_paths( + key_true, 1, path, world, robot ) + true_path = true_paths[0] # Take first (only) path + final_readings = true_readings[0, -1] # Take last readings + final_heading = true_headings[0, -1] # Take final heading - true_path = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) - possible_paths = sample_possible_paths( - key_possible, 20, path, world, robot + # Get multiple possible paths + possible_paths, possible_headings, _ = sample_possible_paths( + key_possible, 20, path, world, robot ) widget.state.update({ "robot_pose": { - "x": float(final_pose.p[0]), - "y": float(final_pose.p[1]), - "heading": float(final_pose.hd) + "x": float(true_path[-1, 0]), + "y": float(true_path[-1, 1]), + "heading": float(final_heading) }, "possible_paths": possible_paths, - "sensor_readings": readings[-1] if len(readings) > 0 else [], + "sensor_readings": final_readings, "true_path": [[float(x), float(y)] for x, y in true_path], "show_debug": True, "current_seed": current_seed From 033814efef0bc896f0b533bc841ebab479abe3e3 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 15:19:27 +0100 Subject: [PATCH 75/86] sample true and possible paths together --- robot_2/where_am_i.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 4eff694..6268a9b 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -360,23 +360,20 @@ def simulate_robot_uncertainty(widget, e, seed=None): ) path = jnp.array(widget.state.robot_path, dtype=jnp.float32) - start_pose = Pose(path[0, :2], 0.0) - controls = path_to_controls(path) - key_true, key_possible = split(current_key) - - # Get single true path with full pose information - true_paths, true_headings, true_readings = sample_possible_paths( - key_true, 1, path, world, robot + # Sample all paths at once (1 true path + N possible paths) + n_possible = 20 + all_paths, all_headings, all_readings = sample_possible_paths( + current_key, n_possible + 1, path, world, robot ) - true_path = true_paths[0] # Take first (only) path - final_readings = true_readings[0, -1] # Take last readings - final_heading = true_headings[0, -1] # Take final heading - # Get multiple possible paths - possible_paths, possible_headings, _ = sample_possible_paths( - key_possible, 20, path, world, robot - ) + # First path is the "true" path + true_path = all_paths[0] + final_readings = all_readings[0, -1] + final_heading = all_headings[0, -1] + + # Remaining paths are possible paths + possible_paths = all_paths[1:] widget.state.update({ "robot_pose": { From 5a2d1dd65cec7fae20ddd53f2c1f99e36caec5eb Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 15:49:28 +0100 Subject: [PATCH 76/86] smaller heading noise --- robot_2/where_am_i.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 6268a9b..c144014 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -353,7 +353,7 @@ def simulate_robot_uncertainty(widget, e, seed=None): world = World(walls_to_jax(widget.state.walls)) robot = RobotCapabilities( p_noise=widget.state.motion_noise, - hd_noise=widget.state.motion_noise, + hd_noise=widget.state.motion_noise * 0.1, sensor_noise=widget.state.sensor_noise, n_sensors=widget.state.n_sensors, sensor_range=10.0 From 497ee258abf7fa9c6376972f0c8c93489a5bfc3a Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 16:51:50 +0100 Subject: [PATCH 77/86] mouse control of sensor view --- robot_2/where_am_i.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index c144014..7cce04f 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -444,6 +444,7 @@ def create_initial_state(seed) -> Dict[str, Any]: "possible_paths": [], "estimated_pose": None, "sensor_readings": [], + "sensor_explore_angle": -1, "show_uncertainty": True, "show_true_position": False, "current_line": [], @@ -481,7 +482,7 @@ def create_initial_state(seed) -> Dict[str, Any]: ) -animation_frame = Plot.Slider("frame", fps=30, range=628, controls=False) + rotating_sensor_rays = ( Plot.line( js(""" @@ -489,8 +490,11 @@ def create_initial_state(seed) -> Dict[str, Any]: const heading = $state.robot_pose.heading || 0; const n_sensors = $state.n_sensors; let angle = heading + (i * Math.PI * 2) / n_sensors; - if (!$state.show_true_position) { - angle += $state.frame * 0.01; + if ($state.sensor_explore_angle > -1) { + angle += $state.sensor_explore_angle + } + else if (!$state.show_true_position) { + angle += $state.current_seed || Math.random() * 2 * Math.PI; } const x = $state.robot_pose.x; const y = $state.robot_pose.y; @@ -506,13 +510,25 @@ def create_initial_state(seed) -> Dict[str, Any]: strokeWidth=1, marker="circle" ) + # move the mouse around the plot to rotate the sensors + + Plot.events({"onMouseMove": + Plot.js("""(e) => { + // Convert mouse position to angle from center + // atan2 gives angle in radians from -pi to pi + // Subtract pi/2 to make 12 o'clock 0 radians + const angle = Math.atan2(e.y, e.x) - Math.PI/2; + + // Normalize to 0 to 2pi range + const normalized = (angle + 2*Math.PI) % (2*Math.PI); + + $state.sensor_explore_angle = normalized; + }""")}) + {"height": 200, "width": 200, "className": "bg-gray-100"} + Plot.aspectRatio(1) + Plot.domain([-10, 10]) + Plot.hideAxis() + Plot.gridX(interval=1) + Plot.gridY(interval=1) - ) true_path = Plot.cond( @@ -633,6 +649,5 @@ def handleSeedIndex(w, e): "n_sensors": simulate_robot_uncertainty, "walls": simulate_robot_uncertainty }) - | animation_frame ) From bc829bddd73e3fde1b69a32211c32437a519dfd1 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 17:49:09 +0100 Subject: [PATCH 78/86] heading noise control, more paths --- robot_2/visualization.py | 2 +- robot_2/where_am_i.py | 44 ++++++++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/robot_2/visualization.py b/robot_2/visualization.py index 757f47a..4d83152 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -126,7 +126,7 @@ def key_scrubber(handle_seed_index): } }, %2] ]], - ["div.text-md.flex.gap-2.mx-auto.p-2.border.hover:border-gray-400.cursor-pointer.w-[140px].text-center", { + ["div.text-md.flex.gap-2.p-2.border.hover:border-gray-400.cursor-pointer.w-[140px].text-center", { "onClick": () => { navigator.clipboard.writeText($state.current_seed.toString()); }, diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 7cce04f..fd618b6 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -353,7 +353,7 @@ def simulate_robot_uncertainty(widget, e, seed=None): world = World(walls_to_jax(widget.state.walls)) robot = RobotCapabilities( p_noise=widget.state.motion_noise, - hd_noise=widget.state.motion_noise * 0.1, + hd_noise=widget.state.motion_noise * widget.state.heading_noise_scale, sensor_noise=widget.state.sensor_noise, n_sensors=widget.state.n_sensors, sensor_range=10.0 @@ -362,7 +362,7 @@ def simulate_robot_uncertainty(widget, e, seed=None): path = jnp.array(widget.state.robot_path, dtype=jnp.float32) # Sample all paths at once (1 true path + N possible paths) - n_possible = 20 + n_possible = 40 all_paths, all_headings, all_readings = sample_possible_paths( current_key, n_possible + 1, path, world, robot ) @@ -395,30 +395,37 @@ def simulate_robot_uncertainty(widget, e, seed=None): $state.update(['walls', 'concat', simplify(0.25)]) } if (mode === 'path') { - $state.robot_path = simplify(0.25) + $state.robot_path = simplify(0.5) } }""") sliders = ( Plot.Slider( - "sensor_noise", - range=[0, 1], - step=0.02, - label="Sensor Noise:", - showValue=True - ) - | Plot.Slider( "motion_noise", range=[0, 0.5], step=0.01, label="Motion Noise:", showValue=True ) + & Plot.Slider( + "heading_noise_scale", + range=[0, 1], + step=0.05, + label="Heading Noise Scale:", + showValue=True + ) | Plot.Slider( + "sensor_noise", + range=[0, 1], + step=0.02, + label="Sensor Noise:", + showValue=True + ) + & Plot.Slider( "n_sensors", range=[4, 32], step=1, - label="Number of Sensors:", + label="Sensors:", showValue=True ) ) @@ -437,6 +444,7 @@ def create_initial_state(seed) -> Dict[str, Any]: "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, "sensor_noise": 0.1, "motion_noise": 0.1, + "heading_noise_scale": 0.3, "n_sensors": 8, "show_sensors": True, "selected_tool": "path", @@ -490,12 +498,13 @@ def create_initial_state(seed) -> Dict[str, Any]: const heading = $state.robot_pose.heading || 0; const n_sensors = $state.n_sensors; let angle = heading + (i * Math.PI * 2) / n_sensors; - if ($state.sensor_explore_angle > -1) { - angle += $state.sensor_explore_angle - } - else if (!$state.show_true_position) { - angle += $state.current_seed || Math.random() * 2 * Math.PI; - } + if (!$state.show_true_position) { + if ($state.sensor_explore_angle > -1) { + angle += $state.sensor_explore_angle + } else { + angle += $state.current_seed || Math.random() * 2 * Math.PI; + } + } const x = $state.robot_pose.x; const y = $state.robot_pose.y; return [ @@ -646,6 +655,7 @@ def handleSeedIndex(w, e): "robot_path": simulate_robot_uncertainty, "sensor_noise": simulate_robot_uncertainty, "motion_noise": simulate_robot_uncertainty, + "heading_noise_scale": simulate_robot_uncertainty, "n_sensors": simulate_robot_uncertainty, "walls": simulate_robot_uncertainty }) From 715c870e047bef54de21ee36da0757457d10eb6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:49:58 +0000 Subject: [PATCH 79/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../probcomp-localization-tutorial.py | 26 +- robot_2/emoji.py | 2 +- robot_2/test_where_am_i.py | 53 +- robot_2/visualization.py | 68 +-- robot_2/where_am_i.py | 576 ++++++++++-------- 5 files changed, 382 insertions(+), 343 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index ff8dba2..668ff35 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -139,7 +139,7 @@ def create_segments(points): def make_world(wall_verts, clutters_vec, start, controls) -> tuple[ - dict[str, FloatArray | tuple[float, float, float, float] | float | Pose], + dict[str, FloatArray | tuple[float, float, float, float] | float | Pose], dict[str, Control | Pose], int ]: @@ -805,18 +805,18 @@ def set_angle(widget: Widget, e): angle = float(e["value"]) rotated_trace, rotated_trace_weight_diff = rotate_trace(key, trace, angle) widget.state.update({ - "rotated_poses": pose_plot(rotated_trace.get_retval(), + "rotated_poses": pose_plot(rotated_trace.get_retval(), fill=Plot.constantly("with heading modified")), "angle": angle, "rotated_trace_weight_diff": rotated_trace_weight_diff}) -( +( Plot.initialState({ "poses": pose_plot(trace.get_retval(), fill=Plot.constantly("some pose")), "rotated_poses": pose_plot(rotated_trace.get_retval(), fill=Plot.constantly("with heading modified")), - "rotated_trace_weight_diff": rotated_trace_weight_diff, + "rotated_trace_weight_diff": rotated_trace_weight_diff, "angle": jnp.pi / 2.0 - + }) | Plot.new( world_plot @@ -827,13 +827,13 @@ def set_angle(widget: Widget, e): ) | html(["span.tc", Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`")]) | ( - Plot.js("`angle: ${$state.angle.toFixed(2)}`") - & ["input", {"type": "range", - "name": "angle", - "defaultValue": Plot.js("$state.angle"), - "min": -jnp.pi / 2, - "max": jnp.pi / 2, - "step": 0.1, + Plot.js("`angle: ${$state.angle.toFixed(2)}`") + & ["input", {"type": "range", + "name": "angle", + "defaultValue": Plot.js("$state.angle"), + "min": -jnp.pi / 2, + "max": jnp.pi / 2, + "step": 0.1, "onChange": set_angle }] & {"widths": ["80px", "200px"]} @@ -900,7 +900,7 @@ def set_first_step_angle(widget: Widget, e): }] & {"widths": ["80px", "200px"]} ) -).widget() +).widget() # %% [markdown] # ### Ideal sensors diff --git a/robot_2/emoji.py b/robot_2/emoji.py index 1af87dc..b43969f 100644 --- a/robot_2/emoji.py +++ b/robot_2/emoji.py @@ -1,4 +1,4 @@ robot = "🤖" pencil = "✏️" recycle = "♻️" -clipboard = "📋" \ No newline at end of file +clipboard = "📋" diff --git a/robot_2/test_where_am_i.py b/robot_2/test_where_am_i.py index 2a615ec..28c5c1d 100644 --- a/robot_2/test_where_am_i.py +++ b/robot_2/test_where_am_i.py @@ -4,60 +4,59 @@ import robot_2.where_am_i as where_am_i from jax.random import PRNGKey + def test_basic_motion(): """Test that robot moves as expected without noise""" # Convert walls to JAX array at creation - now in (N,2,2) shape - walls = jnp.array([ - [[0.0, 0.0], [1.0, 0.0]], # bottom wall - [[1.0, 0.0], [1.0, 1.0]], # right wall - [[1.0, 1.0], [0.0, 1.0]], # top wall - [[0.0, 1.0], [0.0, 0.0]] # left wall - ]) + walls = jnp.array( + [ + [[0.0, 0.0], [1.0, 0.0]], # bottom wall + [[1.0, 0.0], [1.0, 1.0]], # right wall + [[1.0, 1.0], [0.0, 1.0]], # top wall + [[0.0, 1.0], [0.0, 0.0]], # left wall + ] + ) world = World(walls) robot = RobotCapabilities( - p_noise=0.0, - hd_noise=0.0, - sensor_noise=0.0, - n_sensors=8, - sensor_range=10.0 + p_noise=0.0, hd_noise=0.0, sensor_noise=0.0, n_sensors=8, sensor_range=10.0 ) - + start_pose = Pose(jnp.array([0.5, 0.5]), 0.0) key = PRNGKey(0) - + # Move forward 1 unit new_pose, readings, key = execute_control( - world=world, - robot=robot, - current_pose=start_pose, - control=(1.0, 0.0), - key=key + world=world, robot=robot, current_pose=start_pose, control=(1.0, 0.0), key=key ) - assert new_pose.p[0] == pytest.approx(1.0 - where_am_i.WALL_COLLISION_THRESHOLD) # Started at 0.5, blocked by wall at 1.0 + assert new_pose.p[0] == pytest.approx( + 1.0 - where_am_i.WALL_COLLISION_THRESHOLD + ) # Started at 0.5, blocked by wall at 1.0 assert new_pose.p[1] == pytest.approx(0.5) # Y shouldn't change - + # Rotate 90 degrees (π/2 radians) new_pose, readings, key = execute_control( world=world, robot=robot, current_pose=new_pose, - control=(0.0, jnp.pi/2), - key=key + control=(0.0, jnp.pi / 2), + key=key, ) - assert new_pose.hd == pytest.approx(jnp.pi/2) + assert new_pose.hd == pytest.approx(jnp.pi / 2) + def test_pose_methods(): """Test Pose step_along and rotate methods""" p = Pose(jnp.array([1.0, 1.0]), 0.0) - + # Step along heading 0 (right) p2 = p.step_along(1.0) assert p2.p[0] == pytest.approx(2.0) assert p2.p[1] == pytest.approx(1.0) - + # Rotate 90 degrees and step - p3 = p.rotate(jnp.pi/2).step_along(1.0) + p3 = p.rotate(jnp.pi / 2).step_along(1.0) assert p3.p[0] == pytest.approx(1.0) assert p3.p[1] == pytest.approx(2.0) -pytest.main(["-v"]) # \ No newline at end of file + +pytest.main(["-v"]) # diff --git a/robot_2/visualization.py b/robot_2/visualization.py index 4d83152..ce268c4 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -5,22 +5,19 @@ import jax.numpy as jnp - def drawing_system(key, on_complete): """Create drawing system for walls and paths""" line = Plot.line( - js(f"$state.{key}"), - stroke="#ccc", - strokeWidth=4, - strokeDasharray="4" + js(f"$state.{key}"), stroke="#ccc", strokeWidth=4, strokeDasharray="4" ) - - events = Plot.events({ - "_initialState": Plot.initialState({key: []}), - "onDrawStart": js(f"""(e) => {{ + + events = Plot.events( + { + "_initialState": Plot.initialState({key: []}), + "onDrawStart": js(f"""(e) => {{ $state.{key} = [[e.x, e.y, e.key]]; }}"""), - "onDraw": js(f"""(e) => {{ + "onDraw": js(f"""(e) => {{ if ($state.{key}.length > 0) {{ const last = $state.{key}[$state.{key}.length - 1]; const dx = e.x - last[0]; @@ -28,7 +25,8 @@ def drawing_system(key, on_complete): $state.update(['{key}', 'append', [e.x, e.y, e.key]]); }} }}"""), - "onDrawEnd": js(f"""(e) => {{ + "onDrawEnd": js( + f"""(e) => {{ if ($state.{key}.length > 1) {{ const points = [...$state.{key}, [e.x, e.y, e.key]]; const ret = {{ @@ -36,19 +34,19 @@ def drawing_system(key, on_complete): simplify: (threshold=0) => {{ const result = [points[0]]; let lastKept = points[0]; - + for (let i = 1; i < points.length - 1; i++) {{ const p = points[i]; const dx = p[0] - lastKept[0]; const dy = p[1] - lastKept[1]; const dist = Math.sqrt(dx*dx + dy*dy); - + if (dist >= threshold) {{ result.push(p); lastKept = p; }} }} - + result.push(points[points.length - 1]); return result; }} @@ -56,37 +54,41 @@ def drawing_system(key, on_complete): %1(ret) }} $state.{key} = []; - }}""", on_complete) - }) + }}""", + on_complete, + ), + } + ) return line + events def key_scrubber(handle_seed_index): """Create a scrubber UI component for exploring different random seeds. - - The component shows a striped bar that can be clicked to pause/resume and + + The component shows a striped bar that can be clicked to pause/resume and scrubbed to explore different seeds. A recycle button allows cycling through seeds. - + Args: handle_seed_index: Callback function that takes a dict with 'key' (current seed) and 'index' (stripe index or -1 for cycle) and handles seed changes. - + Returns: A Plot.js component containing the scrubber UI. """ - return ( - [Plot.js(""" + return [ + Plot.js( + """ ({children}) => { const [inside, setInside] = React.useState(false) const [waiting, setWaiting] = React.useState(false) const [paused, setPaused] = React.useState(false) - - const text = paused + + const text = paused ? 'Click to Start' - : inside + : inside ? 'Click to Pause' : 'Explore Keys' - + const onMouseMove = React.useCallback(async (e) => { if (paused || waiting) return null; const rect = e.currentTarget.getBoundingClientRect(); @@ -97,7 +99,7 @@ def key_scrubber(handle_seed_index): setWaiting(false) }) const stripeWidth = 4; // Width of each stripe in pixels - + return html(["div.flex.flex-col.gap-1", [ ["div.flex.flex-row.gap-1", [ ["div.rounded-lg.p-2.delay-100.flex-grow", { @@ -136,10 +138,8 @@ def key_scrubber(handle_seed_index): }, $state.current_seed, ["div.text-gray-500.ml-auto", "copy"]] ]]) } - """, handle_seed_index, emoji.recycle - - )] -) - - - \ No newline at end of file + """, + handle_seed_index, + emoji.recycle, + ) + ] diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index fd618b6..cc64188 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -73,118 +73,136 @@ WALL_WIDTH = 6 PATH_WIDTH = 6 + @pz.pytree_dataclass class Pose(genjax.PythonicPytree): """Robot pose with position and heading""" + p: jax.numpy.ndarray # [x, y] - hd: float # heading in radians - + hd: float # heading in radians + def dp(self): """Get direction vector from heading""" return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) - + def step_along(self, s: float) -> "Pose": """Move forward by distance s""" return Pose(self.p + s * self.dp(), self.hd) - + def rotate(self, angle: float) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) - + + @pz.pytree_dataclass class World(genjax.PythonicPytree): """The physical environment with walls that robots can collide with""" + walls: jnp.ndarray # [N, 2, 2] array of wall segments @jax.jit - def ray_distance(self, ray_start: jnp.ndarray, ray_dir: jnp.ndarray, max_dist: float) -> float: + def ray_distance( + self, ray_start: jnp.ndarray, ray_dir: jnp.ndarray, max_dist: float + ) -> float: """Find distance to nearest wall along a ray""" if self.walls.shape[0] == 0: # No walls return max_dist - + # Vectorized computation for all walls at once p1 = self.walls[:, 0] # Shape: (N, 2) p2 = self.walls[:, 1] # Shape: (N, 2) - + # Wall direction vectors wall_vec = p2 - p1 # Shape: (N, 2) - + # Vector from wall start to ray start to_start = ray_start - p1 # Shape: (N, 2) - + # Compute determinant (cross product in 2D) det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) - + # Compute intersection parameters - u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / (det + 1e-10) - t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / (det + 1e-10) - + u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / ( + det + 1e-10 + ) + t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / ( + det + 1e-10 + ) + # Valid intersections: not parallel, in front of ray, within wall segment is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) - + # Find minimum valid distance min_dist = jnp.min(jnp.where(is_valid, t * jnp.linalg.norm(ray_dir), jnp.inf)) return jnp.where(jnp.isinf(min_dist), max_dist, min_dist) - @jax.jit - def check_movement(self, start_pos: jnp.ndarray, end_pos: jnp.ndarray, - collision_radius: float = WALL_COLLISION_THRESHOLD) -> Tuple[bool, jnp.ndarray]: + @jax.jit + def check_movement( + self, + start_pos: jnp.ndarray, + end_pos: jnp.ndarray, + collision_radius: float = WALL_COLLISION_THRESHOLD, + ) -> Tuple[bool, jnp.ndarray]: """Check if movement between two points collides with walls - + Args: start_pos: [x, y] starting position end_pos: [x, y] intended end position collision_radius: How close we can get to walls - + Returns: - (can_move, safe_pos) where safe_pos is either end_pos or the + (can_move, safe_pos) where safe_pos is either end_pos or the furthest safe position along the movement line """ movement_dir = end_pos - start_pos dist = jnp.linalg.norm(movement_dir) - + # Replace if with where ray_dir = jnp.where( dist > 1e-6, movement_dir / dist, - jnp.array([1.0, 0.0]) # Default direction if no movement + jnp.array([1.0, 0.0]), # Default direction if no movement ) - + wall_dist = self.ray_distance(start_pos, ray_dir, dist) - + # Stop short of wall by collision_radius safe_dist = jnp.maximum(0.0, wall_dist - collision_radius) safe_pos = start_pos + ray_dir * safe_dist - + # Use where to select between start_pos and safe_pos - final_pos = jnp.where( - dist > 1e-6, - safe_pos, - start_pos - ) - + final_pos = jnp.where(dist > 1e-6, safe_pos, start_pos) + return wall_dist > dist - collision_radius, final_pos - + + @pz.pytree_dataclass class RobotCapabilities(genjax.PythonicPytree): """Physical capabilities and limitations of the robot""" - p_noise: float # Position noise (std dev in meters) - hd_noise: float # Heading noise (std dev in radians) - sensor_noise: float # Sensor noise (std dev in meters) - n_sensors: int = 8 # Number of distance sensors + + p_noise: float # Position noise (std dev in meters) + hd_noise: float # Heading noise (std dev in radians) + sensor_noise: float # Sensor noise (std dev in meters) + n_sensors: int = 8 # Number of distance sensors sensor_range: float = 10.0 # Maximum sensor range in meters - def try_move(self, world: World, current_pos: jnp.ndarray, - desired_pos: jnp.ndarray, key: PRNGKey) -> jnp.ndarray: + def try_move( + self, + world: World, + current_pos: jnp.ndarray, + desired_pos: jnp.ndarray, + key: PRNGKey, + ) -> jnp.ndarray: """Try to move to desired_pos, respecting walls and adding noise""" # Add motion noise noise = jax.random.normal(key, shape=(2,)) * self.p_noise noisy_target = desired_pos + noise - + # Check for collisions _, safe_pos = world.check_movement(current_pos, noisy_target) return safe_pos + def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: """Convert a series of points into (distance, angle) control pairs""" points = jnp.array([p[:2] for p in path_points]) @@ -194,42 +212,47 @@ def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: angle_changes = jnp.diff(angles, prepend=0.0) return jnp.stack([distances, angle_changes], axis=1) + @jax.jit -def get_sensor_readings(world: World, robot: RobotCapabilities, - pose: Pose, key: PRNGKey) -> Tuple[jnp.ndarray, PRNGKey]: +def get_sensor_readings( + world: World, robot: RobotCapabilities, pose: Pose, key: PRNGKey +) -> Tuple[jnp.ndarray, PRNGKey]: """Return noisy distance readings to walls from given pose""" MAX_SENSORS = 32 # Fixed maximum key, subkey = jax.random.split(key) - + # Calculate angles based on n_sensors, but generate MAX_SENSORS of them angle_step = 2 * jnp.pi / robot.n_sensors angles = jnp.arange(MAX_SENSORS) * angle_step noise = jax.random.normal(subkey, (MAX_SENSORS,)) * robot.sensor_noise - - readings = jax.vmap(lambda a: world.ray_distance( - ray_start=pose.p, - ray_dir=jnp.array([ - jnp.cos(pose.hd + a), - jnp.sin(pose.hd + a) - ]), - max_dist=robot.sensor_range - ))(angles) - + + readings = jax.vmap( + lambda a: world.ray_distance( + ray_start=pose.p, + ray_dir=jnp.array([jnp.cos(pose.hd + a), jnp.sin(pose.hd + a)]), + max_dist=robot.sensor_range, + ) + )(angles) + # Create a mask for the first n_sensors elements mask = jnp.arange(MAX_SENSORS) < robot.n_sensors - + # Apply mask and pad with zeros readings = (readings + noise) * mask - + return readings, key @jax.jit -def execute_control(world: World, robot: RobotCapabilities, - current_pose: Pose, control: Tuple[float, float], - key: PRNGKey) -> Tuple[Pose, jnp.ndarray, PRNGKey]: +def execute_control( + world: World, + robot: RobotCapabilities, + current_pose: Pose, + control: Tuple[float, float], + key: PRNGKey, +) -> Tuple[Pose, jnp.ndarray, PRNGKey]: """Execute a control command with noise, stopping if we hit a wall - + Args: control: (distance, angle) pair where: - angle is how much to turn FIRST @@ -237,118 +260,116 @@ def execute_control(world: World, robot: RobotCapabilities, """ dist, angle = control k1, k2, k3 = jax.random.split(key, 3) - + # Add noise to motion noisy_dist = dist + jax.random.normal(k1) * robot.p_noise noisy_angle = angle + jax.random.normal(k2) * robot.hd_noise - + # First rotate (can always rotate) new_pose = current_pose.rotate(noisy_angle) - + # Check distance to wall in our current heading direction min_dist = world.ray_distance( - ray_start=new_pose.p, - ray_dir=new_pose.dp(), - max_dist=robot.sensor_range + ray_start=new_pose.p, ray_dir=new_pose.dp(), max_dist=robot.sensor_range ) - + # Only move as far as we can before hitting a wall safe_dist = jnp.minimum(noisy_dist, min_dist - WALL_COLLISION_THRESHOLD) safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards - + new_pose = new_pose.step_along(safe_dist) - + # Get sensor readings from new position readings, k4 = get_sensor_readings(world, robot, new_pose, k3) - + return new_pose, readings, k4 + @jax.jit -def simulate_robot_path(world: World, robot: RobotCapabilities, - start_pose: Pose, controls: jnp.ndarray, - key: jnp.ndarray): +def simulate_robot_path( + world: World, + robot: RobotCapabilities, + start_pose: Pose, + controls: jnp.ndarray, + key: jnp.ndarray, +): """Simulate robot path with noise and sensor readings - + Returns: Tuple of: - Array of shape [n_steps, 2] containing positions - Array of shape [n_steps] containing headings - Array of shape [n_steps, n_sensors] containing sensor readings """ + def step_fn(carry, control): pose, k = carry new_pose, readings, new_key = execute_control( - world=world, - robot=robot, - current_pose=pose, - control=control, - key=k + world=world, robot=robot, current_pose=pose, control=control, key=k ) return (new_pose, new_key), (new_pose, readings) - + (_, _), (poses, readings) = jax.lax.scan(step_fn, (start_pose, key), controls) - + # Extract positions and headings - positions = jnp.concatenate([ - start_pose.p[None, :], - jax.vmap(lambda p: p.p)(poses) - ]) - headings = jnp.concatenate([ - jnp.array([start_pose.hd]), - jax.vmap(lambda p: p.hd)(poses) - ]) - + positions = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) + headings = jnp.concatenate( + [jnp.array([start_pose.hd]), jax.vmap(lambda p: p.hd)(poses)] + ) + return positions, headings, readings + @partial(jax.jit, static_argnums=(1)) -def sample_possible_paths(key: jnp.ndarray, n_paths: int, - robot_path: jnp.ndarray, world: World, - robot: RobotCapabilities): +def sample_possible_paths( + key: jnp.ndarray, + n_paths: int, + robot_path: jnp.ndarray, + world: World, + robot: RobotCapabilities, +): """Generate n possible paths given the planned path, respecting walls""" path_points = robot_path[:, :2] controls = path_to_controls(path_points) - + start_point = path_points[0] start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) - + keys = jax.random.split(key, n_paths) - + # Vectorize over different random keys - return jax.vmap(lambda k: simulate_robot_path( - world=world, - robot=robot, - start_pose=start_pose, - controls=controls, - key=k - ))(keys) + return jax.vmap( + lambda k: simulate_robot_path( + world=world, robot=robot, start_pose=start_pose, controls=controls, key=k + ) + )(keys) + def walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: """Convert wall vertices from UI format to JAX array of wall segments""" if not walls_list: return jnp.array([]).reshape((0, 2, 2)) - + points = jnp.array(walls_list, dtype=jnp.float32) p1 = points[:-1] p2 = points[1:] - - segments = jnp.stack([ - p1[:, :2], - p2[:, :2] - ], axis=1) - + + segments = jnp.stack([p1[:, :2], p2[:, :2]], axis=1) + valid_mask = p1[:, 2] == p2[:, 2] return segments * valid_mask[:, None, None] + def simulate_robot_uncertainty(widget, e, seed=None): """Handle updates to robot simulation""" if not widget.state.robot_path: return - + current_seed = jnp.array(seed if seed is not None else widget.state.current_seed) assert jnp.issubdtype(current_seed.dtype, jnp.integer), "Seed must be an integer" - + current_key = PRNGKey(current_seed) - + # Create world and robot objects world = World(walls_to_jax(widget.state.walls)) robot = RobotCapabilities( @@ -356,41 +377,43 @@ def simulate_robot_uncertainty(widget, e, seed=None): hd_noise=widget.state.motion_noise * widget.state.heading_noise_scale, sensor_noise=widget.state.sensor_noise, n_sensors=widget.state.n_sensors, - sensor_range=10.0 + sensor_range=10.0, ) - + path = jnp.array(widget.state.robot_path, dtype=jnp.float32) - + # Sample all paths at once (1 true path + N possible paths) n_possible = 40 all_paths, all_headings, all_readings = sample_possible_paths( current_key, n_possible + 1, path, world, robot ) - + # First path is the "true" path true_path = all_paths[0] final_readings = all_readings[0, -1] final_heading = all_headings[0, -1] - + # Remaining paths are possible paths possible_paths = all_paths[1:] - - widget.state.update({ - "robot_pose": { - "x": float(true_path[-1, 0]), - "y": float(true_path[-1, 1]), - "heading": float(final_heading) - }, - "possible_paths": possible_paths, - "sensor_readings": final_readings, - "true_path": [[float(x), float(y)] for x, y in true_path], - "show_debug": True, - "current_seed": current_seed - }) + + widget.state.update( + { + "robot_pose": { + "x": float(true_path[-1, 0]), + "y": float(true_path[-1, 1]), + "heading": float(final_heading), + }, + "possible_paths": possible_paths, + "sensor_readings": final_readings, + "true_path": [[float(x), float(y)] for x, y in true_path], + "show_debug": True, + "current_seed": current_seed, + } + ) drawing_system_handler = Plot.js("""({points, simplify}) => { - mode = $state.selected_tool + mode = $state.selected_tool if (mode === 'walls') { $state.update(['walls', 'concat', simplify(0.25)]) } @@ -399,36 +422,17 @@ def simulate_robot_uncertainty(widget, e, seed=None): } }""") -sliders = ( - Plot.Slider( - "motion_noise", - range=[0, 0.5], - step=0.01, - label="Motion Noise:", - showValue=True - ) - & Plot.Slider( - "heading_noise_scale", - range=[0, 1], - step=0.05, - label="Heading Noise Scale:", - showValue=True - ) - | Plot.Slider( - "sensor_noise", - range=[0, 1], - step=0.02, - label="Sensor Noise:", - showValue=True - ) - & Plot.Slider( - "n_sensors", - range=[4, 32], - step=1, - label="Sensors:", - showValue=True - ) - ) +sliders = Plot.Slider( + "motion_noise", range=[0, 0.5], step=0.01, label="Motion Noise:", showValue=True +) & Plot.Slider( + "heading_noise_scale", + range=[0, 1], + step=0.05, + label="Heading Noise Scale:", + showValue=True, +) | Plot.Slider( + "sensor_noise", range=[0, 1], step=0.02, label="Sensor Noise:", showValue=True +) & Plot.Slider("n_sensors", range=[4, 32], step=1, label="Sensors:", showValue=True) def create_initial_state(seed) -> Dict[str, Any]: @@ -436,10 +440,14 @@ def create_initial_state(seed) -> Dict[str, Any]: return { "walls": [ # Frame around domain (timestamp 0) - [0, 0, 0], [10, 0, 0], # Bottom - [10, 0, 0], [10, 10, 0], # Right - [10, 10, 0], [0, 10, 0], # Top - [0, 10, 0], [0, 0, 0], # Left + [0, 0, 0], + [10, 0, 0], # Bottom + [10, 0, 0], + [10, 10, 0], # Right + [10, 10, 0], + [0, 10, 0], # Top + [0, 10, 0], + [0, 0, 0], # Left ], "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, "sensor_noise": 0.1, @@ -454,22 +462,29 @@ def create_initial_state(seed) -> Dict[str, Any]: "sensor_readings": [], "sensor_explore_angle": -1, "show_uncertainty": True, - "show_true_position": False, + "show_true_position": False, "current_line": [], - "current_seed": seed + "current_seed": seed, } - + + true_position_toggle = Plot.html( - ["label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", - ["input", { - "type": "checkbox", - "checked": js("$state.show_true_position"), - "onChange": js("(e) => $state.show_true_position = e.target.checked") - }], "Show true position"] + [ + "label.flex.items-center.gap-2.p-2.bg-gray-100.rounded.hover:bg-gray-300", + [ + "input", + { + "type": "checkbox", + "checked": js("$state.show_true_position"), + "onChange": js("(e) => $state.show_true_position = e.target.checked"), + }, + ], + "Show true position", + ] ) sensor_rays = Plot.line( - js(""" + js(""" Array.from($state.sensor_readings).map((r, i) => { const heading = $state.robot_pose.heading || 0; const n_sensors = $state.n_sensors; @@ -478,17 +493,16 @@ def create_initial_state(seed) -> Dict[str, Any]: const y = $state.robot_pose.y; return [ [x, y, i], - [x + r * Math.cos(angle), + [x + r * Math.cos(angle), y + r * Math.sin(angle), i] ] }).flat() """), - z="2", - stroke="red", - strokeWidth=1, - marker="circle" - ) - + z="2", + stroke="red", + strokeWidth=1, + marker="circle", +) rotating_sensor_rays = ( @@ -504,12 +518,12 @@ def create_initial_state(seed) -> Dict[str, Any]: } else { angle += $state.current_seed || Math.random() * 2 * Math.PI; } - } + } const x = $state.robot_pose.x; const y = $state.robot_pose.y; return [ [0, 0, i], - [r * Math.cos(angle), + [r * Math.cos(angle), r * Math.sin(angle), i] ] }).flat() @@ -517,21 +531,24 @@ def create_initial_state(seed) -> Dict[str, Any]: z="2", stroke="red", strokeWidth=1, - marker="circle" + marker="circle", ) # move the mouse around the plot to rotate the sensors - + Plot.events({"onMouseMove": - Plot.js("""(e) => { + + Plot.events( + { + "onMouseMove": Plot.js("""(e) => { // Convert mouse position to angle from center // atan2 gives angle in radians from -pi to pi // Subtract pi/2 to make 12 o'clock 0 radians const angle = Math.atan2(e.y, e.x) - Math.PI/2; - + // Normalize to 0 to 2pi range const normalized = (angle + 2*Math.PI) % (2*Math.PI); - + $state.sensor_explore_angle = normalized; - }""")}) + }""") + } + ) + {"height": 200, "width": 200, "className": "bg-gray-100"} + Plot.aspectRatio(1) + Plot.domain([-10, 10]) @@ -541,78 +558,97 @@ def create_initial_state(seed) -> Dict[str, Any]: ) true_path = Plot.cond( - js("$state.show_true_position"), - [Plot.text( - js("[[$state.robot_pose.x, $state.robot_pose.y]]"), - text=Plot.constantly(emoji.robot), - fontSize=30, - textAnchor="middle", - dy="-0.35em", - rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI")), - Plot.line( - js("$state.true_path"), - stroke=Plot.constantly("True Path"), - strokeWidth=2 - ), - sensor_rays - ] - ) + js("$state.show_true_position"), + [ + Plot.text( + js("[[$state.robot_pose.x, $state.robot_pose.y]]"), + text=Plot.constantly(emoji.robot), + fontSize=30, + textAnchor="middle", + dy="-0.35em", + rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI"), + ), + Plot.line( + js("$state.true_path"), stroke=Plot.constantly("True Path"), strokeWidth=2 + ), + sensor_rays, + ], +) planned_path = Plot.line( - js("$state.robot_path"), - stroke=Plot.constantly("Robot Path"), - strokeWidth=2, - r=3, - marker="circle" - ) + js("$state.robot_path"), + stroke=Plot.constantly("Robot Path"), + strokeWidth=2, + r=3, + marker="circle", +) walls = Plot.line( - js("$state.walls"), - stroke=Plot.constantly("Walls"), - strokeWidth=WALL_WIDTH, - z="2", - render=Plot.renderChildEvents({"onClick": js("""(e) => { + js("$state.walls"), + stroke=Plot.constantly("Walls"), + strokeWidth=WALL_WIDTH, + z="2", + render=Plot.renderChildEvents( + { + "onClick": js("""(e) => { const zs = new Set($state.walls.map(w => w[2])); const targetZ = [...zs][e.index]; $state.walls = $state.walls.filter(([x, y, z]) => z !== targetZ) - }""")}) - ) + }""") + } + ), +) possible_paths = Plot.line( - js(""" + js( + """ if (!$state.show_debug || !$state.possible_paths) {return [];}; - return $state.possible_paths.flatMap((path, pathIdx) => + return $state.possible_paths.flatMap((path, pathIdx) => path.map(([x, y]) => [x, y, pathIdx]) ) - """, expression=False), - stroke="blue", - strokeOpacity=0.2, - z="2" - ) + """, + expression=False, + ), + stroke="blue", + strokeOpacity=0.2, + z="2", +) + def clear_state(w, _): """Reset visualization state""" - w.state.update(create_initial_state(w.state.current_seed) | {"selected_tool": w.state.selected_tool}) - - + w.state.update( + create_initial_state(w.state.current_seed) + | {"selected_tool": w.state.selected_tool} + ) + + selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" - -toolbar = Plot.html("Select tool:") | ["div.flex.gap-2", - [selectable_button, { - "data-selected": js("$state.selected_tool === 'path'"), - "onClick": js("() => $state.selected_tool = 'path'") - }, f"{emoji.robot} Path"], - [selectable_button, { - "data-selected": js("$state.selected_tool === 'walls'"), - "onClick": js("() => $state.selected_tool = 'walls'") - }, f"{emoji.pencil} Walls"], - [selectable_button, { - "onClick": clear_state - }, "Clear"] + +toolbar = Plot.html("Select tool:") | [ + "div.flex.gap-2", + [ + selectable_button, + { + "data-selected": js("$state.selected_tool === 'path'"), + "onClick": js("() => $state.selected_tool = 'path'"), + }, + f"{emoji.robot} Path", + ], + [ + selectable_button, + { + "data-selected": js("$state.selected_tool === 'walls'"), + "onClick": js("() => $state.selected_tool = 'walls'"), + }, + f"{emoji.pencil} Walls", + ], + [selectable_button, {"onClick": clear_state}, "Clear"], ] - + + def handleSeedIndex(w, e): - global key + global key try: if e.index == 0: seed = key[0] @@ -620,44 +656,48 @@ def handleSeedIndex(w, e): key = split(key, 2)[0] seed = key[0] else: - seed = split(key, e.index)[e.index-1][0] + seed = split(key, e.index)[e.index - 1][0] simulate_robot_uncertainty(w, e, seed=seed) except Exception as err: print(f"Error handling seed index: {err}, {e.key}, {e.index}") + key_scrubber = v.key_scrubber(handleSeedIndex) canvas = ( - v.drawing_system("current_line", drawing_system_handler) - + walls - + planned_path - + true_path - + possible_paths - + Plot.domain([0, 10], [0, 10]) - + Plot.grid() - + Plot.aspectRatio(1) - + Plot.colorMap({ + v.drawing_system("current_line", drawing_system_handler) + + walls + + planned_path + + true_path + + possible_paths + + Plot.domain([0, 10], [0, 10]) + + Plot.grid() + + Plot.aspectRatio(1) + + Plot.colorMap( + { "Walls": "#666", "Sensor Rays": "red", "True Path": "green", "Robot Path": "blue", - }) - + Plot.colorLegend() - + Plot.clip() + } ) + + Plot.colorLegend() + + Plot.clip() +) ( - canvas & - (sliders | toolbar | true_position_toggle | key_scrubber | rotating_sensor_rays) + canvas + & (sliders | toolbar | true_position_toggle | key_scrubber | rotating_sensor_rays) & {"widths": ["400px", 1]} | Plot.initialState(create_initial_state(0), sync=True) - | Plot.onChange({ - "robot_path": simulate_robot_uncertainty, - "sensor_noise": simulate_robot_uncertainty, - "motion_noise": simulate_robot_uncertainty, - "heading_noise_scale": simulate_robot_uncertainty, - "n_sensors": simulate_robot_uncertainty, - "walls": simulate_robot_uncertainty - }) + | Plot.onChange( + { + "robot_path": simulate_robot_uncertainty, + "sensor_noise": simulate_robot_uncertainty, + "motion_noise": simulate_robot_uncertainty, + "heading_noise_scale": simulate_robot_uncertainty, + "n_sensors": simulate_robot_uncertainty, + "walls": simulate_robot_uncertainty, + } + ) ) - From 3457e65406838198bb9fb199370387069113363f Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 2 Dec 2024 17:52:58 +0100 Subject: [PATCH 80/86] revert changes --- .../probcomp-localization-tutorial.py | 8 ++------ robot_2/emoji.py | 2 ++ robot_2/visualization.py | 2 -- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 668ff35..3ac2e13 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -138,11 +138,7 @@ def create_segments(points): return jnp.stack([points, jnp.roll(points, shift=-1, axis=0)], axis=1) -def make_world(wall_verts, clutters_vec, start, controls) -> tuple[ - dict[str, FloatArray | tuple[float, float, float, float] | float | Pose], - dict[str, Control | Pose], - int -]: +def make_world(wall_verts, clutters_vec, start, controls): """ Constructs the world by creating segments for walls and clutters, calculates the bounding box, and prepares the simulation parameters. @@ -1483,7 +1479,7 @@ def resample( ) winners = jax.vmap(genjax.categorical.sampler)( jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) - ) # indices returned are relative to the start of the K-segment from which they were drawn. + # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) selected = jax.tree.map(lambda x: x[winners], samples) diff --git a/robot_2/emoji.py b/robot_2/emoji.py index b43969f..7d899cd 100644 --- a/robot_2/emoji.py +++ b/robot_2/emoji.py @@ -1,3 +1,5 @@ +# either Cursor or Claude has a bug where strings containing emoji break responses. +# so I keep them in a separate file. robot = "🤖" pencil = "✏️" recycle = "♻️" diff --git a/robot_2/visualization.py b/robot_2/visualization.py index ce268c4..bcc20fb 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -1,8 +1,6 @@ from genstudio.plot import js import genstudio.plot as Plot import robot_2.emoji as emoji -from typing import Dict, List, Union, Any -import jax.numpy as jnp def drawing_system(key, on_complete): From c360accbbb6b41a60c084c649d6002092c359b9c Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Wed, 4 Dec 2024 17:52:39 +0100 Subject: [PATCH 81/86] plot order --- robot_2/where_am_i.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index cc64188..10104ba 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -560,6 +560,10 @@ def create_initial_state(seed) -> Dict[str, Any]: true_path = Plot.cond( js("$state.show_true_position"), [ + Plot.line( + js("$state.true_path"), stroke=Plot.constantly("True Path"), strokeWidth=2 + ), + sensor_rays, Plot.text( js("[[$state.robot_pose.x, $state.robot_pose.y]]"), text=Plot.constantly(emoji.robot), @@ -568,10 +572,6 @@ def create_initial_state(seed) -> Dict[str, Any]: dy="-0.35em", rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI"), ), - Plot.line( - js("$state.true_path"), stroke=Plot.constantly("True Path"), strokeWidth=2 - ), - sensor_rays, ], ) @@ -668,8 +668,8 @@ def handleSeedIndex(w, e): v.drawing_system("current_line", drawing_system_handler) + walls + planned_path - + true_path + possible_paths + + true_path + Plot.domain([0, 10], [0, 10]) + Plot.grid() + Plot.aspectRatio(1) @@ -683,6 +683,7 @@ def handleSeedIndex(w, e): ) + Plot.colorLegend() + Plot.clip() + + Plot.gridX(interval=1) ) ( From e2fedadae632fcd41d86ef1e79a75f67e4dbd0be Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Wed, 4 Dec 2024 17:56:27 +0100 Subject: [PATCH 82/86] formatting / ruff --- .../probcomp-localization-tutorial.py | 4 ++-- robot_2/where_am_i.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 3ac2e13..bfd2738 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -1478,7 +1478,7 @@ def resample( jax.random.split(key1, N * K), constraints, (motion_settings,) ) winners = jax.vmap(genjax.categorical.sampler)( - jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) + jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) @@ -1669,7 +1669,7 @@ def grid_sample(gf, pose_grid, observation): def flatten_pose_cube(pose_grid, cube_step_size, scores): n_indices = 2 * cube_step_size + 1 best_heading_indices = jnp.argmax( - scores.reshape(n_indices * n_indices, n_indices), axis=1 + scores.reshape(n_indices * n_indices, n_indices), axis=1) # those were block relative; linearize them by adding back block indices bs = best_heading_indices + jnp.arange(0, n_indices**3, n_indices) return Pose(pose_grid.p[bs], pose_grid.hd[bs]), scores[bs] diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index 10104ba..c1c6160 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -50,7 +50,6 @@ # pyright: reportUnusedExpression=false # pyright: reportUnknownMemberType=false -from dataclasses import dataclass from functools import partial from typing import List, Tuple, Any, Dict @@ -58,7 +57,6 @@ import genstudio.plot as Plot import jax import jax.numpy as jnp -import numpy as np from jax.random import PRNGKey, split from penzai import pz from genstudio.plot import js From c78f8581c1b25ab3d00e8a8437e2970beea903a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:56:45 +0000 Subject: [PATCH 83/86] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../probcomp-localization-tutorial.py | 157 ++++++++++++------ 1 file changed, 102 insertions(+), 55 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index bfd2738..832f6bf 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -787,7 +787,10 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) -def rotate_trace(key: PRNGKey, trace: genjax.Trace[Pose], angle) -> tuple[genjax.Trace[Pose], genjax.Weight]: + +def rotate_trace( + key: PRNGKey, trace: genjax.Trace[Pose], angle +) -> tuple[genjax.Trace[Pose], genjax.Weight]: """Returns a modified trace with the heading set to the given angle (in radians), along with the weight difference.""" key, sub_key = jax.random.split(key) rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( @@ -795,25 +798,37 @@ def rotate_trace(key: PRNGKey, trace: genjax.Trace[Pose], angle) -> tuple[genjax ) return rotated_trace, rotated_trace_weight_diff + rotated_trace, rotated_trace_weight_diff = rotate_trace(key, trace, jnp.pi / 2.0) + def set_angle(widget: Widget, e): angle = float(e["value"]) rotated_trace, rotated_trace_weight_diff = rotate_trace(key, trace, angle) - widget.state.update({ - "rotated_poses": pose_plot(rotated_trace.get_retval(), - fill=Plot.constantly("with heading modified")), - "angle": angle, - "rotated_trace_weight_diff": rotated_trace_weight_diff}) + widget.state.update( + { + "rotated_poses": pose_plot( + rotated_trace.get_retval(), + fill=Plot.constantly("with heading modified"), + ), + "angle": angle, + "rotated_trace_weight_diff": rotated_trace_weight_diff, + } + ) -( - Plot.initialState({ - "poses": pose_plot(trace.get_retval(), fill=Plot.constantly("some pose")), - "rotated_poses": pose_plot(rotated_trace.get_retval(), fill=Plot.constantly("with heading modified")), - "rotated_trace_weight_diff": rotated_trace_weight_diff, - "angle": jnp.pi / 2.0 - }) +( + Plot.initialState( + { + "poses": pose_plot(trace.get_retval(), fill=Plot.constantly("some pose")), + "rotated_poses": pose_plot( + rotated_trace.get_retval(), + fill=Plot.constantly("with heading modified"), + ), + "rotated_trace_weight_diff": rotated_trace_weight_diff, + "angle": jnp.pi / 2.0, + } + ) | Plot.new( world_plot + Plot.js("$state.poses") @@ -821,17 +836,26 @@ def set_angle(widget: Widget, e): + Plot.color_map({"some pose": "green", "with heading modified": "red"}) + Plot.title("Modifying a heading") ) - | html(["span.tc", Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`")]) + | html( + [ + "span.tc", + Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`"), + ] + ) | ( Plot.js("`angle: ${$state.angle.toFixed(2)}`") - & ["input", {"type": "range", - "name": "angle", - "defaultValue": Plot.js("$state.angle"), - "min": -jnp.pi / 2, - "max": jnp.pi / 2, - "step": 0.1, - "onChange": set_angle - }] + & [ + "input", + { + "type": "range", + "name": "angle", + "defaultValue": Plot.js("$state.angle"), + "min": -jnp.pi / 2, + "max": jnp.pi / 2, + "step": 0.1, + "onChange": set_angle, + }, + ] & {"widths": ["80px", "200px"]} ) ).widget() @@ -843,7 +867,10 @@ def set_angle(widget: Widget, e): key, sub_key = jax.random.split(key) trace = generate_path_trace(sub_key) -def rotate_first_step(key: PRNGKey, trace: genjax.Trace[Pose], angle) -> tuple[genjax.Trace[Pose], genjax.Weight]: + +def rotate_first_step( + key: PRNGKey, trace: genjax.Trace[Pose], angle +) -> tuple[genjax.Trace[Pose], genjax.Weight]: """Returns a modified trace with the first step's heading set to the given angle (in radians), along with the weight difference.""" key, sub_key = jax.random.split(key) rotated_trace, rotated_trace_weight_diff, _, _ = trace.update( @@ -851,31 +878,39 @@ def rotate_first_step(key: PRNGKey, trace: genjax.Trace[Pose], angle) -> tuple[g ) return rotated_trace, rotated_trace_weight_diff + def set_first_step_angle(widget: Widget, e): angle = float(e["value"]) rotated_trace, rotated_trace_weight_diff = rotate_first_step(key, trace, angle) - widget.state.update({ - "rotated_path": [ - pose_plot(pose, fill=Plot.constantly("with heading modified")) - for pose in path_from_trace(rotated_trace) - ], - "angle": angle, - "rotated_trace_weight_diff": rotated_trace_weight_diff - }) + widget.state.update( + { + "rotated_path": [ + pose_plot(pose, fill=Plot.constantly("with heading modified")) + for pose in path_from_trace(rotated_trace) + ], + "angle": angle, + "rotated_trace_weight_diff": rotated_trace_weight_diff, + } + ) + ( - Plot.initialState({ - "original_path": [ - pose_plot(pose, fill=Plot.constantly("some path")) - for pose in path_from_trace(trace) - ], - "rotated_path": [ - pose_plot(pose, fill=Plot.constantly("with heading modified")) - for pose in path_from_trace(rotate_first_step(key, trace, jnp.pi / 2.0)[0]) - ], - "rotated_trace_weight_diff": rotate_first_step(key, trace, jnp.pi / 2.0)[1], - "angle": jnp.pi / 2.0 - }) + Plot.initialState( + { + "original_path": [ + pose_plot(pose, fill=Plot.constantly("some path")) + for pose in path_from_trace(trace) + ], + "rotated_path": [ + pose_plot(pose, fill=Plot.constantly("with heading modified")) + for pose in path_from_trace( + rotate_first_step(key, trace, jnp.pi / 2.0)[0] + ) + ], + "rotated_trace_weight_diff": rotate_first_step(key, trace, jnp.pi / 2.0)[1], + "angle": jnp.pi / 2.0, + } + ) | Plot.new( world_plot + Plot.js("$state.rotated_path") @@ -883,17 +918,26 @@ def set_first_step_angle(widget: Widget, e): + Plot.color_map({"some path": "green", "with heading modified": "red"}) + Plot.title("Modifying first step heading") ) - | html(["span.tc", Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`")]) + | html( + [ + "span.tc", + Plot.js("`score ratio: ${$state.rotated_trace_weight_diff.toFixed(2)}`"), + ] + ) | ( Plot.js("`angle: ${$state.angle.toFixed(2)}`") - & ["input", {"type": "range", - "name": "angle", - "defaultValue": Plot.js("$state.angle"), - "min": -jnp.pi / 2, - "max": jnp.pi / 2, - "step": 0.1, - "onChange": set_first_step_angle - }] + & [ + "input", + { + "type": "range", + "name": "angle", + "defaultValue": Plot.js("$state.angle"), + "min": -jnp.pi / 2, + "max": jnp.pi / 2, + "step": 0.1, + "onChange": set_first_step_angle, + }, + ] & {"widths": ["80px", "200px"]} ) ).widget() @@ -1072,7 +1116,8 @@ def full_model(motion_settings): return ( full_model_kernel.partial_apply(motion_settings).scan()( robot_inputs["start"], robot_inputs["controls"] - ) @ "steps" + ) + @ "steps" ) @@ -1478,7 +1523,8 @@ def resample( jax.random.split(key1, N * K), constraints, (motion_settings,) ) winners = jax.vmap(genjax.categorical.sampler)( - jax.random.split(key2, K), jnp.reshape(log_weights, (K, N))) + jax.random.split(key2, K), jnp.reshape(log_weights, (K, N)) + ) # indices returned are relative to the start of the K-segment from which they were drawn. # globalize the indices by adding back the index of the start of each segment. winners += jnp.arange(0, N * K, N) @@ -1669,7 +1715,8 @@ def grid_sample(gf, pose_grid, observation): def flatten_pose_cube(pose_grid, cube_step_size, scores): n_indices = 2 * cube_step_size + 1 best_heading_indices = jnp.argmax( - scores.reshape(n_indices * n_indices, n_indices), axis=1) + scores.reshape(n_indices * n_indices, n_indices), axis=1 + ) # those were block relative; linearize them by adding back block indices bs = best_heading_indices + jnp.arange(0, n_indices**3, n_indices) return Pose(pose_grid.p[bs], pose_grid.hd[bs]), scores[bs] From 5257b313cea762ef50a3a904c981a286693abee8 Mon Sep 17 00:00:00 2001 From: Matt Huebert Date: Mon, 27 Jan 2025 06:17:25 +0100 Subject: [PATCH 84/86] robot viz in genjax (GEN-884) (#20) --- poetry.lock | 35 +- pyproject.toml | 5 +- robot_2/bench.py | 290 ++++++++++++++++ robot_2/test_where_am_i.py | 88 +++-- robot_2/visualization.py | 24 +- robot_2/where_am_i.py | 666 ++++++++++++++++++++++--------------- 6 files changed, 775 insertions(+), 333 deletions(-) create mode 100644 robot_2/bench.py diff --git a/poetry.lock b/poetry.lock index 35d94d9..81e73d6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -688,13 +688,13 @@ files = [ [[package]] name = "genjax" -version = "0.7.0.post4.dev0+eacb241e" +version = "0.8.0" description = "Probabilistic programming with Gen, built on top of JAX." optional = false python-versions = ">=3.10,<3.13" files = [ - {file = "genjax-0.7.0.post4.dev0+eacb241e-py3-none-any.whl", hash = "sha256:c6374155c6b772e65919115613264e372606844438e75a6ef1d3db0350d5c79f"}, - {file = "genjax-0.7.0.post4.dev0+eacb241e.tar.gz", hash = "sha256:d738d029a7a5a40390236ab35b8d1e8745c7ba02cc5fc1d2723382c1f8b0cb01"}, + {file = "genjax-0.8.0-py3-none-any.whl", hash = "sha256:ac7d38e24036afdeba6bfd200180e939a950d0e78c91133604b02cda05c4e968"}, + {file = "genjax-0.8.0.tar.gz", hash = "sha256:f207dbe3021750c445cec2cf14d79da69f0b234e10385a56e8f2f06247331931"}, ] [package.dependencies] @@ -708,8 +708,8 @@ tensorflow-probability = ">=0.23.0,<0.24.0" treescope = ">=0.1.5,<0.2.0" [package.extras] -all = ["genstudio (==2024.09.003)"] -genstudio = ["genstudio (==2024.09.003)"] +all = ["genstudio (==2024.11.015)"] +genstudio = ["genstudio (==2024.11.015)"] [package.source] type = "legacy" @@ -718,23 +718,26 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.11.021" +version = "2024.12.4" description = "" optional = false python-versions = ">=3.10,<3.13" -files = [] -develop = true +files = [ + {file = "genstudio-2024.12.4-py3-none-any.whl", hash = "sha256:280154e6facb55a73b66b8b6229c847c8c57377faf3022ff73f564d252e52a8f"}, + {file = "genstudio-2024.12.4.tar.gz", hash = "sha256:f3c975e61c068e7c2a6a75be018018df16e4bb38a693eadeddcfff86dcff2da7"}, +] [package.dependencies] -anywidget = "^0.9.10" -html2image = "^2.0.4.3" -orjson = "^3.10.6" -pillow = "^10.4.0" -traitlets = "^5.14.3" +anywidget = ">=0.9.10,<0.10.0" +html2image = ">=2.0.4.3,<3.0.0.0" +orjson = ">=3.10.6,<4.0.0" +pillow = ">=10.4.0,<11.0.0" +traitlets = ">=5.14.3,<6.0.0" [package.source] -type = "directory" -url = "../genstudio" +type = "legacy" +url = "https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple" +reference = "gcp" [[package]] name = "html2image" @@ -2713,4 +2716,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "3d9a2ab9db39381cf0b6982d7546fd98231b1d83e37bac488f6d0dbc879aa7e8" +content-hash = "48a726e720b1311bc7f5b1738c8b8576f08d70d07f93c24cd2cae92a2cc85e9f" diff --git a/pyproject.toml b/pyproject.toml index 103f425..930f6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,8 @@ packages = [ [tool.poetry.dependencies] python = ">=3.11,<3.13" jupytext = "^1.16.1" -genjax = {version = "0.7.0.post4.dev0+eacb241e", source = "gcp" } -# genstudio = {version = "2024.12.003", source = "gcp"} -genstudio = {path = "../genstudio", develop = true} +genjax = {version = "0.8.0", source = "gcp" } +genstudio = {version = "2024.12.004", source = "gcp"} ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7" diff --git a/robot_2/bench.py b/robot_2/bench.py new file mode 100644 index 0000000..f487e7d --- /dev/null +++ b/robot_2/bench.py @@ -0,0 +1,290 @@ +# %% +import jax.numpy as jnp +import time +from robot_2.where_am_i import ( + RobotCapabilities, + simulate_robot, + World, + walls_to_jax, + PRNGKey, +) +import jax.random + +# Benchmark State: +# Walls: +walls = [ + [0, 0, 0], + [10, 0, 0], + [10, 0, 0], + [10, 10, 0], + [10, 10, 0], + [0, 10, 0], + [0, 10, 0], + [0, 0, 0], + [8.125, 9.03125, 1974574552], + [7.875, 9.0625, 1974574552], + [7.53125, 9.0625, 1974574552], + [7.09375, 9.09375, 1974574552], + [6.8125, 9.09375, 1974574552], + [6.5, 9.125, 1974574552], + [6.1875, 9.125, 1974574552], + [5.71875, 9.15625, 1974574552], + [5.40625, 9.15625, 1974574552], + [5.09375, 9.15625, 1974574552], + [4.75, 9.15625, 1974574552], + [4.4375, 9.125, 1974574552], + [4, 9.09375, 1974574552], + [3.71875, 9.0625, 1974574552], + [3.46875, 9.03125, 1974574552], + [3.21875, 9, 1974574552], + [2.96875, 8.9375, 1974574552], + [2.5625, 8.84375, 1974574552], + [2.09375, 8.75, 1974574552], + [1.71875, 8.625, 1974574552], + [1.75, 8.34375, 1974574552], + [2.03125, 8.25, 1974574552], + [2.34375, 8.1875, 1974574552], + [2.71875, 8.15625, 1974574552], + [3.0625, 8.125, 1974574552], + [3.40625, 8.09375, 1974574552], + [3.96875, 8.0625, 1974574552], + [4.3125, 8.03125, 1974574552], + [4.625, 8.03125, 1974574552], + [4.90625, 8, 1974574552], + [5.3125, 8, 1974574552], + [5.59375, 8, 1974574552], + [5.875, 7.9375, 1974574552], + [6.15625, 7.875, 1974574552], + [6.4375, 7.75, 1974574552], + [6.6875, 7.59375, 1974574552], + [6.8125, 7.375, 1974574552], + [6.6875, 7.15625, 1974574552], + [6.40625, 7, 1974574552], + [5.78125, 6.875, 1974574552], + [5.4375, 6.8125, 1974574552], + [5.09375, 6.75, 1974574552], + [4.34375, 6.625, 1974574552], + [3.90625, 6.5625, 1974574552], + [3.65625, 6.5, 1974574552], + [3.1875, 6.4375, 1974574552], + [2.625, 6.28125, 1974574552], + [2.65625, 5.96875, 1974574552], + [2.9375, 5.84375, 1974574552], + [3.25, 5.78125, 1974574552], + [3.53125, 5.75, 1974574552], + [4.03125, 5.71875, 1974574552], + [4.46875, 5.6875, 1974574552], + [5.21875, 5.6875, 1974574552], + [5.75, 5.6875, 1974574552], + [6.0625, 5.6875, 1974574552], + [6.3125, 5.6875, 1974574552], + [6.5625, 5.65625, 1974574552], + [6.59375, 5.3125, 1974574552], + [6.375, 5.1875, 1974574552], + [6, 5.0625, 1974574552], + [5.3125, 4.9375, 1974574552], + [4.9375, 4.875, 1974574552], + [4.1875, 4.8125, 1974574552], + [3.875, 4.75, 1974574552], + [3.46875, 4.71875, 1974574552], + [3.0625, 4.65625, 1974574552], + [2.75, 4.59375, 1974574552], + [2.46875, 4.5, 1974574552], + [2.78125, 4.1875, 1974574552], + [3.15625, 4.125, 1974574552], + [3.53125, 4.0625, 1974574552], + [4, 4, 1974574552], + [4.46875, 3.9375000000000004, 1974574552], + [5.25, 3.8749999999999996, 1974574552], + [5.75, 3.84375, 1974574552], + [6, 3.84375, 1974574552], + [6.28125, 3.84375, 1974574552], + [6.75, 3.84375, 1974574552], + [6.5, 3.53125, 1974574552], + [6.25, 3.5, 1974574552], + [5.875, 3.4375, 1974574552], + [5.59375, 3.4062499999999996, 1974574552], + [5.0625, 3.375, 1974574552], + [4.625, 3.375, 1974574552], + [3.9375, 3.34375, 1974574552], + [2.875, 3.3125000000000004, 1974574552], + [1.75, 3.1562500000000004, 1974574552], + [1.5, 3.0625, 1974574552], + [1.53125, 2.7812499999999996, 1974574552], + [1.8125, 2.65625, 1974574552], + [2.0625, 2.6249999999999996, 1974574552], + [2.625, 2.5, 1974574552], + [2.875, 2.4687499999999996, 1974574552], + [3.375, 2.40625, 1974574552], + [3.9375, 2.3750000000000004, 1974574552], + [4.8125, 2.34375, 1974574552], + [5.34375, 2.34375, 1974574552], + [5.875, 2.34375, 1974574552], + [6.34375, 2.34375, 1974574552], + [6.59375, 2.34375, 1974574552], + [7.15625, 2.34375, 1974574552], + [7.4375, 2.28125, 1974574552], + [7.1875, 2.03125, 1974574552], + [6.75, 1.9374999999999998, 1974574552], + [6.4375, 1.875, 1974574552], + [5.59375, 1.7812499999999998, 1974574552], + [5.125, 1.7812499999999998, 1974574552], + [4.625, 1.7812499999999998, 1974574552], + [3.96875, 1.71875, 1974574552], + [3.4375, 1.6562500000000002, 1974574552], + [3.125, 1.5937500000000004, 1974574552], + [2.71875, 1.4687499999999998, 1974574552], + [2.40625, 1.3437500000000002, 1974574552], + [2.0625, 1.2187499999999996, 1974574552], + [2.5, 1.0312500000000002, 1974574552], + [2.84375, 1.0312500000000002, 1974574552], + [3.375, 0.9999999999999998, 1974574552], + [3.96875, 0.9687500000000004, 1974574552], + [4.6875, 0.9375, 1974574552], + [5.21875, 0.9375, 1974574552], + [5.84375, 0.9687500000000004, 1974574552], + [6.4375, 1.0312500000000002, 1974574552], + [7, 1.1562499999999998, 1974574552], + [7.34375, 1.2187499999999996, 1974574552], + [7.40625, 1.25, 1974574552], +] + +# Robot Path: +robot_path = [ + [2.5, 0.2500000000000002, 1097479840], + [3.03125, 0.2500000000000002, 1097479840], + [3.625, 0.2500000000000002, 1097479840], + [4.125, 0.2500000000000002, 1097479840], + [4.625, 0.2500000000000002, 1097479840], + [5.375, 0.21874999999999978, 1097479840], + [5.90625, 0.18750000000000044, 1097479840], + [6.4375, 0.18750000000000044, 1097479840], + [6.96875, 0.18750000000000044, 1097479840], + [7.46875, 0.28124999999999956, 1097479840], + [7.9375, 0.46875, 1097479840], + [8, 0.9999999999999998, 1097479840], + [7.65625, 1.40625, 1097479840], + [7.15625, 1.5625, 1097479840], + [6.5625, 1.5625, 1097479840], + [6, 1.5312499999999996, 1097479840], + [5.5, 1.5000000000000002, 1097479840], + [4.9375, 1.40625, 1097479840], + [4.40625, 1.3749999999999996, 1097479840], + [4.9375, 1.3437500000000002, 1097479840], + [5.46875, 1.4375000000000004, 1097479840], + [6.0625, 1.4375000000000004, 1097479840], + [6.625, 1.4687499999999998, 1097479840], + [7.1875, 1.5000000000000002, 1097479840], + [7.65625, 1.71875, 1097479840], + [7.8125, 2.2187500000000004, 1097479840], + [7.71875, 2.71875, 1097479840], + [7.25, 2.90625, 1097479840], + [6.6875, 3.03125, 1097479840], + [6.03125, 3.125, 1097479840], + [5.4375, 3.125, 1097479840], + [4.875, 3.0625, 1097479840], + [4.28125, 3.03125, 1097479840], + [3.71875, 3.0000000000000004, 1097479840], + [4.34375, 2.96875, 1097479840], + [4.96875, 2.96875, 1097479840], + [5.65625, 3.0000000000000004, 1097479840], + [6.25, 3.0000000000000004, 1097479840], + [7.09375, 3.0000000000000004, 1097479840], + [7.75, 3.03125, 1097479840], + [8.25, 3.125, 1097479840], + [8.1875, 3.6875, 1097479840], + [7.59375, 4, 1097479840], + [6.96875, 4.21875, 1097479840], + [6.375, 4.375, 1097479840], + [5.875, 4.4375, 1097479840], + [5.375, 4.5, 1097479840], + [6, 4.5625, 1097479840], + [6.53125, 4.6875, 1097479840], + [7.0625, 4.78125, 1097479840], + [7.4375, 5.1875, 1097479840], + [7.1875, 5.6875, 1097479840], + [6.75, 6.03125, 1097479840], + [6.125, 6.34375, 1097479840], + [5.59375, 6.46875, 1097479840], + [5.0625, 6.46875, 1097479840], + [5.59375, 6.46875, 1097479840], + [6.78125, 6.5625, 1097479840], + [7.3125, 6.65625, 1097479840], + [7.9375, 6.90625, 1097479840], + [8.15625, 7.40625, 1097479840], + [7.6875, 7.71875, 1097479840], + [7.09375, 7.90625, 1097479840], + [6.46875, 8.0625, 1097479840], + [5.96875, 8.15625, 1097479840], + [5.375, 8.1875, 1097479840], + [4.6875, 8.3125, 1097479840], + [4.1875, 8.34375, 1097479840], + [4.6875, 8.4375, 1097479840], + [5.46875, 8.46875, 1097479840], + [5.96875, 8.46875, 1097479840], + [6.53125, 8.46875, 1097479840], + [7.09375, 8.46875, 1097479840], + [7.59375, 8.53125, 1097479840], + [8.15625, 8.625, 1097479840], + [8.65625, 8.8125, 1097479840], + [8.28125, 9.34375, 1097479840], + [7.75, 9.4375, 1097479840], + [7.15625, 9.5, 1097479840], + [6.625, 9.5, 1097479840], + [6, 9.5, 1097479840], + [5.375, 9.53125, 1097479840], + [4.8125, 9.53125, 1097479840], + [4.1875, 9.53125, 1097479840], + [3.5625, 9.4375, 1097479840], + [3, 9.4375, 1097479840], + [2.46875, 9.40625, 1097479840], + [1.84375, 9.34375, 1097479840], + [1.4375, 9.34375, 1097479840], +] + + +def perturb_walls(w, idx, amount=0.1): + # Add small random offset to one wall endpoint + w[idx][1] = w[idx][1] + amount + return w + + +def get_robot(p_noise): + return RobotCapabilities( + p_noise=jnp.array(p_noise), + hd_noise=jnp.array(0.03), + sensor_noise=jnp.array(0.1), + n_sensors=jnp.array(8), + sensor_range=jnp.array(10.0), + ) + + +# Test 2: Perturb robot capabilities + + +# Create random keys +keys = jax.random.split(PRNGKey(0), 100) + +# Test 1: Run simulation with perturbed walls +print("Test 1: Perturbing walls") +start = time.time() +all_paths_1, all_readings_1 = jax.vmap( + lambda k: simulate_robot( + World(*walls_to_jax(walls)), + get_robot(0.1), + jnp.array(robot_path), + k, + ) +)(keys) +print(f"Wall perturbation test took {time.time() - start:.3f} seconds") + +# Test 2: Run simulation with perturbed robot capabilities +print("\nTest 2: Perturbing robot capabilities") +start = time.time() +p_noises = jnp.linspace(0.05, 0.15, 100) # Range of p_noise values +all_paths_2, all_readings_2 = jax.vmap( + lambda k, p: simulate_robot( + World(*walls_to_jax(walls)), get_robot(p), jnp.array(robot_path), k + ) +)(keys, p_noises) +print(f"Robot capability perturbation test took {time.time() - start:.3f} seconds") diff --git a/robot_2/test_where_am_i.py b/robot_2/test_where_am_i.py index 28c5c1d..de4eecd 100644 --- a/robot_2/test_where_am_i.py +++ b/robot_2/test_where_am_i.py @@ -1,62 +1,94 @@ import jax.numpy as jnp +import jax import pytest -from robot_2.where_am_i import World, Pose, RobotCapabilities, execute_control +from robot_2.where_am_i import ( + World, + Pose, + RobotCapabilities, + execute_control, + walls_to_jax, +) import robot_2.where_am_i as where_am_i from jax.random import PRNGKey def test_basic_motion(): """Test that robot moves as expected without noise""" - # Convert walls to JAX array at creation - now in (N,2,2) shape - walls = jnp.array( - [ - [[0.0, 0.0], [1.0, 0.0]], # bottom wall - [[1.0, 0.0], [1.0, 1.0]], # right wall - [[1.0, 1.0], [0.0, 1.0]], # top wall - [[0.0, 1.0], [0.0, 0.0]], # left wall - ] - ) - world = World(walls) + # Create walls in UI format first + walls_list = [ + [0.0, 0.0, 0], # Bottom wall + [1.0, 0.0, 0], + [1.0, 0.0, 1], # Right wall + [1.0, 1.0, 1], + [1.0, 1.0, 2], # Top wall + [0.0, 1.0, 2], + [0.0, 1.0, 3], # Left wall + [0.0, 0.0, 3], + ] + + # Convert to JAX format + walls, wall_vecs = walls_to_jax(walls_list) + world = World(walls, wall_vecs) + robot = RobotCapabilities( - p_noise=0.0, hd_noise=0.0, sensor_noise=0.0, n_sensors=8, sensor_range=10.0 + p_noise=jnp.array(0.0), + hd_noise=jnp.array(0.0), + sensor_noise=jnp.array(0.0), + n_sensors=jnp.array(8), + sensor_range=jnp.array(10.0), ) - start_pose = Pose(jnp.array([0.5, 0.5]), 0.0) + start_pose = Pose(jnp.array([0.5, 0.5]), jnp.array(0.0)) key = PRNGKey(0) - + exec_sim = jax.jit(execute_control.simulate) # Move forward 1 unit - new_pose, readings, key = execute_control( - world=world, robot=robot, current_pose=start_pose, control=(1.0, 0.0), key=key - ) + result = exec_sim(key, (world, robot, start_pose, jnp.array([1.0, 0.0]))) + new_pose = result.get_retval()[0] + assert new_pose.p[0] == pytest.approx( 1.0 - where_am_i.WALL_COLLISION_THRESHOLD ) # Started at 0.5, blocked by wall at 1.0 assert new_pose.p[1] == pytest.approx(0.5) # Y shouldn't change # Rotate 90 degrees (π/2 radians) - new_pose, readings, key = execute_control( - world=world, - robot=robot, - current_pose=new_pose, - control=(0.0, jnp.pi / 2), - key=key, - ) + result = exec_sim(key, (world, robot, new_pose, jnp.array([0.0, jnp.pi / 2]))) + + new_pose = result.get_retval()[0] assert new_pose.hd == pytest.approx(jnp.pi / 2) def test_pose_methods(): """Test Pose step_along and rotate methods""" - p = Pose(jnp.array([1.0, 1.0]), 0.0) + p = Pose(jnp.array([1.0, 1.0]), jnp.array(0.0)) # Step along heading 0 (right) - p2 = p.step_along(1.0) + p2 = p.step_along(jnp.array(1.0)) assert p2.p[0] == pytest.approx(2.0) assert p2.p[1] == pytest.approx(1.0) # Rotate 90 degrees and step - p3 = p.rotate(jnp.pi / 2).step_along(1.0) + p3 = p.rotate(jnp.array(jnp.pi / 2)).step_along(jnp.array(1.0)) assert p3.p[0] == pytest.approx(1.0) assert p3.p[1] == pytest.approx(2.0) -pytest.main(["-v"]) # +def test_walls_to_jax(): + """Test wall conversion from UI format to JAX format""" + walls_list = [[0.0, 0.0, 0], [1.0, 0.0, 0], [1.0, 0.0, 1], [1.0, 1.0, 1]] + + walls, wall_vecs = walls_to_jax(walls_list) + + # Check shapes + assert walls.shape == (3, 2, 2) # 3 segments, 2 points per segment, 2 coordinates + assert wall_vecs.shape == (3, 2) # 3 segments, 2 coordinates per vector + + # Check first wall segment + assert jnp.allclose(walls[0, 0], jnp.array([0.0, 0.0])) + assert jnp.allclose(walls[0, 1], jnp.array([1.0, 0.0])) + + # Check wall vector + assert jnp.allclose(wall_vecs[0], jnp.array([1.0, 0.0])) + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/robot_2/visualization.py b/robot_2/visualization.py index bcc20fb..8df3a78 100644 --- a/robot_2/visualization.py +++ b/robot_2/visualization.py @@ -60,7 +60,7 @@ def drawing_system(key, on_complete): return line + events -def key_scrubber(handle_seed_index): +def seed_scrubber(handle_seed_index): """Create a scrubber UI component for exploring different random seeds. The component shows a striped bar that can be clicked to pause/resume and @@ -85,7 +85,7 @@ def key_scrubber(handle_seed_index): ? 'Click to Start' : inside ? 'Click to Pause' - : 'Explore Keys' + : 'Explore Seeds' const onMouseMove = React.useCallback(async (e) => { if (paused || waiting) return null; @@ -99,7 +99,15 @@ def key_scrubber(handle_seed_index): const stripeWidth = 4; // Width of each stripe in pixels return html(["div.flex.flex-col.gap-1", [ - ["div.flex.flex-row.gap-1", [ + ["div.flex.flex-row.gap-1", + ["div.text-md.flex.gap-2.p-2.border.hover:border-gray-400.cursor-pointer.font-mono.text-center.w-[140px]", { + "onClick": () => { + navigator.clipboard.writeText($state.current_seed.toString()); + }, + "style": { + cursor: "pointer" + } + }, $state.current_seed, ["div.text-gray-500.ml-auto", "copy"]],[ ["div.rounded-lg.p-2.delay-100.flex-grow", { "style": { background: paused @@ -125,15 +133,7 @@ def key_scrubber(handle_seed_index): transition: 'opacity 0.3s ease' } }, %2] - ]], - ["div.text-md.flex.gap-2.p-2.border.hover:border-gray-400.cursor-pointer.w-[140px].text-center", { - "onClick": () => { - navigator.clipboard.writeText($state.current_seed.toString()); - }, - "style": { - cursor: "pointer" - } - }, $state.current_seed, ["div.text-gray-500.ml-auto", "copy"]] + ]] ]]) } """, diff --git a/robot_2/where_am_i.py b/robot_2/where_am_i.py index c1c6160..445b592 100644 --- a/robot_2/where_am_i.py +++ b/robot_2/where_am_i.py @@ -44,321 +44,404 @@ # - Blue line: What the robot THINKS it's doing (following commands perfectly) # - Red rays: What the robot actually SEES (sensor readings) # - Blue cloud: Where the robot MIGHT be (uncertainty) -# - Green line: Where the robot figures it ACTUALLY is +# - Black line: Where the robot figures it ACTUALLY is # %% # pyright: reportUnusedExpression=false # pyright: reportUnknownMemberType=false -from functools import partial from typing import List, Tuple, Any, Dict import genjax import genstudio.plot as Plot import jax import jax.numpy as jnp +import time from jax.random import PRNGKey, split from penzai import pz from genstudio.plot import js +from functools import partial import robot_2.emoji as emoji import robot_2.visualization as v key = PRNGKey(0) - -WALL_COLLISION_THRESHOLD = 0.15 -WALL_WIDTH = 6 -PATH_WIDTH = 6 +WALL_COLLISION_THRESHOLD = jnp.array(0.15) +WALL_BOUNCE = 0.15 +MAX_SENSORS = 32 @pz.pytree_dataclass class Pose(genjax.PythonicPytree): """Robot pose with position and heading""" - p: jax.numpy.ndarray # [x, y] - hd: float # heading in radians + p: jnp.ndarray # [x, y] + hd: jnp.ndarray # heading in radians def dp(self): """Get direction vector from heading""" return jnp.array([jnp.cos(self.hd), jnp.sin(self.hd)]) - def step_along(self, s: float) -> "Pose": + def step_along(self, s: jnp.ndarray) -> "Pose": """Move forward by distance s""" return Pose(self.p + s * self.dp(), self.hd) - def rotate(self, angle: float) -> "Pose": + def rotate(self, angle: jnp.ndarray) -> "Pose": """Rotate by angle (in radians)""" return Pose(self.p, self.hd + angle) + def for_json(self): + if len(self.p.shape) == 1: + return [*self.p, self.hd] + heading_expanded = jnp.expand_dims(self.hd, axis=-1) # Add last dimension + return jnp.concatenate([self.p, heading_expanded], axis=-1) + + +def calculate_bounce_point( + collision_point: jnp.ndarray, + ray_dir: jnp.ndarray, + wall_vec: jnp.ndarray, + bounce_amount: jnp.ndarray, +) -> jnp.ndarray: + """Calculate bounce point for a single wall collision + + Args: + collision_point: Point of collision with wall + ray_dir: Direction of incoming ray + wall_vec: Vector along wall direction + bounce_amount: How far to bounce + + Returns: + Point after bouncing off wall + """ + wall_normal = jnp.array([-wall_vec[1], wall_vec[0]]) / ( + jnp.linalg.norm(wall_vec) + 1e-10 + ) + # Ensure wall normal points away from approach direction + wall_normal = jnp.where( + jnp.dot(ray_dir, wall_normal) > 0, -wall_normal, wall_normal + ) + return collision_point + bounce_amount * wall_normal + + +def compute_wall_normal(wall_direction: jnp.ndarray) -> jnp.ndarray: + """Compute unit normal vector to wall direction""" + return jnp.array([-wall_direction[1], wall_direction[0]]) / ( + jnp.linalg.norm(wall_direction) + 1e-10 + ) + @pz.pytree_dataclass class World(genjax.PythonicPytree): """The physical environment with walls that robots can collide with""" walls: jnp.ndarray # [N, 2, 2] array of wall segments + wall_vecs: jnp.ndarray # [N, 2] array of wall direction vectors + bounce: jnp.ndarray = WALL_BOUNCE # How much to bounce off walls + __hash__ = None + + def physical_step( + self, start_pos: jnp.ndarray, end_pos: jnp.ndarray + ) -> jnp.ndarray: + """Compute physical step with wall collisions and bounces + + Args: + start_pos: Starting position [x, y] + end_pos: Intended end position [x, y] + heading: Current heading in radians + + Returns: + New pose after movement, considering wall collisions + """ + # Calculate step properties + step_direction = end_pos - start_pos + step_length = jnp.linalg.norm(step_direction) + + # Get distance to nearest wall + ray_dir = step_direction / (step_length + 1e-10) # Avoid division by zero + wall_dist, wall_idx = self.ray_distance(start_pos, ray_dir, step_length) + + # Find collision point + collision_point = start_pos + ray_dir * wall_dist + + # Calculate bounce point if wall hit + bounce_pos = calculate_bounce_point( + collision_point, ray_dir, self.wall_vecs[wall_idx], self.bounce + ) + + # Define conditions for position selection + conditions = [ + step_length < 1e-6, # No movement case + wall_dist >= step_length, # No collision case + wall_idx >= 0, # Wall collision case + ] + + positions = [ + start_pos, # For no movement + end_pos, # For no collision + bounce_pos, # For wall collision + ] + + final_pos = jnp.select(conditions, positions, default=end_pos) + + return final_pos - @jax.jit def ray_distance( - self, ray_start: jnp.ndarray, ray_dir: jnp.ndarray, max_dist: float - ) -> float: - """Find distance to nearest wall along a ray""" + self, ray_start: jnp.ndarray, ray_dir: jnp.ndarray, max_dist: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Find distance to nearest wall along a ray and which wall was hit + + Args: + ray_start: Starting point of ray + ray_dir: Direction of ray (normalized) + max_dist: Maximum distance to check + + Returns: + Tuple of (distance, wall_idx) where: + - distance: Distance to nearest wall, or max_dist + threshold if no wall hit + - wall_idx: Index of wall that was hit, or -1 if no wall hit + """ if self.walls.shape[0] == 0: # No walls - return max_dist + return max_dist + WALL_COLLISION_THRESHOLD, jnp.array(-1) # Vectorized computation for all walls at once p1 = self.walls[:, 0] # Shape: (N, 2) - p2 = self.walls[:, 1] # Shape: (N, 2) - - # Wall direction vectors - wall_vec = p2 - p1 # Shape: (N, 2) # Vector from wall start to ray start to_start = ray_start - p1 # Shape: (N, 2) - # Compute determinant (cross product in 2D) - det = wall_vec[:, 0] * (-ray_dir[1]) - wall_vec[:, 1] * (-ray_dir[0]) + # Compute determinant (cross product in 2D) using pre-computed wall vectors + det = self.wall_vecs[:, 0] * (-ray_dir[1]) - self.wall_vecs[:, 1] * ( + -ray_dir[0] + ) # Compute intersection parameters u = (to_start[:, 0] * (-ray_dir[1]) - to_start[:, 1] * (-ray_dir[0])) / ( det + 1e-10 ) - t = (wall_vec[:, 0] * to_start[:, 1] - wall_vec[:, 1] * to_start[:, 0]) / ( - det + 1e-10 - ) + t = ( + self.wall_vecs[:, 0] * to_start[:, 1] + - self.wall_vecs[:, 1] * to_start[:, 0] + ) / (det + 1e-10) # Valid intersections: not parallel, in front of ray, within wall segment is_valid = (jnp.abs(det) > 1e-10) & (t >= 0) & (u >= 0) & (u <= 1) + distances = jnp.where(is_valid, t * jnp.linalg.norm(ray_dir), jnp.inf) - # Find minimum valid distance - min_dist = jnp.min(jnp.where(is_valid, t * jnp.linalg.norm(ray_dir), jnp.inf)) - return jnp.where(jnp.isinf(min_dist), max_dist, min_dist) + # Find closest valid wall + closest_idx = jnp.argmin(distances) + min_dist = distances[closest_idx] - @jax.jit - def check_movement( - self, - start_pos: jnp.ndarray, - end_pos: jnp.ndarray, - collision_radius: float = WALL_COLLISION_THRESHOLD, - ) -> Tuple[bool, jnp.ndarray]: - """Check if movement between two points collides with walls + # Return -1 as wall index if no valid intersection found + wall_idx = jnp.where(jnp.isinf(min_dist), -1, closest_idx) + final_dist = jnp.where( + jnp.isinf(min_dist), max_dist + WALL_COLLISION_THRESHOLD, min_dist + ) - Args: - start_pos: [x, y] starting position - end_pos: [x, y] intended end position - collision_radius: How close we can get to walls + return final_dist, wall_idx - Returns: - (can_move, safe_pos) where safe_pos is either end_pos or the - furthest safe position along the movement line - """ - movement_dir = end_pos - start_pos - dist = jnp.linalg.norm(movement_dir) - - # Replace if with where - ray_dir = jnp.where( - dist > 1e-6, - movement_dir / dist, - jnp.array([1.0, 0.0]), # Default direction if no movement - ) - wall_dist = self.ray_distance(start_pos, ray_dir, dist) +def walls_to_jax(walls_list: List[List[float]]) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert wall vertices from UI format to JAX arrays of wall segments and direction vectors""" + if not walls_list: + empty = jnp.array([]).reshape((0, 2, 2)) + return empty, jnp.array([]).reshape((0, 2)) - # Stop short of wall by collision_radius - safe_dist = jnp.maximum(0.0, wall_dist - collision_radius) - safe_pos = start_pos + ray_dir * safe_dist + points = jnp.array(walls_list, dtype=jnp.float32) + p1 = points[:-1] + p2 = points[1:] + + segments = jnp.stack([p1[:, :2], p2[:, :2]], axis=1) + valid_mask = p1[:, 2] == p2[:, 2] - # Use where to select between start_pos and safe_pos - final_pos = jnp.where(dist > 1e-6, safe_pos, start_pos) + # Compute wall direction vectors + wall_segments = segments * valid_mask[:, None, None] + wall_vecs = (wall_segments[:, 1] - wall_segments[:, 0]) * valid_mask[:, None] - return wall_dist > dist - collision_radius, final_pos + return wall_segments, wall_vecs @pz.pytree_dataclass class RobotCapabilities(genjax.PythonicPytree): """Physical capabilities and limitations of the robot""" - p_noise: float # Position noise (std dev in meters) - hd_noise: float # Heading noise (std dev in radians) - sensor_noise: float # Sensor noise (std dev in meters) - n_sensors: int = 8 # Number of distance sensors - sensor_range: float = 10.0 # Maximum sensor range in meters - - def try_move( - self, - world: World, - current_pos: jnp.ndarray, - desired_pos: jnp.ndarray, - key: PRNGKey, - ) -> jnp.ndarray: - """Try to move to desired_pos, respecting walls and adding noise""" - # Add motion noise - noise = jax.random.normal(key, shape=(2,)) * self.p_noise - noisy_target = desired_pos + noise + p_noise: jnp.ndarray # Position noise (std dev in meters) + hd_noise: jnp.ndarray # Heading noise (std dev in radians) + sensor_noise: jnp.ndarray # Sensor noise (std dev in meters) + n_sensors: jnp.ndarray = 8 # Number of distance sensors + sensor_range: jnp.ndarray = 10.0 # Maximum sensor range in meters - # Check for collisions - _, safe_pos = world.check_movement(current_pos, noisy_target) - return safe_pos +def path_to_controls(path_points: jnp.ndarray) -> Tuple[Pose, jnp.ndarray]: + """Convert a series of points into a starting pose and list of (distance, angle) control pairs -def path_to_controls(path_points: List[List[float]]) -> jnp.ndarray: - """Convert a series of points into (distance, angle) control pairs""" + Each control pair is (distance, angle) where: + - distance: move this far in current heading + - angle: after moving, turn this much (relative to current heading) + """ points = jnp.array([p[:2] for p in path_points]) deltas = points[1:] - points[:-1] distances = jnp.linalg.norm(deltas, axis=1) + + # Calculate absolute angles for each segment angles = jnp.arctan2(deltas[:, 1], deltas[:, 0]) - angle_changes = jnp.diff(angles, prepend=0.0) - return jnp.stack([distances, angle_changes], axis=1) - - -@jax.jit -def get_sensor_readings( - world: World, robot: RobotCapabilities, pose: Pose, key: PRNGKey -) -> Tuple[jnp.ndarray, PRNGKey]: - """Return noisy distance readings to walls from given pose""" - MAX_SENSORS = 32 # Fixed maximum - key, subkey = jax.random.split(key) - - # Calculate angles based on n_sensors, but generate MAX_SENSORS of them - angle_step = 2 * jnp.pi / robot.n_sensors - angles = jnp.arange(MAX_SENSORS) * angle_step - noise = jax.random.normal(subkey, (MAX_SENSORS,)) * robot.sensor_noise - - readings = jax.vmap( - lambda a: world.ray_distance( - ray_start=pose.p, - ray_dir=jnp.array([jnp.cos(pose.hd + a), jnp.sin(pose.hd + a)]), - max_dist=robot.sensor_range, - ) - )(angles) - # Create a mask for the first n_sensors elements - mask = jnp.arange(MAX_SENSORS) < robot.n_sensors + # Start facing the first segment + start_pose = Pose(p=points[0], hd=angles[0]) + + # For each segment, we need: + # - distance: length of current segment + # - angle: turn needed after this segment to face next segment + angle_changes = jnp.diff(angles, append=angles[-1]) - # Apply mask and pad with zeros - readings = (readings + noise) * mask + controls = jnp.stack([distances, angle_changes], axis=1) - return readings, key + return start_pose, controls -@jax.jit +@genjax.gen def execute_control( world: World, robot: RobotCapabilities, current_pose: Pose, - control: Tuple[float, float], - key: PRNGKey, -) -> Tuple[Pose, jnp.ndarray, PRNGKey]: - """Execute a control command with noise, stopping if we hit a wall - - Args: - control: (distance, angle) pair where: - - angle is how much to turn FIRST - - distance is how far to move AFTER turning - """ + control: jnp.ndarray, +): + """Execute a control command with physical step and noise""" dist, angle = control - k1, k2, k3 = jax.random.split(key, 3) - - # Add noise to motion - noisy_dist = dist + jax.random.normal(k1) * robot.p_noise - noisy_angle = angle + jax.random.normal(k2) * robot.hd_noise - # First rotate (can always rotate) - new_pose = current_pose.rotate(noisy_angle) - - # Check distance to wall in our current heading direction - min_dist = world.ray_distance( - ray_start=new_pose.p, ray_dir=new_pose.dp(), max_dist=robot.sensor_range + # Calculate noisy intended position + planned_pos = current_pose.p + dist * current_pose.dp() + noisy_pos = ( + genjax.mv_normal_diag(planned_pos, robot.p_noise * jnp.ones(2)) @ "p_noise" ) + noisy_angle = genjax.normal(current_pose.hd + angle, robot.hd_noise) @ "hd_noise" + physical_pos = world.physical_step(current_pose.p, noisy_pos) - # Only move as far as we can before hitting a wall - safe_dist = jnp.minimum(noisy_dist, min_dist - WALL_COLLISION_THRESHOLD) - safe_dist = jnp.maximum(safe_dist, 0) # Don't move backwards - - new_pose = new_pose.step_along(safe_dist) + final_pose = Pose(p=physical_pos, hd=noisy_angle) - # Get sensor readings from new position - readings, k4 = get_sensor_readings(world, robot, new_pose, k3) + return final_pose - return new_pose, readings, k4 - -@jax.jit -def simulate_robot_path( - world: World, - robot: RobotCapabilities, - start_pose: Pose, - controls: jnp.ndarray, - key: jnp.ndarray, +@genjax.gen +def sample_robot_path( + world: World, robot: RobotCapabilities, start: Pose, controls: jnp.ndarray ): - """Simulate robot path with noise and sensor readings + """Simulate robot path with noise and sensor readings using genjax + + Args: + world: World containing walls + robot: Robot capabilities and noise parameters + start: Starting pose + controls: Array of (distance, angle) control pairs Returns: Tuple of: - - Array of shape [n_steps, 2] containing positions - - Array of shape [n_steps] containing headings - - Array of shape [n_steps, n_sensors] containing sensor readings + - Array of poses for each step (including start pose) + - Array of sensor readings for each step """ + # Prepend a no-op control to get initial readings + noop = jnp.array([0.0, 0.0]) + all_controls = jnp.concatenate([noop[None], controls]) - def step_fn(carry, control): - pose, k = carry - new_pose, readings, new_key = execute_control( - world=world, robot=robot, current_pose=pose, control=control, key=k - ) - return (new_pose, new_key), (new_pose, readings) + path = ( + execute_control.partial_apply(world, robot).accumulate()(start, all_controls) + @ "path" + ) - (_, _), (poses, readings) = jax.lax.scan(step_fn, (start_pose, key), controls) + return path - # Extract positions and headings - positions = jnp.concatenate([start_pose.p[None, :], jax.vmap(lambda p: p.p)(poses)]) - headings = jnp.concatenate( - [jnp.array([start_pose.hd]), jax.vmap(lambda p: p.hd)(poses)] + +@genjax.gen +def pose_reading( + world: World, robot: RobotCapabilities, pose: Pose, angle: jnp.ndarray +) -> jnp.ndarray: + """Get a single noisy sensor reading at the given angle relative to robot heading""" + + ray_dir = jnp.array([jnp.cos(pose.hd + angle), jnp.sin(pose.hd + angle)]) + + distance, idx = world.ray_distance( + ray_start=pose.p, ray_dir=ray_dir, max_dist=robot.sensor_range ) - return positions, headings, readings + noisy_distance = genjax.normal(distance, robot.sensor_noise) @ "reading" + return noisy_distance -@partial(jax.jit, static_argnums=(1)) -def sample_possible_paths( - key: jnp.ndarray, - n_paths: int, - robot_path: jnp.ndarray, + +@genjax.gen +def pose_readings(world: World, robot: RobotCapabilities, pose: Pose): + sensor_angles = jnp.arange(MAX_SENSORS) * 2 * jnp.pi / robot.n_sensors + sensor_mask = jnp.arange(MAX_SENSORS) < robot.n_sensors + + readings = ( + pose_reading.partial_apply(world, robot, pose) + .mask() + .vmap()(sensor_mask, sensor_angles) + @ "readings" + ) + + return readings.value + + +@genjax.gen +def generate_true_path( world: World, robot: RobotCapabilities, + start_pose: Pose, + controls: jnp.ndarray, ): - """Generate n possible paths given the planned path, respecting walls""" - path_points = robot_path[:, :2] - controls = path_to_controls(path_points) + path = sample_robot_path(world, robot, start_pose, controls) @ "true_path" + readings = pose_readings.partial_apply(world, robot).vmap()(path) @ "sensor" - start_point = path_points[0] - start_pose = Pose(jnp.array(start_point, dtype=jnp.float32), 0.0) + return path, readings.value - keys = jax.random.split(key, n_paths) - # Vectorize over different random keys - return jax.vmap( - lambda k: simulate_robot_path( - world=world, robot=robot, start_pose=start_pose, controls=controls, key=k - ) - )(keys) +@partial(jax.jit, static_argnums=5) +def generate_possible_paths( + world: World, + robot: RobotCapabilities, + start_pose: Pose, + controls: jnp.ndarray, + key: jnp.ndarray, + n_possible: int = 40, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Core robot simulation logic that can be used for both visualization and benchmarking + + Args: + world: World containing walls + robot: Robot capabilities and noise parameters + path: Array of path points + key: Random key for simulation + n_possible: Number of possible paths to simulate (default: 40) + Returns: + Tuple of: + - paths: Array of shape (n_possible + 1, n_steps, 3) containing all simulated paths + - readings: Array of shape (n_possible + 1, n_steps, n_sensors) containing sensor readings + """ -def walls_to_jax(walls_list: List[List[float]]) -> jnp.ndarray: - """Convert wall vertices from UI format to JAX array of wall segments""" - if not walls_list: - return jnp.array([]).reshape((0, 2, 2)) + # Sample all paths at once (1 true path + N possible paths) + k1, k2 = jax.random.split(key) - points = jnp.array(walls_list, dtype=jnp.float32) - p1 = points[:-1] - p2 = points[1:] + # Vectorize simulation across keys + paths_k = jax.random.split(k1, n_possible) + possible_paths = jax.vmap(sample_robot_path.simulate, in_axes=(0, None))( + paths_k, (world, robot, start_pose, controls) + ).get_retval() - segments = jnp.stack([p1[:, :2], p2[:, :2]], axis=1) + true_path = possible_paths[0] + possible_paths = possible_paths[1:] - valid_mask = p1[:, 2] == p2[:, 2] - return segments * valid_mask[:, None, None] + readings_k = jax.random.split(k2, len(true_path)) + readings = jax.vmap(pose_readings.partial_apply(world, robot).simulate)( + readings_k, (true_path,) + ) + return possible_paths, true_path, readings.get_retval() -def simulate_robot_uncertainty(widget, e, seed=None): + +def update_robot_simulation(widget, e, seed=None): """Handle updates to robot simulation""" if not widget.state.robot_path: return @@ -368,44 +451,39 @@ def simulate_robot_uncertainty(widget, e, seed=None): current_key = PRNGKey(current_seed) + k1, k2 = jax.random.split(current_key, 2) + # Create world and robot objects - world = World(walls_to_jax(widget.state.walls)) + world = World(*walls_to_jax(widget.state.walls)) robot = RobotCapabilities( - p_noise=widget.state.motion_noise, - hd_noise=widget.state.motion_noise * widget.state.heading_noise_scale, - sensor_noise=widget.state.sensor_noise, - n_sensors=widget.state.n_sensors, - sensor_range=10.0, + p_noise=jnp.array(widget.state.motion_noise, dtype=jnp.float32), + hd_noise=jnp.array( + widget.state.motion_noise * widget.state.heading_noise_scale, + dtype=jnp.float32, + ), + sensor_noise=jnp.array(widget.state.sensor_noise, dtype=jnp.float32), + n_sensors=jnp.array(widget.state.n_sensors, dtype=jnp.int32), + sensor_range=jnp.array(10.0, dtype=jnp.float32), ) - path = jnp.array(widget.state.robot_path, dtype=jnp.float32) + paths = jnp.array(widget.state.robot_path, dtype=jnp.float32) + (start_pose, controls) = path_to_controls(paths[:, :2]) - # Sample all paths at once (1 true path + N possible paths) - n_possible = 40 - all_paths, all_headings, all_readings = sample_possible_paths( - current_key, n_possible + 1, path, world, robot + # Use the factored out simulation function and measure time + start_time = time.time() + paths, true_path, readings = generate_possible_paths( + world, robot, start_pose, controls, k1 ) - - # First path is the "true" path - true_path = all_paths[0] - final_readings = all_readings[0, -1] - final_heading = all_headings[0, -1] - - # Remaining paths are possible paths - possible_paths = all_paths[1:] + simulation_time = (time.time() - start_time) * 1000 # Convert to milliseconds widget.state.update( { - "robot_pose": { - "x": float(true_path[-1, 0]), - "y": float(true_path[-1, 1]), - "heading": float(final_heading), - }, - "possible_paths": possible_paths, - "sensor_readings": final_readings, - "true_path": [[float(x), float(y)] for x, y in true_path], + "possible_paths": paths, + "true_path": true_path, + "robot_readings": readings[-1][: robot.n_sensors], "show_debug": True, "current_seed": current_seed, + "simulation_time": simulation_time, } ) @@ -447,19 +525,18 @@ def create_initial_state(seed) -> Dict[str, Any]: [0, 10, 0], [0, 0, 0], # Left ], - "robot_pose": {"x": 0.5, "y": 0.5, "heading": 0}, + "robot_pose": Plot.js("$state.true_path?.[$state.true_path.length-1]"), + "true_path": None, "sensor_noise": 0.1, "motion_noise": 0.1, "heading_noise_scale": 0.3, "n_sensors": 8, - "show_sensors": True, "selected_tool": "path", "robot_path": [], "possible_paths": [], "estimated_pose": None, - "sensor_readings": [], + "robot_readings": None, "sensor_explore_angle": -1, - "show_uncertainty": True, "show_true_position": False, "current_line": [], "current_seed": seed, @@ -482,20 +559,21 @@ def create_initial_state(seed) -> Dict[str, Any]: ) sensor_rays = Plot.line( - js(""" - Array.from($state.sensor_readings).map((r, i) => { - const heading = $state.robot_pose.heading || 0; - const n_sensors = $state.n_sensors; + js( + """ + const readings = $state.robot_readings + if (!readings) return; + const n_sensors = readings.length; + const [x, y, heading] = $state.robot_pose; + return Array.from($state.robot_readings).flatMap((r, i) => { const angle = heading + (i * Math.PI * 2) / n_sensors; - const x = $state.robot_pose.x; - const y = $state.robot_pose.y; - return [ - [x, y, i], - [x + r * Math.cos(angle), - y + r * Math.sin(angle), i] - ] - }).flat() - """), + const from = [x, y, i] + const to = [x + r * Math.cos(angle), y + r * Math.sin(angle), i] + return [from, to] + }) + """, + expression=False, + ), z="2", stroke="red", strokeWidth=1, @@ -505,27 +583,32 @@ def create_initial_state(seed) -> Dict[str, Any]: rotating_sensor_rays = ( Plot.line( - js(""" - Array.from($state.sensor_readings).map((r, i) => { - const heading = $state.robot_pose.heading || 0; - const n_sensors = $state.n_sensors; - let angle = heading + (i * Math.PI * 2) / n_sensors; + js( + """ + const readings = $state.robot_readings; + if (!readings) return; + const n_sensors = readings.length; + const [x, y, heading] = $state.robot_pose; + + let angle_modifier = 0 if (!$state.show_true_position) { - if ($state.sensor_explore_angle > -1) { - angle += $state.sensor_explore_angle + const explore_angle = $state.sensor_explore_angle; + if (explore_angle > -1) { + angle_modifier = explore_angle } else { - angle += $state.current_seed || Math.random() * 2 * Math.PI; + angle_modifier = $state.current_seed || Math.random() * 2 * Math.PI; } } - const x = $state.robot_pose.x; - const y = $state.robot_pose.y; - return [ - [0, 0, i], - [r * Math.cos(angle), - r * Math.sin(angle), i] - ] - }).flat() - """), + return Array.from($state.robot_readings).flatMap((r, i) => { + let angle = heading + (i * Math.PI * 2) / n_sensors; + angle += angle_modifier; + const from = [0, 0, i] + const to = [r * Math.cos(angle), r * Math.sin(angle), i] + return [from, to] + }) + """, + expression=False, + ), z="2", stroke="red", strokeWidth=1, @@ -556,19 +639,19 @@ def create_initial_state(seed) -> Dict[str, Any]: ) true_path = Plot.cond( - js("$state.show_true_position"), + js("$state.show_true_position && $state.robot_pose"), [ Plot.line( js("$state.true_path"), stroke=Plot.constantly("True Path"), strokeWidth=2 ), sensor_rays, Plot.text( - js("[[$state.robot_pose.x, $state.robot_pose.y]]"), + js("[$state.robot_pose]"), text=Plot.constantly(emoji.robot), fontSize=30, textAnchor="middle", dy="-0.35em", - rotate=js("(-$state.robot_pose.heading + Math.PI/2) * 180 / Math.PI"), + rotate=js("(-$state.robot_pose[2] + Math.PI/2) * 180 / Math.PI"), ), ], ) @@ -584,7 +667,7 @@ def create_initial_state(seed) -> Dict[str, Any]: walls = Plot.line( js("$state.walls"), stroke=Plot.constantly("Walls"), - strokeWidth=WALL_WIDTH, + strokeWidth=6, z="2", render=Plot.renderChildEvents( { @@ -600,10 +683,10 @@ def create_initial_state(seed) -> Dict[str, Any]: possible_paths = Plot.line( js( """ - if (!$state.show_debug || !$state.possible_paths) {return [];}; - return $state.possible_paths.flatMap((path, pathIdx) => - path.map(([x, y]) => [x, y, pathIdx]) - ) + if (!$state.show_debug || !$state.possible_paths) {return [];}; + return $state.possible_paths.map((path, pathIdx) => + path.map(([x, y]) => [x, y, pathIdx]) + ).flat() """, expression=False, ), @@ -621,6 +704,13 @@ def clear_state(w, _): ) +def print_state(w, _): + """Print current walls and robot path in a format suitable for benchmarking""" + print("# Benchmark State:") + print(f"walls = {w.state.walls}") + print(f"robot_path = {w.state.robot_path}") + + selectable_button = "button.px-3.py-1.rounded.bg-gray-100.hover:bg-gray-300.data-[selected=true]:bg-gray-300" toolbar = Plot.html("Select tool:") | [ @@ -642,26 +732,41 @@ def clear_state(w, _): f"{emoji.pencil} Walls", ], [selectable_button, {"onClick": clear_state}, "Clear"], + [selectable_button, {"onClick": print_state}, "Print State"], ] def handleSeedIndex(w, e): + # Called by seed scrubber UI with: + # w: Plot.js widget + # e: {index: stripe index, key: current seed} + # The index indicates which stripe was clicked: + # -1: Recycle button was clicked, cycle to next seed + # 0: First stripe clicked, use first seed + # >0: Other stripes clicked, use seed at that position + # Need to: Get seed from global key based on index, update simulation global key + seed = None try: if e.index == 0: + # For first stripe, use first seed from key seed = key[0] elif e.index == -1: + # For recycle button, split key into 2 parts and use first seed + # This cycles through seeds by taking first part key = split(key, 2)[0] seed = key[0] else: + # For other stripes, split key into e.index parts + # Use seed at position (index-1) since we're 0-based seed = split(key, e.index)[e.index - 1][0] - simulate_robot_uncertainty(w, e, seed=seed) except Exception as err: + # Log any errors that occur during seed selection print(f"Error handling seed index: {err}, {e.key}, {e.index}") + # Update simulation with the selected seed + update_robot_simulation(w, e, seed=seed) -key_scrubber = v.key_scrubber(handleSeedIndex) - canvas = ( v.drawing_system("current_line", drawing_system_handler) + walls @@ -675,7 +780,7 @@ def handleSeedIndex(w, e): { "Walls": "#666", "Sensor Rays": "red", - "True Path": "green", + "True Path": "black", "Robot Path": "blue", } ) @@ -685,18 +790,31 @@ def handleSeedIndex(w, e): ) ( - canvas - & (sliders | toolbar | true_position_toggle | key_scrubber | rotating_sensor_rays) + ( + canvas + | Plot.js( + "$state.simulation_time && `${$state.simulation_time?.toFixed(2)} ms`" + ) + ) + & ( + sliders + | toolbar + | true_position_toggle + | rotating_sensor_rays + | v.seed_scrubber(handleSeedIndex) + ) & {"widths": ["400px", 1]} - | Plot.initialState(create_initial_state(0), sync=True) + | Plot.initialState( + create_initial_state(7 + 5 + 14), sync={"current_seed", "selected_tool"} + ) | Plot.onChange( { - "robot_path": simulate_robot_uncertainty, - "sensor_noise": simulate_robot_uncertainty, - "motion_noise": simulate_robot_uncertainty, - "heading_noise_scale": simulate_robot_uncertainty, - "n_sensors": simulate_robot_uncertainty, - "walls": simulate_robot_uncertainty, + "robot_path": update_robot_simulation, + "sensor_noise": update_robot_simulation, + "motion_noise": update_robot_simulation, + "heading_noise_scale": update_robot_simulation, + "n_sensors": update_robot_simulation, + "walls": update_robot_simulation, } ) ) From 538f3aee01f987ebbe652335199f0c728a44d174 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 27 Jan 2025 13:55:33 +0100 Subject: [PATCH 85/86] attempt to fix localization tutorial merge --- .../probcomp-localization-tutorial.py | 521 +++++++----------- 1 file changed, 208 insertions(+), 313 deletions(-) diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 832f6bf..16073fc 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -490,7 +490,7 @@ def pose_plot(p, fill: str | Any = "black", **opts): @genjax.gen -def step_proposal(motion_settings, start, control): +def step_model(motion_settings, start, control): p = ( genjax.mv_normal_diag( start.p + control.ds * start.dp(), motion_settings["p_noise"] * jnp.ones(2) @@ -509,7 +509,7 @@ def step_proposal(motion_settings, start, control): # %% key = jax.random.PRNGKey(0) -step_proposal.simulate( +step_model.simulate( key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]) ).get_retval() @@ -534,15 +534,20 @@ def make_circle(p, r): # Generate N_samples of starting poses from the prior N_samples = 50 key, sub_key = jax.random.split(key) -pose_samples = jax.vmap(step_proposal.simulate, in_axes=(0, None))( +pose_samples = jax.vmap(step_model.simulate, in_axes=(0, None))( jax.random.split(sub_key, N_samples), (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) +def pose_list_to_plural_pose(pl: list[Pose]) -> Pose: + return Pose(jnp.array([pose.p for pose in pl]), [pose.hd for pose in pl]) + + def poses_to_plots(poses: Iterable[Pose], **plot_opts): return [pose_plot(pose, **plot_opts) for pose in poses] + # Plot the world, starting pose samples, and 95% confidence region # Calculate the radius of the 95% confidence region def confidence_circle(pose: Pose, p_noise: float): @@ -574,7 +579,7 @@ def confidence_circle(pose: Pose, p_noise: float): # %% # `simulate` takes the GF plus a tuple of args to pass to it. key, sub_key = jax.random.split(key) -trace: genjax.Trace[Pose] = step_proposal.simulate( +trace = step_model.simulate( sub_key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) @@ -701,7 +706,7 @@ def confidence_circle(pose: Pose, p_noise: float): # %% path_model = ( - step_proposal.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() + step_model.partial_apply(default_motion_settings).map(lambda r: (r, r)).scan() ) @@ -781,8 +786,9 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): # # One could, for instance, consider just the placement of the first step, and replace its stochastic choice of heading with an updated value. The original trace was typical under the pose prior model, whereas the modified one may be rather less likely. This plot is annotated with log of how much unlikelier, the score ratio: # %% + key, sub_key = jax.random.split(key) -trace = step_proposal.simulate( +trace = step_model.simulate( sub_key, (default_motion_settings, robot_inputs["start"], robot_inputs["controls"][0]), ) @@ -1106,17 +1112,17 @@ def noisy_sensor(pose): @genjax.gen def full_model_kernel(motion_settings, state, control): - pose = step_proposal(motion_settings, state, control) @ "pose" + pose = step_model(motion_settings, state, control) @ "pose" sensor_model(pose, sensor_angles) @ "sensor" - return pose, pose + return pose @genjax.gen def full_model(motion_settings): return ( - full_model_kernel.partial_apply(motion_settings).scan()( - robot_inputs["start"], robot_inputs["controls"] - ) + full_model_kernel.partial_apply(motion_settings) + .map(lambda r: (r, r)) + .scan()(robot_inputs["start"], robot_inputs["controls"]) @ "steps" ) @@ -1212,13 +1218,11 @@ def animate_full_trace(trace, frame_key=None): # Encode sensor readings into choice map. -def constraint_from_sensors(readings): - angle_indices = jnp.arange(len(sensor_angles)) - return jax.vmap( - lambda ix, v: C["steps", ix, "sensor", angle_indices, "distance"].set(v) - )(jnp.arange(T), readings) + C["initial", "sensor", angle_indices, "distance"].set( - readings[0] - ) +def constraint_from_sensors(readings, t: int = T): + return C["steps", jnp.arange(t + 1), "sensor", :, "distance"].set(readings[: t + 1]) + # return jax.vmap( + # lambda v: C["steps", :, "sensor", :, "distance"].set(v) + # )(readings[:t]) constraints_low_deviation = constraint_from_sensors(observations_low_deviation) @@ -1515,9 +1519,12 @@ def constraint_from_path(path): # %% -def resample( +def importance_sample( key: PRNGKey, constraints: genjax.ChoiceMap, motion_settings, N: int, K: int ): + """Produce N importance samples of depth K from the model. That is, N times, we + generate K importance samples conditioned by the constraints, and categorically + select one of them.""" key1, key2 = jax.random.split(key) samples, log_weights = jax.vmap(model_importance, in_axes=(0, None, None))( jax.random.split(key1, N * K), constraints, (motion_settings,) @@ -1532,7 +1539,7 @@ def resample( return selected -jit_resample = jax.jit(resample, static_argnums=(3, 4)) +jit_resample = jax.jit(importance_sample, static_argnums=(3, 4)) key, sub_key = jax.random.split(key) low_posterior = jit_resample( @@ -1546,21 +1553,24 @@ def resample( # %% -def animate_path_as_line(path, **options): - x_coords = path.p[:, 0] - y_coords = path.p[:, 1] - return Plot.line({"x": x_coords, "y": y_coords}, {"curve": "linear", **options}) +def path_to_polyline(path, **options): + if len(path.p.shape) > 1: + x_coords = path.p[:, 0] + y_coords = path.p[:, 1] + return Plot.line({"x": x_coords, "y": y_coords}, {"curve": "linear", **options}) + else: + return Plot.dot([path.p], fill=options["stroke"], r=2, **options) # ( world_plot + [ - animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="green") + path_to_polyline(path, opacity=0.2, strokeWidth=2, stroke="green") for path in jax.vmap(get_path)(low_posterior) ] + [ - animate_path_as_line(path, opacity=0.2, strokeWidth=2, stroke="blue") + path_to_polyline(path, opacity=0.2, strokeWidth=2, stroke="blue") for path in jax.vmap(get_path)(high_posterior) ] + poses_to_plots( @@ -1584,312 +1594,197 @@ def animate_path_as_line(path, **options): # Let's pause a moment to examine this chart. If the robot had no sensors, it would have no alternative but to estimate its position by integrating the control inputs to produce the integrated path in gray. In the low deviation setting, Gen has helped the robot to see that about halfway through its journey, noise in the control-effector relationship has caused the robot to deviate to the south slightly, and *the sensor data combined with importance sampling is enough* to give accurate results in the low deviation setting. # But in the high deviation setting, the loose nature of the paths in the blue posterior indicate that the robot has not discovered its true position by using importance sampling with the noisy sensor data. In the high deviation setting, more refined inference technique will be required. # -# Let's approach the problem step by step instead of trying to infer the whole path. -# To get started we'll work with the initial point, and then improve it. Once that's done, -# we can chain together such improved moves to hopefully get a better inference of the -# actual path. - -# One thing we'll need is a path to improve. We can select one of the importance samples we generated -# earlier. - - -# %% -def select_by_weight(key: PRNGKey, weights: FloatArray, things): - """Makes a categorical selection from the vector object `things` - weighted by `weights`. The selected object is returned (with its - outermost axis removed) with its weight.""" - chosen = jax.random.categorical(key, weights) - return jax.tree.map(lambda v: v[chosen], things), weights[chosen] - - -# %% [markdown] -# Select an importance sample by weight in both the low and high deviation settings. It will be handy -# to have one path to work with to test our improvements. -# %% -key, k1, k2 = jax.random.split(key, 3) -low_deviation_path, _ = select_by_weight(k1, low_weights, low_deviation_paths) -high_deviation_path, _ = select_by_weight(k2, high_weights, high_deviation_paths) - - -# %% [markdown] -# Create a choicemap that will enforce the given sensor observation -# %% -def observation_to_choicemap(observation, pose=None): - sensor_cm = C["sensor", :, "distance"].set(observation) - pose_cm = ( - C["pose", "p"].set(pose.p) + C["pose", "hd"].set(pose.hd) - if pose is not None - else C.n() - ) - return sensor_cm + pose_cm - - -# %% [markdown] -# Let's visualize a cloud of possible poses by coloring the elements proportional to their -# plausibility under the sensor readingss. -# %% -def step_sample(key: PRNGKey, N: int, gf, observation): - tr, ws = jax.vmap(gf.importance, in_axes=(0, None, None))( - jax.random.split(key, N), observation_to_choicemap(observation), () - ) - return tr.get_retval()[0], ws - - -def weighted_small_pose_plot(proposal, truth, weights, poses, zoom=1): - max_logw = jnp.max(weights) - lse_ws = max_logw + jnp.log(jnp.sum(jnp.exp(weights - max_logw))) - scaled_ws = jnp.exp(weights - lse_ws) - max_scaled_w: FloatArray = jnp.max(scaled_ws) - scaled_ws /= max_scaled_w - # the following hack "boosts" lower scores a bit, to give us more visibility into - # the density of the nearby cloud. Aesthetically, I found too many points were - # invisible without some adjustment, since the score distribution is concentrated - # closely around 1.0 - scaled_ws = scaled_ws**0.3 - z = 0.03 * zoom - return Plot.new( - [pose_plot(p, fill=w, zoom=z) for p, w in zip(poses, scaled_ws)] - + pose_plot(proposal, fill="red", zoom=z) - + pose_plot(truth, fill="green", zoom=z) - ) + { - "color": {"type": "linear", "scheme": "OrRd"}, - "height": 400, - "width": 400, - "aspectRatio": 1, - } - - -# %% -key, sub_key = jax.random.split(key) -step_poses, step_scores = step_sample( - sub_key, - 1000, - full_model_kernel( - motion_settings_low_deviation, - robot_inputs["start"], - robot_inputs["controls"][0], - ), - observations_low_deviation[0], -) -# %% -weighted_small_pose_plot( - path_low_deviation[0], robot_inputs["start"], step_scores, step_poses -) - +# Let's approach the problem step by step instead of trying to infer the whole path at once. +# The technique we will use is called Sequential Importance Sampling or a +# [Particle Filter](https://en.wikipedia.org/wiki/Particle_filter). It works like this. +# +# When we designed the step model for the robot, we arranged things so that the model +# could be used with `scan`: the model takes a *state* and a *control input* to produce +# a new *state*. Imagine at some time step $t$ that we use importance sampling with this +# model at a pose $\mathbf{z}_t$ and control input $\mathbf{u}_t$, scored with respect to the +# sensor observations $\mathbf{y}_t$ observed at that time. We will get a weighted collection +# of possible updated poses $\mathbf{z}_t^N$ and weights $w^N$. +# +# The particle filter "winnows" this set by replacing it with $N$ weighted selections +# *with replacement* from this collection. This may select better candidates several +# times, and is likely to drop poor candidates from the collection. We can arrange to +# to this at each time step with a little preparation: we start by "cloning" our idea +# of the robot's initial position into an N vector and this becomes the initial particle +# collection. At each step, we generate an importance sample and winnow it. +# +# This can also be done as a scan. Our previous attempt used `scan` to produce candidate +# paths from start to end, and these were scored for importance using all of the sensor +# readings at once. The results were better than guesses, but not accurate, in the +# high deviation setting. +# +# The technique we will use here discards steps with low likelihood at each step, and +# reinforces steps with high likelihood, allowing better particles to proportionately +# search more of the probability space while discarding unpromising particles. +# +# The following class attempts to generatlize this idea: -# %% [markdown] -# Develop a function which will produce a grid of evenly spaced nearby poses given -# an initial pose. $n$ is the number of steps to take in each cardinal direction -# (up/down, left/right and changes in heading). For example, if you say $n = 2$, there -# will be a $5\times 5$ grid of positions with the original pose in the center, and 5 layers -# of this type, each with different heading deltas (including zero), for a total of -# $125 = 5^3$ alternate poses. # %% -def grid_of_nearby_poses(p, n, motion_settings): - indices = jnp.arange(-n, n + 1) - n_indices = len(indices) - point_deltas = indices * 2 * motion_settings["p_noise"] / n - hd_deltas = indices * 2 * motion_settings["hd_noise"] / n - xs = jnp.repeat(point_deltas, n_indices) - ys = jnp.tile(point_deltas, n_indices) - points = jnp.repeat(jnp.column_stack((xs, ys)), n_indices, axis=0) - headings = jnp.tile(hd_deltas, n_indices * n_indices) - return Pose(p.p + points, p.hd + headings) +StateT = TypeVar("StateT") +ControlT = TypeVar("ControlT") -# %% - +class SequentialImportanceSampling(Generic[StateT, ControlT]): + """ + Given: + - a functional wrapper for the importance method of a generative function + - an initial state of type StateT, which should be a PyTree $z_0$ + - a vector of control inputs, also a PyTree $u_i, of shape $(T, \ldots)$ + - an array of observations $y_i$, also of shape $(T, \ldots)$ + perform the inference technique known as Sequential Importance Sampling. + + The signature of the GFI importance method is + key -> constraint -> args -> (trace, weight) + For importance sampling, this is vmapped over key to get + [keys] -> constraint -> args -> ([trace], [weight]) + The functional wrapper's purpose is to maneuver the state and control + inputs into whatever argument shape the underlying model is expecting, + and to turn the observation at step $t$ into a choicemap asserting + that constraint. + + After the object is constructed, SIS can be performed at any importance + depth with the `run` method, which will perform the following steps: + + - inflate the initial value to a vector of size N of identical initial + values + - vmap over N keys generated from the supplied key + - each vmap cell will scan over the control inputs and observations + + Between each step, categorical sampling with replacement is formed to + create a particle filter. Favorable importance draws are likely to + be replicated, and unfavorable ones discarded. The resampled vector of + states is sent the the next step, while the values drawn from the + importance sample and the indices chosen are emitted from teh scan step, + where, at the end of the process, they will be available as matrices + of shape (N, T). + """ -def grid_sample(gf, pose_grid, observation): - scores, _retvals = jax.vmap( - lambda pose: gf.assess(observation_to_choicemap(observation, pose), ()) - )(pose_grid) - return scores + def __init__( + self, + importance: Callable[ + [PRNGKey, StateT, ControlT, Array], tuple[genjax.Trace[StateT], float] + ], + init: StateT, + controls: ControlT, + observations: Array, + ): + self.importance = jax.jit(importance) + self.init = init + self.controls = controls + self.observations = observations + + class Result(Generic[StateT]): + """This object contains all of the information generated by the SIS scan, + and offers some convenient methods to reconstruct the paths explored + (`flood_fill`) or ultimately chosen (`backtrack`). + """ + def __init__( + self, N: int, end: StateT, samples: genjax.Trace[StateT], indices: IntArray + ): + self.N = N + self.end = end + self.samples = samples + self.indices = indices + + def flood_fill(self) -> list[list[StateT]]: + samples = self.samples.get_retval() + active_paths = [[p] for p in samples[0]] + complete_paths = [] + for i in range(1, len(samples)): + indices = self.indices[i - 1] + counts = jnp.bincount(indices, length=self.N) + new_active_paths = self.N * [None] + for j in range(self.N): + if counts[j] == 0: + complete_paths.append(active_paths[j]) + new_active_paths[j] = active_paths[indices[j]] + [samples[i][j]] + active_paths = new_active_paths + + return complete_paths + active_paths + + def backtrack(self) -> list[list[StateT]]: + paths = [[p] for p in self.end] + samples = self.samples.get_retval() + for i in reversed(range(len(samples))): + for j in range(len(paths)): + paths[j].append(samples[i][self.indices[i][j].item()]) + for p in paths: + p.reverse() + return paths + + def run(self, key: PRNGKey, N: int) -> dict: + def step(state, update): + key, control, observation = update + ks = jax.random.split(key, (2, N)) + sample, log_weights = jax.vmap(self.importance, in_axes=(0, 0, None, None))( + ks[0], state, control, observation + ) + indices = jax.vmap(genjax.categorical.sampler, in_axes=(0, None))( + ks[1], log_weights + ) + resample = jax.tree.map(lambda v: v[indices], sample) + return resample.get_retval(), (sample, indices) -# %% -# Our grid of nearby poses is actually a cube when we take into consideration the -# heading deltas. In order to get a 2d density to visualize, we flatten the cube by -# taking the "best" of the headings by score at each point. (Note: for the inference -# that follows, we will work with the full cube). -def flatten_pose_cube(pose_grid, cube_step_size, scores): - n_indices = 2 * cube_step_size + 1 - best_heading_indices = jnp.argmax( - scores.reshape(n_indices * n_indices, n_indices), axis=1 - ) - # those were block relative; linearize them by adding back block indices - bs = best_heading_indices + jnp.arange(0, n_indices**3, n_indices) - return Pose(pose_grid.p[bs], pose_grid.hd[bs]), scores[bs] + init_array = jax.tree.map( + lambda a: jnp.broadcast_to(a, (N,) + a.shape), self.init + ) + end, (samples, indices) = jax.lax.scan( + step, + init_array, + ( + jax.random.split(key, len(self.controls)), + self.controls, + self.observations, + ), + ) + return SequentialImportanceSampling.Result(N, end, samples, indices) -# %% [markdown] -# Prepare a plot showing the density of nearby improvements available using the grid -# search and importance sampling techniques. # %% -# Test our code for visualizing the Boltzmann and grid searches at the initial pose. -def first_step_chart(key): - cube_step_size = 6 - pose_grid = grid_of_nearby_poses( - path_low_deviation[0], cube_step_size, motion_settings_low_deviation - ) - gf = full_model_kernel( - motion_settings_low_deviation, +def localization_sis(motion_settings, observations): + return SequentialImportanceSampling( + lambda key, pose, control, observation: full_model_kernel.importance( + key, + C["sensor", :, "distance"].set(observation), + (motion_settings, pose, control), + ), robot_inputs["start"], - robot_inputs["controls"][0], - ) - score_grid = grid_sample( - gf, - pose_grid, - observations_low_deviation[0], - ) - step_poses, step_scores = step_sample( - key, - 1000, - gf, - observations_low_deviation[0], - ) - pose_plane, score_plane = flatten_pose_cube(pose_grid, cube_step_size, score_grid) - return weighted_small_pose_plot( - path_low_deviation[0], robot_inputs["start"], score_plane, pose_plane - ) & weighted_small_pose_plot( - path_low_deviation[0], robot_inputs["start"], step_scores, step_poses + robot_inputs["controls"], + observations, ) -key, sub_key = jax.random.split(key) -first_step_chart(sub_key) -# %% [markdown] -# Now let's try doing the whole path. We want to produce something that is ultimately -# scan-compatible, so it should have the form state -> update -> new_state. The state -# is obviously the pose; the update will include the sensor readings at the current -# position and the control input for the next step. - -# Step 1. retire assess_model and use full_model_kernel in both bz and grid improvers. -# Step 2. add the [pose,weight] of `pose` to the vector sampled by select_by_weight in the bz case -# Step 3. How is the weight computed for `pose` ? -# what we have now + correction term -# pose.weight = full_model_kernel.assess(p, (cm,)) - - # %% -def improved_path(key: PRNGKey, motion_settings: dict, observations: FloatArray): - cube_step_size = 8 - - def grid_search_step(k: PRNGKey, gf, center_pose, observation): - pose_grid = grid_of_nearby_poses(center_pose, cube_step_size, motion_settings) - nearby_weights = grid_sample(gf, pose_grid, observation) - return nearby_weights, pose_grid - - def improved_step(state, update): - observation, control, key = update - gf = full_model_kernel(motion_settings, state, control) - # Run a sample and pick an element by weight. - k1, k2, k3 = jax.random.split(key, 3) - poses, scores = step_sample(k1, 1000, gf, observation) - new_pose, new_weight = select_by_weight(k2, scores, poses) - weights2, poses2 = grid_search_step(k2, gf, new_pose, observation) - # Note that `new_pose` will be among the poses considered by grid_search_step, - # so the possibility exists to remain stationary, as Bayesian inference requires - chosen_pose, _ = select_by_weight(k3, weights2, poses2) - flat_poses, flat_scores = flatten_pose_cube(poses2, cube_step_size, weights2) - return chosen_pose, (new_pose, chosen_pose, flat_scores, flat_poses, new_weight) - - sub_keys = jax.random.split(key, T + 1) - return jax.lax.scan( - improved_step, - robot_inputs["start"], - ( - observations, # observation at time t - robot_inputs["controls"], # guides step from t to t+1 - sub_keys[1:], - ), - ) - - -jit_improved_path = jax.jit(improved_path) -# %% -key, sub_key = jax.random.split(key) -_, improved_low = jit_improved_path( - sub_key, motion_settings_low_deviation, observations_low_deviation -) key, sub_key = jax.random.split(key) -_, improved_high = jit_improved_path( - sub_key, motion_settings_high_deviation, observations_high_deviation -) - +smc_result = localization_sis( + motion_settings_high_deviation, observations_high_deviation +).run(sub_key, 100) -# %% -def path_comparison_plot(*plots): - types = ["improved", "integrated", "importance", "true"] - plot = world_plot - plot += [ - animate_path_as_line(p, strokeWidth=2, stroke=Plot.constantly(t)) - for p, t in zip(plots, types) +( + world_plot + + path_to_polyline(path_high_deviation, stroke="blue", strokeWidth=2) + + [ + path_to_polyline(pose_list_to_plural_pose(p), opacity=0.1, stroke="green") + for p in smc_result.flood_fill() ] - plot += [poses_to_plots(p, fill=Plot.constantly(t)) for p, t in zip(plots, types)] - return plot + Plot.color_map( - { - "integrated": "green", - "improved": "blue", - "true": "black", - "importance": "red", - } - ) - - -# %% -path_comparison_plot( - improved_low[0], path_integrated, low_deviation_path, path_low_deviation ) # %% -path_comparison_plot( - improved_high[0], path_integrated, high_deviation_path, path_high_deviation +# Try it in the low deviation setting +key, sub_key = jax.random.split(key) +low_smc_result = localization_sis( + motion_settings_low_deviation, observations_low_deviation +).run(sub_key, 20) +( + world_plot + + path_to_polyline(path_low_deviation, stroke="blue", strokeWidth=2) + + [ + path_to_polyline(pose_list_to_plural_pose(p), opacity=0.1, stroke="green") + for p in low_smc_result.flood_fill() + ] ) -# %% [markdown] -# To see how the grid search improves poses, we play back the grid-search path -# next to an importance sample path. You can see the grid search has a better fit -# of sensor data to wall position at a variety of time steps. -# %% -Plot.Row( - animate_path_and_sensors( - improved_high[0], - observations_high_deviation, - motion_settings_high_deviation, - frame_key="frame", - ), - animate_path_and_sensors( - high_deviation_path, - observations_high_deviation, - motion_settings_high_deviation, - frame_key="frame", - ), -) | Plot.Slider("frame", 0, T, fps=2) - - -# %% -# Finishing touch: weave together the improved plot and the improvement steps -# into a slider animation -# Plot.Frames( -# [weighted_small_pose_plot(improved_high[0][k], path_high_deviation[k], improved_high[2][k], improved_high[1][k]) for k in range(T)], -# ) -# %% -def wsp_frame(k): - return path_comparison_plot( - improved_high[0][: k + 1], - path_integrated[: k + 1], - high_deviation_path[: k + 1], - path_high_deviation[: k + 1], - ) & weighted_small_pose_plot( - improved_high[1][k], - path_high_deviation[k], - improved_high[2][k], - improved_high[3][k], - zoom=4, - ) - - -# %% -Plot.Frames([wsp_frame(k) for k in range(1, 6)]) - -# %% From 8d960a7f609caac96a50267bee3d1df01bf5a156 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Mon, 27 Jan 2025 13:57:15 +0100 Subject: [PATCH 86/86] bump genstudio --- poetry.lock | 15 +++++---------- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/poetry.lock b/poetry.lock index 81e73d6..279a6cb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -718,13 +718,13 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.12.4" +version = "2025.1.9" description = "" optional = false -python-versions = ">=3.10,<3.13" +python-versions = "<3.13,>=3.11" files = [ - {file = "genstudio-2024.12.4-py3-none-any.whl", hash = "sha256:280154e6facb55a73b66b8b6229c847c8c57377faf3022ff73f564d252e52a8f"}, - {file = "genstudio-2024.12.4.tar.gz", hash = "sha256:f3c975e61c068e7c2a6a75be018018df16e4bb38a693eadeddcfff86dcff2da7"}, + {file = "genstudio-2025.1.9-py3-none-any.whl", hash = "sha256:4c96b97e7ede81cc89869c9351394a35115dea33b1c9e0d9d46fdd842a00764a"}, + {file = "genstudio-2025.1.9.tar.gz", hash = "sha256:d0b1e22905952e4dc9b3ed5c404e60147dbc03b5905ad8fbc2d891ca3da83597"}, ] [package.dependencies] @@ -734,11 +734,6 @@ orjson = ">=3.10.6,<4.0.0" pillow = ">=10.4.0,<11.0.0" traitlets = ">=5.14.3,<6.0.0" -[package.source] -type = "legacy" -url = "https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple" -reference = "gcp" - [[package]] name = "html2image" version = "2.0.5" @@ -2716,4 +2711,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "48a726e720b1311bc7f5b1738c8b8576f08d70d07f93c24cd2cae92a2cc85e9f" +content-hash = "e8eb20d0158bf03c5e88d32f576ac06238c8e2acaaf7873160fd54f7a209bc96" diff --git a/pyproject.toml b/pyproject.toml index 930f6c6..35ff186 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ packages = [ python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.8.0", source = "gcp" } -genstudio = {version = "2024.12.004", source = "gcp"} +genstudio = "2025.1.9" ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7"