diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index ae8d051cb1..ee650b9418 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -165,9 +165,15 @@ def _update_augmentations(self) -> None: for subset_name in ["train", "val", "test"]: subset = getattr(self, f"{subset_name}_data", None) augmentations = getattr(self, f"{subset_name}_augmentations", None) - model_transform = get_nested_attr(self, "trainer.model.pre_processor.transform") - if subset and model_transform: - self._update_subset_augmentations(subset, augmentations, model_transform) + model_transform = get_nested_attr(self, "trainer.model.pre_processor.transform", None) + + if subset: + if model_transform: + # If model transform exists, update augmentations with model-specific transforms + self._update_subset_augmentations(subset, augmentations, model_transform) + else: + # If no model transform, just apply the user-specified augmentations + subset.augmentations = augmentations @staticmethod def _update_subset_augmentations(