Skip to content

Commit 8fdd3c6

Browse files
authored
robot viz in genjax (GEN-884) (#20)
1 parent 4762e31 commit 8fdd3c6

File tree

6 files changed

+775
-333
lines changed

6 files changed

+775
-333
lines changed

poetry.lock

Lines changed: 19 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ packages = [
1212
[tool.poetry.dependencies]
1313
python = ">=3.11,<3.13"
1414
jupytext = "^1.16.1"
15-
genjax = {version = "0.7.0.post4.dev0+eacb241e", source = "gcp" }
16-
# genstudio = {version = "2024.12.003", source = "gcp"}
17-
genstudio = {path = "../genstudio", develop = true}
15+
genjax = {version = "0.8.0", source = "gcp" }
16+
genstudio = {version = "2024.12.004", source = "gcp"}
1817
ipykernel = "^6.29.3"
1918
matplotlib = "^3.8.3"
2019
anywidget = "^0.9.7"

robot_2/bench.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# %%
2+
import jax.numpy as jnp
3+
import time
4+
from robot_2.where_am_i import (
5+
RobotCapabilities,
6+
simulate_robot,
7+
World,
8+
walls_to_jax,
9+
PRNGKey,
10+
)
11+
import jax.random
12+
13+
# Benchmark State:
14+
# Walls:
15+
walls = [
16+
[0, 0, 0],
17+
[10, 0, 0],
18+
[10, 0, 0],
19+
[10, 10, 0],
20+
[10, 10, 0],
21+
[0, 10, 0],
22+
[0, 10, 0],
23+
[0, 0, 0],
24+
[8.125, 9.03125, 1974574552],
25+
[7.875, 9.0625, 1974574552],
26+
[7.53125, 9.0625, 1974574552],
27+
[7.09375, 9.09375, 1974574552],
28+
[6.8125, 9.09375, 1974574552],
29+
[6.5, 9.125, 1974574552],
30+
[6.1875, 9.125, 1974574552],
31+
[5.71875, 9.15625, 1974574552],
32+
[5.40625, 9.15625, 1974574552],
33+
[5.09375, 9.15625, 1974574552],
34+
[4.75, 9.15625, 1974574552],
35+
[4.4375, 9.125, 1974574552],
36+
[4, 9.09375, 1974574552],
37+
[3.71875, 9.0625, 1974574552],
38+
[3.46875, 9.03125, 1974574552],
39+
[3.21875, 9, 1974574552],
40+
[2.96875, 8.9375, 1974574552],
41+
[2.5625, 8.84375, 1974574552],
42+
[2.09375, 8.75, 1974574552],
43+
[1.71875, 8.625, 1974574552],
44+
[1.75, 8.34375, 1974574552],
45+
[2.03125, 8.25, 1974574552],
46+
[2.34375, 8.1875, 1974574552],
47+
[2.71875, 8.15625, 1974574552],
48+
[3.0625, 8.125, 1974574552],
49+
[3.40625, 8.09375, 1974574552],
50+
[3.96875, 8.0625, 1974574552],
51+
[4.3125, 8.03125, 1974574552],
52+
[4.625, 8.03125, 1974574552],
53+
[4.90625, 8, 1974574552],
54+
[5.3125, 8, 1974574552],
55+
[5.59375, 8, 1974574552],
56+
[5.875, 7.9375, 1974574552],
57+
[6.15625, 7.875, 1974574552],
58+
[6.4375, 7.75, 1974574552],
59+
[6.6875, 7.59375, 1974574552],
60+
[6.8125, 7.375, 1974574552],
61+
[6.6875, 7.15625, 1974574552],
62+
[6.40625, 7, 1974574552],
63+
[5.78125, 6.875, 1974574552],
64+
[5.4375, 6.8125, 1974574552],
65+
[5.09375, 6.75, 1974574552],
66+
[4.34375, 6.625, 1974574552],
67+
[3.90625, 6.5625, 1974574552],
68+
[3.65625, 6.5, 1974574552],
69+
[3.1875, 6.4375, 1974574552],
70+
[2.625, 6.28125, 1974574552],
71+
[2.65625, 5.96875, 1974574552],
72+
[2.9375, 5.84375, 1974574552],
73+
[3.25, 5.78125, 1974574552],
74+
[3.53125, 5.75, 1974574552],
75+
[4.03125, 5.71875, 1974574552],
76+
[4.46875, 5.6875, 1974574552],
77+
[5.21875, 5.6875, 1974574552],
78+
[5.75, 5.6875, 1974574552],
79+
[6.0625, 5.6875, 1974574552],
80+
[6.3125, 5.6875, 1974574552],
81+
[6.5625, 5.65625, 1974574552],
82+
[6.59375, 5.3125, 1974574552],
83+
[6.375, 5.1875, 1974574552],
84+
[6, 5.0625, 1974574552],
85+
[5.3125, 4.9375, 1974574552],
86+
[4.9375, 4.875, 1974574552],
87+
[4.1875, 4.8125, 1974574552],
88+
[3.875, 4.75, 1974574552],
89+
[3.46875, 4.71875, 1974574552],
90+
[3.0625, 4.65625, 1974574552],
91+
[2.75, 4.59375, 1974574552],
92+
[2.46875, 4.5, 1974574552],
93+
[2.78125, 4.1875, 1974574552],
94+
[3.15625, 4.125, 1974574552],
95+
[3.53125, 4.0625, 1974574552],
96+
[4, 4, 1974574552],
97+
[4.46875, 3.9375000000000004, 1974574552],
98+
[5.25, 3.8749999999999996, 1974574552],
99+
[5.75, 3.84375, 1974574552],
100+
[6, 3.84375, 1974574552],
101+
[6.28125, 3.84375, 1974574552],
102+
[6.75, 3.84375, 1974574552],
103+
[6.5, 3.53125, 1974574552],
104+
[6.25, 3.5, 1974574552],
105+
[5.875, 3.4375, 1974574552],
106+
[5.59375, 3.4062499999999996, 1974574552],
107+
[5.0625, 3.375, 1974574552],
108+
[4.625, 3.375, 1974574552],
109+
[3.9375, 3.34375, 1974574552],
110+
[2.875, 3.3125000000000004, 1974574552],
111+
[1.75, 3.1562500000000004, 1974574552],
112+
[1.5, 3.0625, 1974574552],
113+
[1.53125, 2.7812499999999996, 1974574552],
114+
[1.8125, 2.65625, 1974574552],
115+
[2.0625, 2.6249999999999996, 1974574552],
116+
[2.625, 2.5, 1974574552],
117+
[2.875, 2.4687499999999996, 1974574552],
118+
[3.375, 2.40625, 1974574552],
119+
[3.9375, 2.3750000000000004, 1974574552],
120+
[4.8125, 2.34375, 1974574552],
121+
[5.34375, 2.34375, 1974574552],
122+
[5.875, 2.34375, 1974574552],
123+
[6.34375, 2.34375, 1974574552],
124+
[6.59375, 2.34375, 1974574552],
125+
[7.15625, 2.34375, 1974574552],
126+
[7.4375, 2.28125, 1974574552],
127+
[7.1875, 2.03125, 1974574552],
128+
[6.75, 1.9374999999999998, 1974574552],
129+
[6.4375, 1.875, 1974574552],
130+
[5.59375, 1.7812499999999998, 1974574552],
131+
[5.125, 1.7812499999999998, 1974574552],
132+
[4.625, 1.7812499999999998, 1974574552],
133+
[3.96875, 1.71875, 1974574552],
134+
[3.4375, 1.6562500000000002, 1974574552],
135+
[3.125, 1.5937500000000004, 1974574552],
136+
[2.71875, 1.4687499999999998, 1974574552],
137+
[2.40625, 1.3437500000000002, 1974574552],
138+
[2.0625, 1.2187499999999996, 1974574552],
139+
[2.5, 1.0312500000000002, 1974574552],
140+
[2.84375, 1.0312500000000002, 1974574552],
141+
[3.375, 0.9999999999999998, 1974574552],
142+
[3.96875, 0.9687500000000004, 1974574552],
143+
[4.6875, 0.9375, 1974574552],
144+
[5.21875, 0.9375, 1974574552],
145+
[5.84375, 0.9687500000000004, 1974574552],
146+
[6.4375, 1.0312500000000002, 1974574552],
147+
[7, 1.1562499999999998, 1974574552],
148+
[7.34375, 1.2187499999999996, 1974574552],
149+
[7.40625, 1.25, 1974574552],
150+
]
151+
152+
# Robot Path:
153+
robot_path = [
154+
[2.5, 0.2500000000000002, 1097479840],
155+
[3.03125, 0.2500000000000002, 1097479840],
156+
[3.625, 0.2500000000000002, 1097479840],
157+
[4.125, 0.2500000000000002, 1097479840],
158+
[4.625, 0.2500000000000002, 1097479840],
159+
[5.375, 0.21874999999999978, 1097479840],
160+
[5.90625, 0.18750000000000044, 1097479840],
161+
[6.4375, 0.18750000000000044, 1097479840],
162+
[6.96875, 0.18750000000000044, 1097479840],
163+
[7.46875, 0.28124999999999956, 1097479840],
164+
[7.9375, 0.46875, 1097479840],
165+
[8, 0.9999999999999998, 1097479840],
166+
[7.65625, 1.40625, 1097479840],
167+
[7.15625, 1.5625, 1097479840],
168+
[6.5625, 1.5625, 1097479840],
169+
[6, 1.5312499999999996, 1097479840],
170+
[5.5, 1.5000000000000002, 1097479840],
171+
[4.9375, 1.40625, 1097479840],
172+
[4.40625, 1.3749999999999996, 1097479840],
173+
[4.9375, 1.3437500000000002, 1097479840],
174+
[5.46875, 1.4375000000000004, 1097479840],
175+
[6.0625, 1.4375000000000004, 1097479840],
176+
[6.625, 1.4687499999999998, 1097479840],
177+
[7.1875, 1.5000000000000002, 1097479840],
178+
[7.65625, 1.71875, 1097479840],
179+
[7.8125, 2.2187500000000004, 1097479840],
180+
[7.71875, 2.71875, 1097479840],
181+
[7.25, 2.90625, 1097479840],
182+
[6.6875, 3.03125, 1097479840],
183+
[6.03125, 3.125, 1097479840],
184+
[5.4375, 3.125, 1097479840],
185+
[4.875, 3.0625, 1097479840],
186+
[4.28125, 3.03125, 1097479840],
187+
[3.71875, 3.0000000000000004, 1097479840],
188+
[4.34375, 2.96875, 1097479840],
189+
[4.96875, 2.96875, 1097479840],
190+
[5.65625, 3.0000000000000004, 1097479840],
191+
[6.25, 3.0000000000000004, 1097479840],
192+
[7.09375, 3.0000000000000004, 1097479840],
193+
[7.75, 3.03125, 1097479840],
194+
[8.25, 3.125, 1097479840],
195+
[8.1875, 3.6875, 1097479840],
196+
[7.59375, 4, 1097479840],
197+
[6.96875, 4.21875, 1097479840],
198+
[6.375, 4.375, 1097479840],
199+
[5.875, 4.4375, 1097479840],
200+
[5.375, 4.5, 1097479840],
201+
[6, 4.5625, 1097479840],
202+
[6.53125, 4.6875, 1097479840],
203+
[7.0625, 4.78125, 1097479840],
204+
[7.4375, 5.1875, 1097479840],
205+
[7.1875, 5.6875, 1097479840],
206+
[6.75, 6.03125, 1097479840],
207+
[6.125, 6.34375, 1097479840],
208+
[5.59375, 6.46875, 1097479840],
209+
[5.0625, 6.46875, 1097479840],
210+
[5.59375, 6.46875, 1097479840],
211+
[6.78125, 6.5625, 1097479840],
212+
[7.3125, 6.65625, 1097479840],
213+
[7.9375, 6.90625, 1097479840],
214+
[8.15625, 7.40625, 1097479840],
215+
[7.6875, 7.71875, 1097479840],
216+
[7.09375, 7.90625, 1097479840],
217+
[6.46875, 8.0625, 1097479840],
218+
[5.96875, 8.15625, 1097479840],
219+
[5.375, 8.1875, 1097479840],
220+
[4.6875, 8.3125, 1097479840],
221+
[4.1875, 8.34375, 1097479840],
222+
[4.6875, 8.4375, 1097479840],
223+
[5.46875, 8.46875, 1097479840],
224+
[5.96875, 8.46875, 1097479840],
225+
[6.53125, 8.46875, 1097479840],
226+
[7.09375, 8.46875, 1097479840],
227+
[7.59375, 8.53125, 1097479840],
228+
[8.15625, 8.625, 1097479840],
229+
[8.65625, 8.8125, 1097479840],
230+
[8.28125, 9.34375, 1097479840],
231+
[7.75, 9.4375, 1097479840],
232+
[7.15625, 9.5, 1097479840],
233+
[6.625, 9.5, 1097479840],
234+
[6, 9.5, 1097479840],
235+
[5.375, 9.53125, 1097479840],
236+
[4.8125, 9.53125, 1097479840],
237+
[4.1875, 9.53125, 1097479840],
238+
[3.5625, 9.4375, 1097479840],
239+
[3, 9.4375, 1097479840],
240+
[2.46875, 9.40625, 1097479840],
241+
[1.84375, 9.34375, 1097479840],
242+
[1.4375, 9.34375, 1097479840],
243+
]
244+
245+
246+
def perturb_walls(w, idx, amount=0.1):
247+
# Add small random offset to one wall endpoint
248+
w[idx][1] = w[idx][1] + amount
249+
return w
250+
251+
252+
def get_robot(p_noise):
253+
return RobotCapabilities(
254+
p_noise=jnp.array(p_noise),
255+
hd_noise=jnp.array(0.03),
256+
sensor_noise=jnp.array(0.1),
257+
n_sensors=jnp.array(8),
258+
sensor_range=jnp.array(10.0),
259+
)
260+
261+
262+
# Test 2: Perturb robot capabilities
263+
264+
265+
# Create random keys
266+
keys = jax.random.split(PRNGKey(0), 100)
267+
268+
# Test 1: Run simulation with perturbed walls
269+
print("Test 1: Perturbing walls")
270+
start = time.time()
271+
all_paths_1, all_readings_1 = jax.vmap(
272+
lambda k: simulate_robot(
273+
World(*walls_to_jax(walls)),
274+
get_robot(0.1),
275+
jnp.array(robot_path),
276+
k,
277+
)
278+
)(keys)
279+
print(f"Wall perturbation test took {time.time() - start:.3f} seconds")
280+
281+
# Test 2: Run simulation with perturbed robot capabilities
282+
print("\nTest 2: Perturbing robot capabilities")
283+
start = time.time()
284+
p_noises = jnp.linspace(0.05, 0.15, 100) # Range of p_noise values
285+
all_paths_2, all_readings_2 = jax.vmap(
286+
lambda k, p: simulate_robot(
287+
World(*walls_to_jax(walls)), get_robot(p), jnp.array(robot_path), k
288+
)
289+
)(keys, p_noises)
290+
print(f"Robot capability perturbation test took {time.time() - start:.3f} seconds")

0 commit comments

Comments
 (0)