Skip to content

Commit 277eee4

Browse files
committed
Remove shape checks
1 parent 86e2c88 commit 277eee4

File tree

2 files changed

+0
-8
lines changed

2 files changed

+0
-8
lines changed

imgx/task/diffusion_segmentation/experiment.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,6 @@ def train_init(
387387
aug_rng = jax.random.PRNGKey(self.config["seed"])
388388
batch = aug_fn(aug_rng, batch)
389389

390-
# check image size
391-
image_shape = self.dataset_info.image_spatial_shape
392-
chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape)
393-
394390
# init train state on cpu first
395391
dtype = get_half_precision_dtype(self.config.half_precision)
396392
model = instantiate(self.config.task.model, dtype=dtype)

imgx/task/segmentation/experiment.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,6 @@ def train_init(
261261
aug_rng = jax.random.PRNGKey(self.config["seed"])
262262
batch = aug_fn(aug_rng, batch)
263263

264-
# check image size
265-
image_shape = self.dataset_info.image_spatial_shape
266-
chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape)
267-
268264
# init train state on cpu first
269265
dtype = get_half_precision_dtype(self.config.half_precision)
270266
model = instantiate(self.config.task.model, dtype=dtype)

0 commit comments

Comments
 (0)