Skip to content

Commit 37a0dc5

Browse files
mhuebertsritchie
authored andcommitted
formatting
1 parent 82474b6 commit 37a0dc5

File tree

4 files changed

+106
-624
lines changed

4 files changed

+106
-624
lines changed

bayes3d/genjax/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from genjax.incremental import Diff, NoChange, UnknownChange
88

99
import bayes3d as b
10-
import bayes3d.scene_graph
1110

1211
from .genjax_distributions import (
1312
contact_params_uniform,
@@ -128,14 +127,14 @@ def get_far_plane(trace):
128127

129128

130129
def add_object(trace, key, obj_id, parent, face_parent, face_child):
131-
N = get_indices(trace).shape[0] + 1
130+
N = b.get_indices(trace).shape[0] + 1
132131
choices = trace.get_choices()
133132
choices[f"parent_{N-1}"] = parent
134133
choices[f"id_{N-1}"] = obj_id
135134
choices[f"face_parent_{N-1}"] = face_parent
136135
choices[f"face_child_{N-1}"] = face_child
137136
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
138-
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0]
137+
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]
139138

140139

141140
add_object_jit = jax.jit(add_object)
@@ -152,7 +151,7 @@ def print_trace(trace):
152151

153152

154153
def viz_trace_meshcat(trace, colors=None):
155-
b.clear_visualizer()
154+
b.clear()
156155
b.show_cloud(
157156
"1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"])
158157
)
@@ -224,14 +223,14 @@ def enumerator(trace, key, *args):
224223
key,
225224
chm_builder(addresses, args, chm_args),
226225
argdiff_f(trace),
227-
)[0]
226+
)[2]
228227

229228
def enumerator_with_weight(trace, key, *args):
230229
return trace.update(
231230
key,
232231
chm_builder(addresses, args, chm_args),
233232
argdiff_f(trace),
234-
)[0:2]
233+
)[1:3]
235234

236235
def enumerator_score(trace, key, *args):
237236
return enumerator(trace, key, *args).get_score()
@@ -302,4 +301,4 @@ def update_address(trace, key, address, value):
302301
key,
303302
genjax.choice_map({address: value}),
304303
tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),
305-
)[0]
304+
)[2]

bayes3d/viser.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

demo_c2f.ipynb

Lines changed: 0 additions & 373 deletions
This file was deleted.

0 commit comments

Comments
 (0)