Skip to content

Commit 18ef2ba

Browse files
- make JSON load from parent directory instead of URL (change this back when the repo goes public)
- add some notes Jay and I took.
1 parent 407db7b commit 18ef2ba

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

genjax-localization-tutorial/probcomp-localization-tutorial.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import jax
4040
import jax.numpy as jnp
4141
import genjax
42-
from urllib.request import urlopen
4342
from genjax import SelectionBuilder as S
4443
from genjax import ChoiceMapBuilder as C
4544
from genjax.typing import Array, FloatArray, PRNGKey, IntArray
@@ -205,19 +204,20 @@ def load_world(file_name):
205204
Returns:
206205
- tuple: A tuple containing the world configuration, the initial state, and the total number of control steps.
207206
"""
208-
with urlopen(
209-
"https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json"
210-
) as url:
211-
data = json.load(url)
212-
213-
walls_vec = jnp.array(data["wall_verts"])
214-
clutters_vec = jnp.array(data["clutter_vert_groups"])
207+
# TODO: change these to urlopen when the repo becomes public
208+
with open("../world.json") as f:
209+
world = json.load(f)
210+
with open("../robot_program.json") as f:
211+
robot_program = json.load(f)
212+
213+
walls_vec = jnp.array(world["wall_verts"])
214+
clutters_vec = jnp.array(world["clutter_vert_groups"])
215215
start = Pose(
216-
jnp.array(data["start_pose"]["p"], dtype=float),
217-
jnp.array(data["start_pose"]["hd"], dtype=float),
216+
jnp.array(robot_program["start_pose"]["p"], dtype=float),
217+
jnp.array(robot_program["start_pose"]["hd"], dtype=float),
218218
)
219219

220-
cs = jnp.array([[c["ds"], c["dhd"]] for c in data["program_controls"]])
220+
cs = jnp.array([[c["ds"], c["dhd"]] for c in robot_program["program_controls"]])
221221
controls = Control(cs[:, 0], cs[:, 1])
222222

223223
return make_world(walls_vec, clutters_vec, start, controls)
@@ -675,8 +675,6 @@ def confidence_circle(pose: Pose, p_noise: float):
675675

676676
# %% [hide-input]
677677

678-
# jax.random.split(jax.random.PRNGKey(3333), N_samples).shape
679-
680678
ps0 = jax.tree.map(lambda v: v[0], pose_samples)
681679
(
682680
ps0.project(jax.random.PRNGKey(2), S[()]),
@@ -1108,9 +1106,6 @@ def animate_full_trace(trace, frame_key=None):
11081106

11091107
def constraint_from_sensors(readings, t: int = T):
11101108
return C["steps", jnp.arange(t + 1), "sensor", :, "distance"].set(readings[: t + 1])
1111-
# return jax.vmap(
1112-
# lambda v: C["steps", :, "sensor", :, "distance"].set(v)
1113-
# )(readings[:t])
11141109

11151110

11161111
constraints_low_deviation = constraint_from_sensors(observations_low_deviation)
@@ -1222,7 +1217,7 @@ def constraint_from_path(path):
12221217
c_hds = jax.vmap(lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd))(
12231218
jnp.arange(T), path.hd
12241219
)
1225-
return c_ps + c_hds # + c_p + c_hd
1220+
return c_ps + c_hds
12261221

12271222

12281223
constraints_path_integrated = constraint_from_path(path_integrated)
@@ -1662,6 +1657,16 @@ def localization_sis(motion_settings, observations):
16621657
for p in smc_result.flood_fill()
16631658
]
16641659
)
1660+
1661+
# jay/colin: the inference is pretty good. We could:
1662+
# - add grid search to refine each step, or
1663+
# - let the robot adjust its next control input to make use of
1664+
# the updated information about its actual pose that the
1665+
# inference over the sensor data has revealed.
1666+
# the point being:
1667+
# We can say "using the inference, watch the robot succeed
1668+
# in entering the room. Without that, the robot's mission
1669+
# was bound to fail in the high deviation scenario."
16651670
# %%
16661671
# Try it in the low deviation setting
16671672
key, sub_key = jax.random.split(key)

0 commit comments

Comments
 (0)