Skip to content

Commit fc69c95

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent c1436f5 commit fc69c95

File tree

91 files changed

+12163
-9676
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+12163
-9676
lines changed

bayes3d/colmap/colmap_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ def read_intrinsics_text(path):
196196
elems = line.split()
197197
camera_id = int(elems[0])
198198
model = elems[1]
199-
assert (
200-
model == "PINHOLE"
201-
), "While the loader support other types, the rest of the code assumes PINHOLE"
199+
assert model == "PINHOLE", (
200+
"While the loader support other types, the rest of the code assumes PINHOLE"
201+
)
202202
width = int(elems[2])
203203
height = int(elems[3])
204204
params = np.array(tuple(map(float, elems[4:])))

bayes3d/colmap/dataset_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
9595
FovY = focal2fov(focal_length_y, height)
9696
FovX = focal2fov(focal_length_x, width)
9797
else:
98-
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
98+
assert False, (
99+
"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
100+
)
99101

100102
image_path = os.path.join(images_folder, os.path.basename(extr.name))
101103
image_name = os.path.basename(image_path).split(".")[0]

bayes3d/genjax/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def get_far_plane(trace):
129129
def add_object(trace, key, obj_id, parent, face_parent, face_child):
130130
N = b.get_indices(trace).shape[0] + 1
131131
choices = trace.get_choices()
132-
choices[f"parent_{N-1}"] = parent
133-
choices[f"id_{N-1}"] = obj_id
134-
choices[f"face_parent_{N-1}"] = face_parent
135-
choices[f"face_child_{N-1}"] = face_child
136-
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
132+
choices[f"parent_{N - 1}"] = parent
133+
choices[f"id_{N - 1}"] = obj_id
134+
choices[f"face_parent_{N - 1}"] = face_parent
135+
choices[f"face_child_{N - 1}"] = face_child
136+
choices[f"contact_params_{N - 1}"] = jnp.zeros(3)
137137
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]
138138

139139

bayes3d/neural/cosypose_baseline/cosypose_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def cosypose_interface(rgb_imgs, camera_k):
206206
all_scores = []
207207
for i, rgb_img in enumerate(rgb_imgs):
208208
pred = COSYPOSE_MODEL.inference(rgb_img, camera_k)
209-
print(f"{i+1}/{num_imgs} inference done")
209+
print(f"{i + 1}/{num_imgs} inference done")
210210

211211
pred_poses = np.asarray(pred.poses.cpu())
212212
pred_ids = [

bayes3d/neural/dino.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module:
189189
return model
190190

191191
stride = nn_utils._pair(stride)
192-
assert all(
193-
[(patch_size // s_) * s_ == patch_size for s_ in stride]
194-
), f"stride {stride} should divide patch_size {patch_size}"
192+
assert all([(patch_size // s_) * s_ == patch_size for s_ in stride]), (
193+
f"stride {stride} should divide patch_size {patch_size}"
194+
)
195195

196196
# fix the stride
197197
model.patch_embed.proj.stride = stride
@@ -415,7 +415,9 @@ def extract_descriptors(
415415
if not include_cls:
416416
x = x[:, :, 1:, :] # remove cls token
417417
else:
418-
assert not bin, "bin = True and include_cls = True are not supported together, set one of them False."
418+
assert not bin, (
419+
"bin = True and include_cls = True are not supported together, set one of them False."
420+
)
419421
if not bin:
420422
desc = (
421423
x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1)
@@ -431,9 +433,9 @@ def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor:
431433
:param batch: batch to extract saliency maps for. Has shape BxCxHxW.
432434
:return: a tensor of saliency maps. has shape Bxt-1
433435
"""
434-
assert (
435-
self.model_type == "dino_vits8"
436-
), "saliency maps are supported only for dino_vits model_type."
436+
assert self.model_type == "dino_vits8", (
437+
"saliency maps are supported only for dino_vits model_type."
438+
)
437439
self._extract_features(batch, [11], "attn")
438440
head_idxs = [0, 2, 4, 5]
439441
curr_feats = self._feats[0] # Bxhxtxt

bayes3d/renderer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def add_mesh(self, mesh, mesh_name=None, scaling_factor=1.0, center_mesh=True):
126126
bounding_box_dims, bounding_box_pose = bayes3d.utils.aabb(mesh.vertices)
127127
if center_mesh:
128128
if not jnp.isclose(bounding_box_pose[:3, 3], 0.0).all():
129-
print(f"Centering mesh with translation {bounding_box_pose[:3,3]}")
129+
print(f"Centering mesh with translation {bounding_box_pose[:3, 3]}")
130130
mesh.vertices = mesh.vertices - bounding_box_pose[:3, 3]
131131

132132
self.meshes.append(mesh)

bayes3d/rendering/nvdiffrast_jax/renderer_matching_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def xfm_points(points, matrix):
178178
pos_grads.min().item(),
179179
pos_grads.max().item(),
180180
)
181-
print(f"JAX rasterization (eval + grad): {(end_time - start_time)*1000} ms")
181+
print(f"JAX rasterization (eval + grad): {(end_time - start_time) * 1000} ms")
182182

183183
# save viz
184184
b.viz.get_depth_image(rast_out[0][:, :, 2]).save("img_jax.png")
@@ -229,7 +229,7 @@ def xfm_points(points, matrix):
229229
pos_grads.min().item(),
230230
pos_grads.max().item(),
231231
)
232-
print(f"Torch rasterization (eval + grad): {(end_time - start_time)*1000} ms")
232+
print(f"Torch rasterization (eval + grad): {(end_time - start_time) * 1000} ms")
233233

234234
# save viz
235235
b.viz.get_depth_image(jnp.array(rast_out[0][:, :, 2].cpu())).save("img_torch.png")
@@ -278,7 +278,7 @@ def xfm_points(points, matrix):
278278
print(
279279
f"JAX BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}"
280280
)
281-
print(f"JAX interpolation: {(end_time - start_time)*1000} ms")
281+
print(f"JAX interpolation: {(end_time - start_time) * 1000} ms")
282282

283283
# save viz
284284
b.viz.get_depth_image(gb_pos[0][:, :, 2]).save("interpolate_jax.png")
@@ -316,7 +316,7 @@ def xfm_points(points, matrix):
316316
print(
317317
f"TORCH BWD (sum, min, max): g_attr={g_attr.sum().item(), g_attr.min().item(), g_attr.max().item()}\ng_rast={g_rast.sum().item(), g_rast.min().item(), g_rast.max().item()}"
318318
)
319-
print(f"Torch interpolation: {(end_time - start_time)*1000} ms")
319+
print(f"Torch interpolation: {(end_time - start_time) * 1000} ms")
320320

321321
# save viz
322322
b.viz.get_depth_image(jnp.array(gb_pos[0][:, :, 2].cpu())).save(

scripts/_mkl/notebooks/00a - Types.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"#|default_exp types"
9+
"# |default_exp types"
1010
]
1111
},
1212
{
@@ -15,7 +15,7 @@
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18-
"#|export\n",
18+
"# |export\n",
1919
"from typing import Any, NamedTuple\n",
2020
"import numpy as np\n",
2121
"import jax\n",
@@ -29,18 +29,18 @@
2929
"Int = Array\n",
3030
"FaceIndex = int\n",
3131
"FaceIndices = Array\n",
32-
"ArrayN = Array\n",
33-
"Array3 = Array\n",
34-
"Array2 = Array\n",
35-
"ArrayNx2 = Array\n",
36-
"ArrayNx3 = Array\n",
37-
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
38-
"PrecisionMatrix = Matrix\n",
32+
"ArrayN = Array\n",
33+
"Array3 = Array\n",
34+
"Array2 = Array\n",
35+
"ArrayNx2 = Array\n",
36+
"ArrayNx3 = Array\n",
37+
"Matrix = jaxlib.xla_extension.ArrayImpl\n",
38+
"PrecisionMatrix = Matrix\n",
3939
"CovarianceMatrix = Matrix\n",
40-
"CholeskyMatrix = Matrix\n",
41-
"SquareMatrix = Matrix\n",
42-
"Vector = Array\n",
43-
"Direction = Vector\n",
40+
"CholeskyMatrix = Matrix\n",
41+
"SquareMatrix = Matrix\n",
42+
"Vector = Array\n",
43+
"Direction = Vector\n",
4444
"BaseVector = Vector"
4545
]
4646
},

0 commit comments

Comments
 (0)