diff --git a/mbodied/types/sense/vision.py b/mbodied/types/sense/vision.py index 288d43f..c371a14 100644 --- a/mbodied/types/sense/vision.py +++ b/mbodied/types/sense/vision.py @@ -26,8 +26,6 @@ ```python image = Image("path/to/image.png", size=new_size_tuple).save("path/to/new/image.jpg") image.save("path/to/new/image.jpg", quality=5) - -TODO: Implement Lazy attribute loading for the image data. """ import base64 as base64lib @@ -52,6 +50,7 @@ InstanceOf, model_serializer, model_validator, + PrivateAttr ) from typing_extensions import Literal @@ -90,18 +89,12 @@ class Image(Sample): model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True, extras="forbid", validate_assignment=False) - array: NumpyArray - size: tuple[int, int] - - pil: InstanceOf[PILImage] | None = Field( - None, - repr=False, - exclude=True, - description="The image represented as a PIL Image object.", - ) + _array: NumpyArray | None = PrivateAttr(default=None) + _base64: InstanceOf[Base64Str] | None = PrivateAttr(default=None) + _pil: InstanceOf[PILImage] | None = PrivateAttr(default=None) + _url: InstanceOf[AnyUrl] | str | None = PrivateAttr(default=None) + _size: tuple[int, int] | None = PrivateAttr(default=None) encoding: Literal["png", "jpeg", "jpg", "bmp", "gif"] - base64: InstanceOf[Base64Str] | None = None - url: InstanceOf[AnyUrl] | str | None = None path: FilePath | None = None @classmethod @@ -176,6 +169,70 @@ def __init__( kwargs["bytes"] = bytes_obj super().__init__(**kwargs) + self._array = kwargs.get("array", None) + self._base64 = kwargs.get("base64", None) + self._pil = kwargs.get("pil", None) + self._url = kwargs.get("url", None) + self._size = kwargs.get("size", None) + + if self._size is not None and self.pil is not None: + self._pil = self._pil.resize(self._size) + self._array = np.array(self._pil) + + @property + def array(self) -> np.ndarray | None: + """Lazily computes and returns the NumPy array.""" + if self._array is None and self.pil is not None: + # Convert the PIL image to a NumPy array + self._array = np.array(self._pil) + return self._array + + @property + def base64(self) -> Base64Str | None: + """Lazily computes and returns the base64-encoded string.""" + if self._base64 is None and self.pil is not None: + buffer = io.BytesIO() + # Save the PIL image to a buffer in the specified encoding + self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) + self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8") + return self._base64 + + @property + def pil(self) -> PILImage | None: + """Lazily loads and returns the PIL image.""" + if self._pil is None: + if self._array is not None: + self._pil = PILModule.fromarray(self._array).convert("RGB") + elif self._base64 is not None: + image_data = base64lib.b64decode(self._base64) + self._pil = PILModule.open(io.BytesIO(image_data)).convert("RGB") + elif self.path is not None: + self._pil = PILModule.open(self.path).convert("RGB") + elif self._url is not None: + self._pil = Image.load_url(self._url) + return self._pil + + @property + def url(self) -> AnyUrl | str | None: + """Lazily computes and returns the data URL.""" + if self._url is None and self._base64 is not None: + self._url = f"data:image/{self.encoding};base64,{self._base64}" + elif self._url is None and self.pil is not None: + # First convert the PIL image to a base64 string + buffer = io.BytesIO() + self._pil.convert("RGB").save(buffer, format=self.encoding.upper()) + self._base64 = base64lib.b64encode(buffer.getvalue()).decode("utf-8") + # Construct the data URL + self._url = f"data:image/{self.encoding};base64,{self._base64}" + return self._url + + @property + def size(self) -> tuple[int, int] | None: + "Lazily computes and returns the image size" + if self._size is None and self.pil is not None: + self._size = self._pil.size + return self._size + def __repr__(self): """Return a string representation of the image.""" if self.base64 is None: @@ -217,37 +274,6 @@ def open(path: str, encoding: str = "jpeg", size=None) -> "Image": image = PILModule.open(path).convert("RGB") return Image(image, encoding, size) - @staticmethod - def pil_to_data(image: PILImage, encoding: str, size=None) -> dict: - """Creates an Image instance from a PIL image. - - Args: - image (PIL.Image.Image): The source PIL image from which to create the Image instance. - encoding (str): The format used for encoding the image when converting to base64. - size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple. - - Returns: - Image: An instance of the Image class with populated fields. - """ - if encoding.lower() == "jpg": - encoding = "jpeg" - buffer = io.BytesIO() - image.convert("RGB").save(buffer, format=encoding.upper()) - base64_encoded = base64lib.b64encode(buffer.getvalue()).decode("utf-8") - data_url = f"data:image/{encoding};base64,{base64_encoded}" - if size is not None: - image = image.resize(size) - else: - size = image.size - return { - "array": np.array(image), - "base64": base64_encoded, - "pil": image, - "size": size, - "url": data_url, - "encoding": encoding.lower(), - } - @staticmethod def load_url(url: str, download=False) -> PILImage | None: """Downloads an image from a URL or decodes it from a base64 data URI. @@ -302,21 +328,6 @@ def from_bytes(cls, bytes_data: bytes, encoding: str = "jpeg", size=None) -> "Im image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB") return cls(image, encoding, size) - @staticmethod - def bytes_to_data(bytes_data: bytes, encoding: str = "jpeg", size=None) -> dict: - """Creates an Image instance from a bytes object. - - Args: - bytes_data (bytes): The bytes object to convert to an image. - encoding (str): The format used for encoding the image when converting to base64. - size (Optional[Tuple[int, int]]): The size of the image as a (width, height) tuple. - - Returns: - Image: An instance of the Image class with populated fields. - """ - image = PILModule.open(io.BytesIO(bytes_data)).convert("RGB") - return Image.pil_to_data(image, encoding, size) - @model_validator(mode="before") @classmethod def validate_kwargs(cls, values) -> dict: @@ -327,74 +338,30 @@ def validate_kwargs(cls, values) -> dict: if len(provided_fields) > 1: raise ValueError(f"Multiple image sources provided; only one is allowed but got: {provided_fields}") - # Initialize all fields to None or their default values + # Initialize all fields to their input values or None validated_values = { - "array": None, - "base64": None, + "array": values.get("array", None), + "base64": values.get("base64", None), "encoding": values.get("encoding", "jpeg").lower(), - "path": None, - "pil": None, - "url": None, + "path": values.get("path", None), + "pil": values.get("pil", None), + "url": values.get("url", None), "size": values.get("size", None), } # Validate the encoding first + if validated_values["encoding"] == "jpg": + validated_values["encoding"] = "jpeg" + if validated_values["encoding"] not in ["png", "jpeg", "jpg", "bmp", "gif"]: raise ValueError("The 'encoding' must be a valid image format (png, jpeg, jpg, bmp, gif).") - if "bytes" in values and values["bytes"] is not None: - validated_values.update(cls.bytes_to_data(values["bytes"], values["encoding"], values["size"])) - return validated_values - - if "pil" in values and values["pil"] is not None: - validated_values.update( - cls.pil_to_data(values["pil"], values["encoding"], values["size"]), - ) - return validated_values - # Process the provided image source - if "path" in provided_fields: - image = PILModule.open(values["path"]).convert("RGB") - validated_values["path"] = values["path"] - validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"])) - - elif "array" in provided_fields: - image = PILModule.fromarray(values["array"]).convert("RGB") - validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"])) - - elif "pil" in provided_fields: - validated_values.update( - cls.pil_to_data(values["pil"], validated_values["encoding"], validated_values["size"]), - ) - - elif "base64" in provided_fields: - validated_values.update( - cls.from_base64(values["base64"], validated_values["encoding"], validated_values["size"]), - ) - - elif "url" in provided_fields: + if "url" in provided_fields: url_path = urlparse(values["url"]).path file_extension = ( Path(url_path).suffix[1:].lower() if Path(url_path).suffix else validated_values["encoding"] ) validated_values["encoding"] = file_extension - validated_values["url"] = values["url"] - image = cls.load_url(values["url"]) - if image is None: - validated_values["array"] = np.zeros((224, 224, 3), dtype=np.uint8) - validated_values["size"] = (224, 224) - return validated_values - - validated_values.update(cls.pil_to_data(image, file_extension, validated_values["size"])) - validated_values["url"] = values["url"] - - elif "size" in values and values["size"] is not None: - array = np.zeros((values["size"][0], values["size"][1], 3), dtype=np.uint8) - image = PILModule.fromarray(array).convert("RGB") - validated_values.update(cls.pil_to_data(image, validated_values["encoding"], validated_values["size"])) - if any(validated_values[k] is None for k in ["array", "base64", "pil", "url"]): - logging.warning( - f"Failed to validate image data. Could only fetch {[k for k in validated_values if validated_values[k] is not None]}", - ) return validated_values def save(self, path: str, encoding: str | None = None, quality: int = 10) -> None: diff --git a/tests/test_senses.py b/tests/test_senses.py index 41248a6..50e5e91 100644 --- a/tests/test_senses.py +++ b/tests/test_senses.py @@ -139,6 +139,30 @@ def test_image_model_dump_load_with_base64(): reconstructed_img = Image.model_validate_json(json) assert np.array_equal(reconstructed_img.array, array) + +def test_lazy_loading(): + # Create an image with a path + image_path = "resources/bridge_example.jpeg" + img = Image(image_path) + + # Test that attributes are lazily loaded + assert img._array is None + assert img._base64 is None + assert img._size is None + assert img._url is None + + # Access size, which should trigger lazy loading + assert img.size is not None + assert img._size is not None + + # Access array and base64 to ensure they are also lazily loaded + assert img.array is not None + assert img._array is not None + assert img.base64 is not None + assert img._base64 is not None + assert img.url is not None + assert img._url is not None + if __name__ == "__main__": pytest.main([__file__, "-vv"])