Skip to content

Commit a60fd0e

Browse files
authored
Merge pull request #23 from chi-collective/colin/loc-smc
SMC inference in genjax for localization
2 parents a7fc36f + 18ef2ba commit a60fd0e

File tree

3 files changed

+635
-1517
lines changed

3 files changed

+635
-1517
lines changed

genjax-localization-tutorial/probcomp-localization-tutorial.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import jax
4040
import jax.numpy as jnp
4141
import genjax
42-
from urllib.request import urlopen
4342
from genjax import SelectionBuilder as S
4443
from genjax import ChoiceMapBuilder as C
4544
from genjax.typing import Array, FloatArray, PRNGKey, IntArray
@@ -50,7 +49,7 @@
5049
import os
5150

5251
html = Plot.Hiccup
53-
Plot.configure({"display_as": "html", "dev": False})
52+
# Plot.configure({"display_as": "html", "dev": False})
5453

5554
# Ensure a location for image generation.
5655
os.makedirs("imgs", exist_ok=True)
@@ -205,19 +204,20 @@ def load_world(file_name):
205204
Returns:
206205
- tuple: A tuple containing the world configuration, the initial state, and the total number of control steps.
207206
"""
208-
with urlopen(
209-
"https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json"
210-
) as url:
211-
data = json.load(url)
212-
213-
walls_vec = jnp.array(data["wall_verts"])
214-
clutters_vec = jnp.array(data["clutter_vert_groups"])
207+
# TODO: change these to urlopen when the repo becomes public
208+
with open("../world.json") as f:
209+
world = json.load(f)
210+
with open("../robot_program.json") as f:
211+
robot_program = json.load(f)
212+
213+
walls_vec = jnp.array(world["wall_verts"])
214+
clutters_vec = jnp.array(world["clutter_vert_groups"])
215215
start = Pose(
216-
jnp.array(data["start_pose"]["p"], dtype=float),
217-
jnp.array(data["start_pose"]["hd"], dtype=float),
216+
jnp.array(robot_program["start_pose"]["p"], dtype=float),
217+
jnp.array(robot_program["start_pose"]["hd"], dtype=float),
218218
)
219219

220-
cs = jnp.array([[c["ds"], c["dhd"]] for c in data["program_controls"]])
220+
cs = jnp.array([[c["ds"], c["dhd"]] for c in robot_program["program_controls"]])
221221
controls = Control(cs[:, 0], cs[:, 1])
222222

223223
return make_world(walls_vec, clutters_vec, start, controls)
@@ -675,8 +675,6 @@ def confidence_circle(pose: Pose, p_noise: float):
675675

676676
# %% [hide-input]
677677

678-
# jax.random.split(jax.random.PRNGKey(3333), N_samples).shape
679-
680678
ps0 = jax.tree.map(lambda v: v[0], pose_samples)
681679
(
682680
ps0.project(jax.random.PRNGKey(2), S[()]),
@@ -1079,10 +1077,10 @@ def animate_full_trace(trace, frame_key=None):
10791077
"p_noise": 0.05,
10801078
"hd_noise": (1 / 10.0) * 2 * jnp.pi / 360,
10811079
}
1082-
key, k_low, k_high = jax.random.split(key, 3)
1080+
motion_settings_high_deviation = {"p_noise": 0.25, "hd_noise": 2 * jnp.pi / 360}
10831081

1082+
key, k_low, k_high = jax.random.split(key, 3)
10841083
trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,))
1085-
motion_settings_high_deviation = {"p_noise": 0.25, "hd_noise": 2 * jnp.pi / 360}
10861084
trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,))
10871085

10881086
animate_full_trace(trace_low_deviation)
@@ -1108,9 +1106,6 @@ def animate_full_trace(trace, frame_key=None):
11081106

11091107
def constraint_from_sensors(readings, t: int = T):
11101108
return C["steps", jnp.arange(t + 1), "sensor", :, "distance"].set(readings[: t + 1])
1111-
# return jax.vmap(
1112-
# lambda v: C["steps", :, "sensor", :, "distance"].set(v)
1113-
# )(readings[:t])
11141109

11151110

11161111
constraints_low_deviation = constraint_from_sensors(observations_low_deviation)
@@ -1222,7 +1217,7 @@ def constraint_from_path(path):
12221217
c_hds = jax.vmap(lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd))(
12231218
jnp.arange(T), path.hd
12241219
)
1225-
return c_ps + c_hds # + c_p + c_hd
1220+
return c_ps + c_hds
12261221

12271222

12281223
constraints_path_integrated = constraint_from_path(path_integrated)
@@ -1662,6 +1657,16 @@ def localization_sis(motion_settings, observations):
16621657
for p in smc_result.flood_fill()
16631658
]
16641659
)
1660+
1661+
# jay/colin: the inference is pretty good. We could:
1662+
# - add grid search to refine each step, or
1663+
# - let the robot adjust its next control input to make use of
1664+
# the updated information about its actual pose that the
1665+
# inference over the sensor data has revealed.
1666+
# the point being:
1667+
# We can say "using the inference, watch the robot succeed
1668+
# in entering the room. Without that, the robot's mission
1669+
# was bound to fail in the high deviation scenario."
16651670
# %%
16661671
# Try it in the low deviation setting
16671672
key, sub_key = jax.random.split(key)
@@ -1676,3 +1681,16 @@ def localization_sis(motion_settings, observations):
16761681
for p in low_smc_result.flood_fill()
16771682
]
16781683
)
1684+
1685+
# %%
1686+
# demo: recycle traces
1687+
key, k_low, k_high = jax.random.split(key, 3)
1688+
trace_low_deviation = full_model.simulate(k_low, (motion_settings_low_deviation,))
1689+
trace_high_deviation = full_model.simulate(k_high, (motion_settings_high_deviation,))
1690+
path_low_deviation = get_path(trace_low_deviation)
1691+
path_high_deviation = get_path(trace_high_deviation)
1692+
# ...using these data.
1693+
observations_low_deviation = get_sensors(trace_low_deviation)
1694+
observations_high_deviation = get_sensors(trace_high_deviation)
1695+
1696+
# %%

0 commit comments

Comments
 (0)