Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions swanlab/data/modules/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def __init__(self, data_or_path: InputType, sample_rate: int = 44100, caption: s
"you can install them by `pip install soundfile numpy`"
)
super().__init__()
# Support swanlab.Audio as input (e.g., swanlab.Audio(swanlab.Audio(data)))
if isinstance(data_or_path, Audio):
self.audio_data = data_or_path.audio_data.copy()
self.sample_rate = sample_rate if sample_rate is not None else data_or_path.sample_rate
self.buffer = MediaBuffer()
sf.write(self.buffer, self.audio_data.T, self.sample_rate, format="wav")
self.caption = D.check_caption(caption) if caption is not None else data_or_path.caption
return

if isinstance(data_or_path, str):
# 如果输入为路径字符串
try:
Expand Down
31 changes: 31 additions & 0 deletions swanlab/data/modules/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,37 @@ def __init__(
self.format = self.__convert_file_type(file_type)
self.size = convert_size(size)

# Support swanlab.Image as input (e.g., swanlab.Image(swanlab.Image(fig)))
if isinstance(data_or_path, Image):
self.format = data_or_path.format if file_type is None else self.__convert_file_type(file_type)
self.size = convert_size(size) if size is not None else data_or_path.size

# Handle GIF images (which don't have image_data, only buffer)
if hasattr(data_or_path, 'image_data') and data_or_path.image_data is not None:
base_image = data_or_path.image_data.copy()
self.image_data = self.__resize(base_image, self.size)
self.buffer = MediaBuffer()
self.image_data.save(self.buffer, format=self.format if self.format != "jpg" else "jpeg")
else:
# For images without image_data (e.g., GIF loaded from path), load from buffer
data_or_path.buffer.seek(0) # Reset buffer position
base_image = PILImage.open(data_or_path.buffer)
# If it's a GIF and no size change, just copy the buffer
if self.format == "gif" and (self.size is None or self.size == data_or_path.size):
self.buffer = MediaBuffer()
data_or_path.buffer.seek(0) # Reset again before copying
self.buffer.write(data_or_path.buffer.read())
# Set image_data for size access, even though we use buffer for saving
self.image_data = base_image
else:
# Process and resize the image
self.image_data = self.__resize(base_image, self.size)
self.buffer = MediaBuffer()
self.image_data.save(self.buffer, format=self.format if self.format != "jpg" else "jpeg")

self.caption = D.check_caption(caption) if caption is not None else data_or_path.caption
return

# 如果输入为路径字符串
if isinstance(data_or_path, str):
# 如果文件后缀为gif,则将format设置为gif
Expand Down
23 changes: 19 additions & 4 deletions swanlab/data/modules/object3d/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,25 @@ class Molecule(MediaType):
pdb_data: str
caption: Optional[str] = None

def __post_init__(self):
"""Validates input data after initialization."""
if not isinstance(self.pdb_data, str):
raise TypeError("pdb_data must be a string, use RDKit.Chem.MolToPDBBlock to convert.")
def __init__(self, pdb_data: Union[str, "Molecule"], caption: Optional[str] = None):
"""Initialize Molecule instance.

Args:
pdb_data: PDB data string or another Molecule instance.
caption: Optional descriptive text.
"""
# Initialize MediaType parent class
super().__init__()

# Support swanlab.Molecule as input (e.g., swanlab.Molecule(swanlab.Molecule(...)))
if isinstance(pdb_data, Molecule):
self.pdb_data = pdb_data.pdb_data
self.caption = caption if caption is not None else pdb_data.caption
else:
if not isinstance(pdb_data, str):
raise TypeError("pdb_data must be a string, use RDKit.Chem.MolToPDBBlock to convert.")
self.pdb_data = pdb_data
self.caption = caption

@staticmethod
def check_is_available():
Expand Down
29 changes: 28 additions & 1 deletion swanlab/data/modules/object3d/object3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Object3D:

def __new__(
cls,
data: Union[ndarray, str, Path, Dict, Mol],
data: Union[ndarray, str, Path, Dict, Mol, MediaType],
*,
caption: Optional[str] = None,
**kwargs,
Expand All @@ -102,6 +102,33 @@ def __new__(

kwargs['caption'] = caption

# Support MediaType instances as input (e.g., Object3D(Object3D(...)))
# This includes PointCloud, Model3D, Molecule, etc.
if isinstance(data, MediaType):
# If caption is provided, create a new instance with the new caption
# Otherwise, return the original instance
if caption is not None:
# Handle different MediaType subclasses
if isinstance(data, Molecule):
return Molecule(data.pdb_data, caption=caption)
elif isinstance(data, PointCloud):
# Create a new PointCloud with the same points and boxes, but new caption
return PointCloud(
points=data.points.copy(),
boxes=list(data.boxes) if data.boxes else [],
caption=caption,
step=data.step,
key=data.key,
)
elif isinstance(data, Model3D):
# Create a new Model3D with the same glb_path, but new caption
return Model3D(
glb_path=data.glb_path,
caption=caption,
step=data.step,
)
return data

if isinstance(data, ndarray):
return cls._handle_ndarray(data, **kwargs)

Expand Down
6 changes: 6 additions & 0 deletions swanlab/data/modules/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class Text(MediaType):

def __init__(self, data: Union[str, int, float], caption: str = None):
super().__init__()
# Support swanlab.Text as input (e.g., swanlab.Text(swanlab.Text("hello")))
if isinstance(data, Text):
self.text_data = data.text_data
self.caption = D.check_caption(caption) if caption is not None else data.caption
return

# 处理文本数据

if isinstance(data, str):
Expand Down
12 changes: 11 additions & 1 deletion swanlab/data/modules/video/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@ class Video(MediaType):
caption: The caption of the video.
"""

def __init__(self, data_or_path: str, caption: str = None):
def __init__(self, data_or_path: Union[str, "Video"], caption: str = None):
super().__init__()
# Support swanlab.Video as input (e.g., swanlab.Video(swanlab.Video(path)))
if isinstance(data_or_path, Video):
# Use Image's nested support to create a new Image from the existing one
self._image = Image(
data_or_path._image,
caption=caption if caption is not None else data_or_path._image.caption,
file_type="gif"
)
return

if not data_or_path.endswith(".gif"):
raise ValueError("swanlab.Video only supports gif format file paths")

Expand Down
24 changes: 24 additions & 0 deletions test/unit/data/modules/audio/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,27 @@ def test_audio_fail():
mock = np.random.randn(3, 100000)
with pytest.raises(TypeError):
Audio(data_or_path=mock, sample_rate=44100)


def test_audio_nested():
"""测试Audio类支持嵌套输入(套娃)"""
# 创建基础Audio实例
mock = np.random.randn(2, 100000).astype(np.float32)
base_audio = Audio(data_or_path=mock, sample_rate=44100, caption="original")

# 测试嵌套输入
nested_audio = Audio(base_audio)

# 验证属性被正确复制
assert np.array_equal(nested_audio.audio_data, base_audio.audio_data)
assert nested_audio.sample_rate == base_audio.sample_rate
assert nested_audio.caption == base_audio.caption
assert nested_audio.buffer.getbuffer().nbytes > 0

# 测试可以覆盖caption
nested_with_new_caption = Audio(base_audio, caption="new caption")
assert nested_with_new_caption.caption == "new caption"
assert nested_with_new_caption.sample_rate == base_audio.sample_rate

# 验证原始实例未被修改
assert base_audio.caption == "original"
17 changes: 17 additions & 0 deletions test/unit/data/modules/image/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,20 @@ def test_image_size():

image = Image(mock, size=128)
assert image.image_size == (128, 64)


def test_image_accepts_image_instance():
base = Image(np.random.randint(0, 255, (4, 4, 3), dtype=np.uint8), caption="first")

nested = Image(base)

assert nested.format == base.format
assert nested.caption == base.caption
assert nested.image_data.size == base.image_data.size
assert nested.buffer.getbuffer().nbytes > 0

resized = Image(base, size=2, caption="second")

assert resized.caption == "second"
assert max(resized.image_data.size) <= 2
assert base.image_data.size == (4, 4) # base image remains unchanged
28 changes: 28 additions & 0 deletions test/unit/data/modules/object3d/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,34 @@ def test_get_more(self, mol):
molecule = Molecule.from_mol(mol)
assert molecule.get_more() is None

def test_nested_molecule(self, mol):
"""Tests Molecule class supports nested input (套娃)."""
if mol:
# 创建基础Molecule实例
base_molecule = Molecule.from_mol(mol, caption="original")

# 测试嵌套输入
nested_molecule = Molecule(base_molecule)

# 验证属性被正确复制
assert nested_molecule.pdb_data == base_molecule.pdb_data
assert nested_molecule.caption == base_molecule.caption

# 测试可以覆盖caption
nested_with_new_caption = Molecule(base_molecule, caption="new caption")
assert nested_with_new_caption.pdb_data == base_molecule.pdb_data
assert nested_with_new_caption.caption == "new caption"

# 验证原始实例未被修改
assert base_molecule.caption == "original"

# 测试从PDB字符串创建的嵌套
pdb_string = base_molecule.pdb_data
molecule_from_pdb = Molecule(pdb_string, caption="from pdb")
nested_from_pdb = Molecule(molecule_from_pdb)
assert nested_from_pdb.pdb_data == pdb_string
assert nested_from_pdb.caption == "from pdb"


class TestObject3DWithMolecule:
@pytest.fixture
Expand Down
37 changes: 37 additions & 0 deletions test/unit/data/modules/object3d/test_object3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,40 @@ def test_metadata_handling(self, xyz_points):
obj = Object3D(xyz_points, caption="Test")
assert obj.step is None
assert obj.caption == "Test"

def test_nested_mediatype(self, xyz_points):
"""测试Object3D类支持嵌套MediaType输入(套娃)"""
# 创建基础Object3D实例(PointCloud)
base_obj = Object3D(xyz_points, caption="original")
assert isinstance(base_obj, PointCloud)

# 测试嵌套输入 - 没有提供新caption时,直接返回原实例
nested_obj = Object3D(base_obj)
assert nested_obj is base_obj # 应该返回同一个实例

# 测试可以覆盖caption - 应该创建新实例
nested_with_new_caption = Object3D(base_obj, caption="new caption")
assert nested_with_new_caption.caption == "new caption"
assert nested_with_new_caption is not base_obj # 应该是新实例
assert base_obj.caption == "original" # 原始实例未被修改
# 验证points数据被正确复制
assert np.array_equal(nested_with_new_caption.points, base_obj.points)

# 测试从Molecule创建的嵌套
try:
from rdkit.Chem import MolFromSmiles
mol = MolFromSmiles("CCO")
if mol:
base_molecule = Object3D(mol, caption="molecule original")
nested_molecule = Object3D(base_molecule)
# Object3D应该返回Molecule实例
assert nested_molecule.caption == "molecule original"
assert nested_molecule is base_molecule # 没有新caption时返回原实例

nested_molecule_new_caption = Object3D(base_molecule, caption="molecule new")
assert nested_molecule_new_caption.caption == "molecule new"
assert nested_molecule_new_caption is not base_molecule # 应该是新实例
assert base_molecule.caption == "molecule original" # 原始实例未被修改
except ImportError:
# RDKit不可用时跳过Molecule测试
pass
55 changes: 27 additions & 28 deletions test/unit/data/modules/text/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,30 @@ def test_text_caption():
assert data == "1"
assert buffer is None
assert text.get_more()["caption"] == "test"
# ---------------------------------- int输入 ----------------------------------
mock = 1
text = Text(data=mock, caption="test")
data, buffer = text.parse()
assert data == "1"
assert buffer is None
assert text.get_more()["caption"] == "test"
# ---------------------------------- int输入 ----------------------------------
mock = 1
text = Text(data=mock, caption="test")
data, buffer = text.parse()
assert data == "1"
assert buffer is None
assert text.get_more()["caption"] == "test"
# ---------------------------------- int输入 ----------------------------------
mock = 1
text = Text(data=mock, caption="test")
data, buffer = text.parse()
assert data == "1"
assert buffer is None
assert text.get_more()["caption"] == "test"
# ---------------------------------- int输入 ----------------------------------
mock = 1
text = Text(data=mock, caption="test")
data, buffer = text.parse()
assert data == "1"
assert buffer is None
assert text.get_more()["caption"] == "test"


def test_text_nested():
"""测试Text类支持嵌套输入(套娃)"""
# 创建基础Text实例
base_text = Text(data="Hello World", caption="original")

# 测试嵌套输入
nested_text = Text(base_text)

# 验证属性被正确复制
assert nested_text.text_data == base_text.text_data
assert nested_text.caption == base_text.caption

# 测试可以覆盖caption
nested_with_new_caption = Text(base_text, caption="new caption")
assert nested_with_new_caption.text_data == base_text.text_data
assert nested_with_new_caption.caption == "new caption"

# 验证原始实例未被修改
assert base_text.caption == "original"

# 测试数字类型的嵌套
base_num = Text(data=42, caption="number")
nested_num = Text(base_num)
assert nested_num.text_data == "42"
assert nested_num.caption == "number"
29 changes: 29 additions & 0 deletions test/unit/data/modules/video/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,32 @@ def test_video_inheritance():
# 验证Video有Image的所有必要方法
assert hasattr(video, 'parse')
assert hasattr(video, 'get_more')


def test_video_nested():
"""测试Video类支持嵌套输入(套娃)"""
# 创建基础Video实例
mock_image = PILImage.fromarray(np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8))
path = os.path.join(TEMP_PATH, f"{generate()}.gif")
mock_image.save(path, format="GIF")

base_video = Video(data_or_path=path, caption="original")

# 测试嵌套输入
nested_video = Video(base_video)

# 验证可以正常解析
data, buffer = nested_video.parse()
assert isinstance(data, str)
assert data.endswith(".gif")
assert buffer is not None

# 验证caption被正确复制
assert nested_video.get_more()["caption"] == "original"

# 测试可以覆盖caption
nested_with_new_caption = Video(base_video, caption="new caption")
assert nested_with_new_caption.get_more()["caption"] == "new caption"

# 验证原始实例未被修改
assert base_video.get_more()["caption"] == "original"