39
39
import jax
40
40
import jax .numpy as jnp
41
41
import genjax
42
- from urllib .request import urlopen
43
42
from genjax import SelectionBuilder as S
44
43
from genjax import ChoiceMapBuilder as C
45
44
from genjax .typing import Array , FloatArray , PRNGKey , IntArray
50
49
import os
51
50
52
51
html = Plot .Hiccup
53
- Plot .configure ({"display_as" : "html" , "dev" : False })
52
+ # Plot.configure({"display_as": "html", "dev": False})
54
53
55
54
# Ensure a location for image generation.
56
55
os .makedirs ("imgs" , exist_ok = True )
@@ -205,19 +204,20 @@ def load_world(file_name):
205
204
Returns:
206
205
- tuple: A tuple containing the world configuration, the initial state, and the total number of control steps.
207
206
"""
208
- with urlopen (
209
- "https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json"
210
- ) as url :
211
- data = json .load (url )
212
-
213
- walls_vec = jnp .array (data ["wall_verts" ])
214
- clutters_vec = jnp .array (data ["clutter_vert_groups" ])
207
+ # TODO: change these to urlopen when the repo becomes public
208
+ with open ("../world.json" ) as f :
209
+ world = json .load (f )
210
+ with open ("../robot_program.json" ) as f :
211
+ robot_program = json .load (f )
212
+
213
+ walls_vec = jnp .array (world ["wall_verts" ])
214
+ clutters_vec = jnp .array (world ["clutter_vert_groups" ])
215
215
start = Pose (
216
- jnp .array (data ["start_pose" ]["p" ], dtype = float ),
217
- jnp .array (data ["start_pose" ]["hd" ], dtype = float ),
216
+ jnp .array (robot_program ["start_pose" ]["p" ], dtype = float ),
217
+ jnp .array (robot_program ["start_pose" ]["hd" ], dtype = float ),
218
218
)
219
219
220
- cs = jnp .array ([[c ["ds" ], c ["dhd" ]] for c in data ["program_controls" ]])
220
+ cs = jnp .array ([[c ["ds" ], c ["dhd" ]] for c in robot_program ["program_controls" ]])
221
221
controls = Control (cs [:, 0 ], cs [:, 1 ])
222
222
223
223
return make_world (walls_vec , clutters_vec , start , controls )
@@ -675,8 +675,6 @@ def confidence_circle(pose: Pose, p_noise: float):
675
675
676
676
# %% [hide-input]
677
677
678
- # jax.random.split(jax.random.PRNGKey(3333), N_samples).shape
679
-
680
678
ps0 = jax .tree .map (lambda v : v [0 ], pose_samples )
681
679
(
682
680
ps0 .project (jax .random .PRNGKey (2 ), S [()]),
@@ -1079,10 +1077,10 @@ def animate_full_trace(trace, frame_key=None):
1079
1077
"p_noise" : 0.05 ,
1080
1078
"hd_noise" : (1 / 10.0 ) * 2 * jnp .pi / 360 ,
1081
1079
}
1082
- key , k_low , k_high = jax . random . split ( key , 3 )
1080
+ motion_settings_high_deviation = { "p_noise" : 0.25 , "hd_noise" : 2 * jnp . pi / 360 }
1083
1081
1082
+ key , k_low , k_high = jax .random .split (key , 3 )
1084
1083
trace_low_deviation = full_model .simulate (k_low , (motion_settings_low_deviation ,))
1085
- motion_settings_high_deviation = {"p_noise" : 0.25 , "hd_noise" : 2 * jnp .pi / 360 }
1086
1084
trace_high_deviation = full_model .simulate (k_high , (motion_settings_high_deviation ,))
1087
1085
1088
1086
animate_full_trace (trace_low_deviation )
@@ -1108,9 +1106,6 @@ def animate_full_trace(trace, frame_key=None):
1108
1106
1109
1107
def constraint_from_sensors (readings , t : int = T ):
1110
1108
return C ["steps" , jnp .arange (t + 1 ), "sensor" , :, "distance" ].set (readings [: t + 1 ])
1111
- # return jax.vmap(
1112
- # lambda v: C["steps", :, "sensor", :, "distance"].set(v)
1113
- # )(readings[:t])
1114
1109
1115
1110
1116
1111
constraints_low_deviation = constraint_from_sensors (observations_low_deviation )
@@ -1222,7 +1217,7 @@ def constraint_from_path(path):
1222
1217
c_hds = jax .vmap (lambda ix , hd : C ["steps" , ix , "pose" , "hd" ].set (hd ))(
1223
1218
jnp .arange (T ), path .hd
1224
1219
)
1225
- return c_ps + c_hds # + c_p + c_hd
1220
+ return c_ps + c_hds
1226
1221
1227
1222
1228
1223
constraints_path_integrated = constraint_from_path (path_integrated )
@@ -1662,6 +1657,16 @@ def localization_sis(motion_settings, observations):
1662
1657
for p in smc_result .flood_fill ()
1663
1658
]
1664
1659
)
1660
+
1661
+ # jay/colin: the inference is pretty good. We could:
1662
+ # - add grid search to refine each step, or
1663
+ # - let the robot adjust its next control input to make use of
1664
+ # the updated information about its actual pose that the
1665
+ # inference over the sensor data has revealed.
1666
+ # the point being:
1667
+ # We can say "using the inference, watch the robot succeed
1668
+ # in entering the room. Without that, the robot's mission
1669
+ # was bound to fail in the high deviation scenario."
1665
1670
# %%
1666
1671
# Try it in the low deviation setting
1667
1672
key , sub_key = jax .random .split (key )
@@ -1676,3 +1681,16 @@ def localization_sis(motion_settings, observations):
1676
1681
for p in low_smc_result .flood_fill ()
1677
1682
]
1678
1683
)
1684
+
1685
+ # %%
1686
+ # demo: recycle traces
1687
+ key , k_low , k_high = jax .random .split (key , 3 )
1688
+ trace_low_deviation = full_model .simulate (k_low , (motion_settings_low_deviation ,))
1689
+ trace_high_deviation = full_model .simulate (k_high , (motion_settings_high_deviation ,))
1690
+ path_low_deviation = get_path (trace_low_deviation )
1691
+ path_high_deviation = get_path (trace_high_deviation )
1692
+ # ...using these data.
1693
+ observations_low_deviation = get_sensors (trace_low_deviation )
1694
+ observations_high_deviation = get_sensors (trace_high_deviation )
1695
+
1696
+ # %%
0 commit comments