Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions decimer_segmentation/decimer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,22 @@ def download_trained_weights(model_url: str, model_path: str, verbose=1):


def segment_chemical_structures_from_file(
file_path: str, expand: bool = True, **kwargs
) -> List[np.ndarray]:
file_path: str,
expand: bool = True,
return_bboxes: bool = False,
**kwargs,
) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]]:
"""
Segment chemical structures from a PDF or image file.

Args:
file_path: Path to input file (PDF or image)
expand: Whether to expand masks to capture complete structures
return_bboxes: Whether to return bounding boxes along with segments

Returns:
List of segmented chemical structure images as numpy arrays
List of segmented chemical structure images as numpy arrays.
If return_bboxes is True, returns a tuple of (segments, bboxes).
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Input file not found: {file_path}")
Expand All @@ -203,15 +208,19 @@ def segment_chemical_structures_from_file(

if not images:
logger.warning(f"No images could be extracted from {file_path}")
return []
return ([], []) if return_bboxes else []

# Process all images sequentially (model can't parallelize)
all_segments = []
all_bboxes = []
for image in images:
segments = segment_chemical_structures(image, expand)
segments, bboxes = segment_chemical_structures(
image, expand, return_bboxes=True
)
all_segments.extend(segments)
all_bboxes.extend(bboxes)

return all_segments
return (all_segments, all_bboxes) if return_bboxes else all_segments


def _load_images_from_file(file_path: str) -> List[np.ndarray]:
Expand Down Expand Up @@ -288,7 +297,8 @@ def segment_chemical_structures(
return_bboxes: Whether to return bounding boxes along with segments

Returns:
List of segmented structure images, optionally with bounding boxes
List of segmented structure images, optionally with bounding boxes.
If return_bboxes is True, returns a tuple of (segments, bboxes).
"""
if image is None or image.size == 0:
return ([], []) if return_bboxes else []
Expand All @@ -314,11 +324,16 @@ def segment_chemical_structures(
# Sort in reading order and filter empty
if segments:
segments, bboxes = _sort_segments_bboxes(segments, bboxes)
segments = [
s
for s in segments
filtered = [
(s, b)
for s, b in zip(segments, bboxes)
if s is not None and s.size > 0 and s.shape[0] > 0 and s.shape[1] > 0
]
if filtered:
segments, bboxes = zip(*filtered)
segments, bboxes = list(segments), list(bboxes)
else:
segments, bboxes = [], []

return (segments, bboxes) if return_bboxes else segments

Expand Down
27 changes: 27 additions & 0 deletions tests/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,33 @@ def test_full_pipeline_empty_image(self):
# Expected if model not available
pass

def test_segment_from_file_with_return_bboxes(self):
"""Test segment_chemical_structures_from_file returns bboxes when return_bboxes=True."""
from decimer_segmentation import segment_chemical_structures_from_file

image = np.ones((500, 500, 3), dtype=np.uint8) * 255
cv2.rectangle(image, (50, 50), (150, 150), (0, 0, 0), -1)

with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
cv2.imwrite(f.name, image)
try:
result = segment_chemical_structures_from_file(
f.name, expand=False, return_bboxes=True
)
assert isinstance(result, tuple)
assert len(result) == 2
segments, bboxes = result
assert isinstance(segments, list)
assert isinstance(bboxes, list)
assert len(segments) == len(bboxes)
for bbox in bboxes:
assert len(bbox) == 4
except Exception:
# Expected if model not available
pass
finally:
os.unlink(f.name)


# Benchmark tests (optional, for performance monitoring)
class TestPerformance:
Expand Down
Loading