Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from genjax.incremental import Diff, NoChange, UnknownChange

import bayes3d as b
import bayes3d.scene_graph

from .genjax_distributions import (
contact_params_uniform,
Expand Down Expand Up @@ -127,14 +128,14 @@ def get_far_plane(trace):


def add_object(trace, key, obj_id, parent, face_parent, face_child):
N = b.get_indices(trace).shape[0] + 1
N = get_indices(trace).shape[0] + 1
choices = trace.get_choices()
choices[f"parent_{N-1}"] = parent
choices[f"id_{N-1}"] = obj_id
choices[f"face_parent_{N-1}"] = face_parent
choices[f"face_child_{N-1}"] = face_child
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0]


add_object_jit = jax.jit(add_object)
Expand All @@ -151,7 +152,7 @@ def print_trace(trace):


def viz_trace_meshcat(trace, colors=None):
b.clear()
b.clear_visualizer()
b.show_cloud(
"1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"])
)
Expand Down Expand Up @@ -223,14 +224,14 @@ def enumerator(trace, key, *args):
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[2]
)[0]

def enumerator_with_weight(trace, key, *args):
return trace.update(
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[1:3]
)[0:2]

def enumerator_score(trace, key, *args):
return enumerator(trace, key, *args).get_score()
Expand Down Expand Up @@ -301,4 +302,4 @@ def update_address(trace, key, address, value):
key,
genjax.choice_map({address: value}),
tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),
)[2]
)[0]
62 changes: 62 additions & 0 deletions bayes3d/viser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random

import viser

server.add_frame(
"/tree",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(random.random() * 2.0, 2.0, 0.2),
)
server.add_frame(
"/tree/branch",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(random.random() * 2.0, 2.0, 0.2),
)

client_handle = list(server.get_clients().values())[0]

p, q = client_handle.camera.position, client_handle.camera.wxyz

client_handle.camera.position = p
client_handle.camera.wxyz = q

img = client_handle.camera.get_render(100, 100)


server = viser.ViserServer()

import os

import trimesh

i = 9
model_dir = os.path.join(b.utils.get_assets_dir(), "ycb_video_models/models")
mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i], "textured.obj")
mesh = trimesh.load(mesh_path)

server.add_mesh_trimesh(
name="/trimesh",
mesh=mesh,
)

server.reset_scene()


server.add_mesh(
name="/trimesh",
vertices=mesh.vertices,
faces=mesh.faces,
)

sphere = trimesh.creation.uv_sphere(
0.1,
(
10,
10,
),
)
server.add_mesh(
name="/trimesh2",
vertices=sphere.vertices * np.array([1.0, 2.0, 3.0]),
faces=sphere.faces,
)
23 changes: 9 additions & 14 deletions bayes3d/viz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def preprocess_for_viz(img):
cmap.set_bad(color=(1.0, 1.0, 1.0, 1.0))


def get_depth_image(image, min_val=None, max_val=None, remove_max=True):
def get_depth_image(image, max=None):
"""Convert a depth image to a PIL image.

Args:
Expand All @@ -60,21 +60,16 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True):
Returns:
PIL.Image: Depth image visualized as a PIL image.
"""
if len(image.shape) > 2:
depth = np.array(image[:, :, -1])
depth = np.array(image)
if max is None:
maxim = depth.max()
else:
depth = np.array(image)

if max_val is None:
max_val = depth.max()
if not remove_max:
max_val += 1
if min_val is None:
min_val = depth.min()

mask = (depth < max_val) * (depth > min_val)
maxim = max
mask = depth < maxim
depth[np.logical_not(mask)] = np.nan
depth = (depth - min_val) / (max_val - min_val + 1e-10)
vmin = depth[mask].min()
vmax = depth[mask].max()
depth = (depth - vmin) / (vmax - vmin)

img = Image.fromarray(
np.rint(cmap(depth) * 255.0).astype(np.int8), mode="RGBA"
Expand Down
373 changes: 373 additions & 0 deletions demo_c2f.ipynb

Large diffs are not rendered by default.

339 changes: 339 additions & 0 deletions likelihood_debug.ipynb

Large diffs are not rendered by default.

Loading