7
7
from genjax .incremental import Diff , NoChange , UnknownChange
8
8
9
9
import bayes3d as b
10
- import bayes3d .scene_graph
11
10
12
11
from .genjax_distributions import (
13
12
contact_params_uniform ,
@@ -128,14 +127,14 @@ def get_far_plane(trace):
128
127
129
128
130
129
def add_object (trace , key , obj_id , parent , face_parent , face_child ):
131
- N = get_indices (trace ).shape [0 ] + 1
130
+ N = b . get_indices (trace ).shape [0 ] + 1
132
131
choices = trace .get_choices ()
133
132
choices [f"parent_{ N - 1 } " ] = parent
134
133
choices [f"id_{ N - 1 } " ] = obj_id
135
134
choices [f"face_parent_{ N - 1 } " ] = face_parent
136
135
choices [f"face_child_{ N - 1 } " ] = face_child
137
136
choices [f"contact_params_{ N - 1 } " ] = jnp .zeros (3 )
138
- return model .importance (key , choices , (jnp .arange (N ), * trace .get_args ()[1 :]))[0 ]
137
+ return model .importance (key , choices , (jnp .arange (N ), * trace .get_args ()[1 :]))[1 ]
139
138
140
139
141
140
add_object_jit = jax .jit (add_object )
@@ -152,7 +151,7 @@ def print_trace(trace):
152
151
153
152
154
153
def viz_trace_meshcat (trace , colors = None ):
155
- b .clear_visualizer ()
154
+ b .clear ()
156
155
b .show_cloud (
157
156
"1" , b .apply_transform_jit (trace ["image" ].reshape (- 1 , 3 ), trace ["camera_pose" ])
158
157
)
@@ -224,14 +223,14 @@ def enumerator(trace, key, *args):
224
223
key ,
225
224
chm_builder (addresses , args , chm_args ),
226
225
argdiff_f (trace ),
227
- )[0 ]
226
+ )[2 ]
228
227
229
228
def enumerator_with_weight (trace , key , * args ):
230
229
return trace .update (
231
230
key ,
232
231
chm_builder (addresses , args , chm_args ),
233
232
argdiff_f (trace ),
234
- )[0 : 2 ]
233
+ )[1 : 3 ]
235
234
236
235
def enumerator_score (trace , key , * args ):
237
236
return enumerator (trace , key , * args ).get_score ()
@@ -302,4 +301,4 @@ def update_address(trace, key, address, value):
302
301
key ,
303
302
genjax .choice_map ({address : value }),
304
303
tuple (map (lambda v : Diff (v , UnknownChange ), trace .args )),
305
- )[0 ]
304
+ )[2 ]
0 commit comments