|
39 | 39 | import jax
|
40 | 40 | import jax.numpy as jnp
|
41 | 41 | import genjax
|
42 |
| -from urllib.request import urlopen |
43 | 42 | from genjax import SelectionBuilder as S
|
44 | 43 | from genjax import ChoiceMapBuilder as C
|
45 | 44 | from genjax.typing import Array, FloatArray, PRNGKey, IntArray
|
@@ -205,19 +204,20 @@ def load_world(file_name):
|
205 | 204 | Returns:
|
206 | 205 | - tuple: A tuple containing the world configuration, the initial state, and the total number of control steps.
|
207 | 206 | """
|
208 |
| - with urlopen( |
209 |
| - "https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json" |
210 |
| - ) as url: |
211 |
| - data = json.load(url) |
212 |
| - |
213 |
| - walls_vec = jnp.array(data["wall_verts"]) |
214 |
| - clutters_vec = jnp.array(data["clutter_vert_groups"]) |
| 207 | + # TODO: change these to urlopen when the repo becomes public |
| 208 | + with open("../world.json") as f: |
| 209 | + world = json.load(f) |
| 210 | + with open("../robot_program.json") as f: |
| 211 | + robot_program = json.load(f) |
| 212 | + |
| 213 | + walls_vec = jnp.array(world["wall_verts"]) |
| 214 | + clutters_vec = jnp.array(world["clutter_vert_groups"]) |
215 | 215 | start = Pose(
|
216 |
| - jnp.array(data["start_pose"]["p"], dtype=float), |
217 |
| - jnp.array(data["start_pose"]["hd"], dtype=float), |
| 216 | + jnp.array(robot_program["start_pose"]["p"], dtype=float), |
| 217 | + jnp.array(robot_program["start_pose"]["hd"], dtype=float), |
218 | 218 | )
|
219 | 219 |
|
220 |
| - cs = jnp.array([[c["ds"], c["dhd"]] for c in data["program_controls"]]) |
| 220 | + cs = jnp.array([[c["ds"], c["dhd"]] for c in robot_program["program_controls"]]) |
221 | 221 | controls = Control(cs[:, 0], cs[:, 1])
|
222 | 222 |
|
223 | 223 | return make_world(walls_vec, clutters_vec, start, controls)
|
@@ -675,8 +675,6 @@ def confidence_circle(pose: Pose, p_noise: float):
|
675 | 675 |
|
676 | 676 | # %% [hide-input]
|
677 | 677 |
|
678 |
| -# jax.random.split(jax.random.PRNGKey(3333), N_samples).shape |
679 |
| - |
680 | 678 | ps0 = jax.tree.map(lambda v: v[0], pose_samples)
|
681 | 679 | (
|
682 | 680 | ps0.project(jax.random.PRNGKey(2), S[()]),
|
@@ -1108,9 +1106,6 @@ def animate_full_trace(trace, frame_key=None):
|
1108 | 1106 |
|
1109 | 1107 | def constraint_from_sensors(readings, t: int = T):
|
1110 | 1108 | return C["steps", jnp.arange(t + 1), "sensor", :, "distance"].set(readings[: t + 1])
|
1111 |
| - # return jax.vmap( |
1112 |
| - # lambda v: C["steps", :, "sensor", :, "distance"].set(v) |
1113 |
| - # )(readings[:t]) |
1114 | 1109 |
|
1115 | 1110 |
|
1116 | 1111 | constraints_low_deviation = constraint_from_sensors(observations_low_deviation)
|
@@ -1222,7 +1217,7 @@ def constraint_from_path(path):
|
1222 | 1217 | c_hds = jax.vmap(lambda ix, hd: C["steps", ix, "pose", "hd"].set(hd))(
|
1223 | 1218 | jnp.arange(T), path.hd
|
1224 | 1219 | )
|
1225 |
| - return c_ps + c_hds # + c_p + c_hd |
| 1220 | + return c_ps + c_hds |
1226 | 1221 |
|
1227 | 1222 |
|
1228 | 1223 | constraints_path_integrated = constraint_from_path(path_integrated)
|
@@ -1662,6 +1657,16 @@ def localization_sis(motion_settings, observations):
|
1662 | 1657 | for p in smc_result.flood_fill()
|
1663 | 1658 | ]
|
1664 | 1659 | )
|
| 1660 | + |
| 1661 | +# jay/colin: the inference is pretty good. We could: |
| 1662 | +# - add grid search to refine each step, or |
| 1663 | +# - let the robot adjust its next control input to make use of |
| 1664 | +# the updated information about its actual pose that the |
| 1665 | +# inference over the sensor data has revealed. |
| 1666 | +# the point being: |
| 1667 | +# We can say "using the inference, watch the robot succeed |
| 1668 | +# in entering the room. Without that, the robot's mission |
| 1669 | +# was bound to fail in the high deviation scenario." |
1665 | 1670 | # %%
|
1666 | 1671 | # Try it in the low deviation setting
|
1667 | 1672 | key, sub_key = jax.random.split(key)
|
|
0 commit comments