Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions anylabeling/configs/auto_labeling/models.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
- name: "sam2_1_coreml_large"
display_name: Segment Anything 2.1 (Large) CoreML
download_url: https://huggingface.co/apple/coreml-sam2.1-large
encoder_model_path: SAM2_1LargeImageEncoderFLOAT16.mlpackage
decoder_model_path: SAM2_1LargeMaskDecoderFLOAT16.mlpackage
image_encoder_model_path: SAM2_1LargeImageEncoderFLOAT16.mlpackage
prompt_encoder_model_path: sSAM2_1LargePromptEncoderFLOAT16.mlpackage
input_size: 1024
max_height: 1024
max_width: 1024
type: segment_anything
- name: "sam2_hiera_tiny_20240803"
display_name: Segment Anything 2 (Hiera-Tiny)
download_url: https://huggingface.co/vietanhdev/segment-anything-2-onnx-models/resolve/main/sam2_hiera_tiny.zip
Expand Down
60 changes: 40 additions & 20 deletions anylabeling/services/auto_labeling/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from anylabeling.config import get_config, save_config

import ssl
from huggingface_hub import snapshot_download

ssl._create_default_https_context = (
ssl._create_unverified_context
Expand Down Expand Up @@ -267,22 +268,7 @@ def load_model(self, config_file):
self.model_download_thread.started.connect(self.model_download_worker.run)
self.model_download_thread.start()

def _download_and_extract_model(self, model_config):
"""Download and extract a model from model config"""
config_file = model_config["config_file"]
# Check if model is already downloaded
if not os.path.exists(config_file):
raise ValueError(self.tr("Error in loading config file."))
with open(config_file, "r") as f:
model_config = yaml.safe_load(f)
if model_config.get("has_downloaded", False):
return

# Download model
download_url = model_config.get("download_url", None)
if not download_url:
raise ValueError(self.tr("Missing download_url in config file."))
tmp_dir = tempfile.mkdtemp()
def download_zip(self, tmp_dir, download_url):
zip_model_path = os.path.join(tmp_dir, "model.zip")

# Download url
Expand All @@ -307,21 +293,55 @@ def _progress(count, block_size, total_size):
print(f"Could not download {download_url}: {e}")
self.new_model_status.emit(f"Could not download {download_url}")
return None

# Extract model
tmp_extract_dir = os.path.join(tmp_dir, "extract")
extract_dir = os.path.dirname(config_file)
with zipfile.ZipFile(zip_model_path, "r") as zip_ref:
zip_ref.extractall(tmp_extract_dir)

# Find model folder (containing config.yaml)
# Find model folder (containing config.yaml)
model_folder = None
for root, _, files in os.walk(tmp_extract_dir):
if "config.yaml" in files:
model_folder = root
break
if model_folder is None:
raise ValueError(self.tr("Could not find config.yaml in zip file."))
return model_folder

def download_hf(self, tmp_dir, download_url, model_config):
repo_id = download_url.split('https://huggingface.co/')[-1].strip('/')
tmp_extract_dir = os.path.join(tmp_dir, "extract")
local_dir = snapshot_download(
repo_id=repo_id,
local_dir=tmp_extract_dir # where to store everything
)
with open(tmp_extract_dir + "/config.yaml", "w") as f:
model_config = yaml.dump(model_config, f, default_flow_style=False)
return tmp_extract_dir

def _download_and_extract_model(self, model_config):
"""Download and extract a model from model config"""
config_file = model_config["config_file"]
extract_dir = os.path.dirname(config_file)
# Check if model is already downloaded
if not os.path.exists(config_file):
raise ValueError(self.tr("Error in loading config file."))
with open(config_file, "r") as f:
model_config = yaml.safe_load(f)
if model_config.get("has_downloaded", False):
return

# Download model
download_url = model_config.get("download_url", None)
if not download_url:
raise ValueError(self.tr("Missing download_url in config file."))

tmp_dir = tempfile.mkdtemp()
if download_url.endswith('.zip'):
model_folder = self.download_zip(tmp_dir, download_url)

if download_url.startswith('https://huggingface.co'):
model_folder = self.download_hf(tmp_dir, download_url, model_config)


# Move model folder to correct location
shutil.rmtree(extract_dir)
Expand Down
115 changes: 115 additions & 0 deletions anylabeling/services/auto_labeling/sam2_coreml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import cv2
import numpy as np
import coremltools as ct
from pathlib import Path
from PIL import Image


class SegmentAnything2CoreML:
def __init__(self, model_path: str) -> None:
print("using CoreML", model_path)
image_decoder_path = os.path.join(
model_path, "SAM2_1LargeImageEncoderFLOAT16.mlpackage"
)
mask_decoder_path = os.path.join(
model_path, "SAM2_1LargeMaskDecoderFLOAT16.mlpackage"
)
prompt_encoder_path = os.path.join(
model_path, "SAM2_1LargePromptEncoderFLOAT16.mlpackage"
)
self.image_encoder = ct.models.MLModel(image_decoder_path)
self.mask_decoder = ct.models.MLModel(mask_decoder_path)
self.prompt_encoder = ct.models.MLModel(prompt_encoder_path)
self.input_size = (1024, 1024)

def encode(self, cv_image: np.ndarray) -> dict:
"""Encodes the input image using the image encoder."""
# Convert OpenCV image to PIL Image
pil_image = Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))

# Resize image to input_size
original_size = pil_image.size
resized_image = pil_image.resize(self.input_size, Image.Resampling.LANCZOS)

# Predict image embeddings
embeddings = self.image_encoder.predict({"image": resized_image})

return {
"high_res_feats_0": embeddings["feats_s0"],
"high_res_feats_1": embeddings["feats_s1"],
"image_embedding": embeddings["image_embedding"],
"original_size": original_size,
}

def predict_masks(self, embedding: dict, prompt: list) -> list[np.ndarray]:
"""Predicts masks based on image embedding and prompt."""
points = []
labels = []
for mark in prompt:
if mark["type"] == "point":
# Scale point coordinates to match the model's input size
x_scaled = mark["data"][0] * (
self.input_size[0] / embedding["original_size"][0]
)
y_scaled = mark["data"][1] * (
self.input_size[1] / embedding["original_size"][1]
)
points.append([x_scaled, y_scaled])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
# Scale rectangle coordinates
x1_scaled = mark["data"][0] * (
self.input_size[0] / embedding["original_size"][0]
)
y1_scaled = mark["data"][1] * (
self.input_size[1] / embedding["original_size"][1]
)
x2_scaled = mark["data"][2] * (
self.input_size[0] / embedding["original_size"][0]
)
y2_scaled = mark["data"][3] * (
self.input_size[1] / embedding["original_size"][1]
)
points.append([x1_scaled, y1_scaled])
points.append([x2_scaled, y2_scaled])
labels.append(2) # Label for top-left of box
labels.append(3) # Label for bottom-right of box

points_array = np.array(points, dtype=np.float32).reshape(1, len(points), 2)
labels_array = np.array(labels, dtype=np.int32).reshape(1, len(labels))

# Get prompt embeddings
prompt_embeddings = self.prompt_encoder.predict(
{"points": points_array, "labels": labels_array}
)

# Predict masks
mask_output = self.mask_decoder.predict(
{
"image_embedding": embedding["image_embedding"],
"sparse_embedding": prompt_embeddings["sparse_embeddings"],
"dense_embedding": prompt_embeddings["dense_embeddings"],
"feats_s0": embedding["high_res_feats_0"],
"feats_s1": embedding["high_res_feats_1"],
}
)

# The model returns low_res_masks, which need to be upscaled and thresholded
low_res_masks = mask_output["low_res_masks"]

# Select the best mask based on score
scores = mask_output["scores"]
best_mask_idx = np.argmax(scores)
mask = low_res_masks[0, best_mask_idx] # Assuming batch size of 1

# Resize the mask back to the original image size
original_width, original_height = embedding["original_size"]
mask = cv2.resize(
mask, (original_width, original_height), interpolation=cv2.INTER_LINEAR
)

# Apply threshold to get a binary mask
mask = (mask > 0).astype(np.uint8) * 255 # Convert to 0 or 255

return np.array([mask]) # Return as a list for consistency
11 changes: 7 additions & 4 deletions anylabeling/services/auto_labeling/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .types import AutoLabelingResult
from .sam_onnx import SegmentAnythingONNX
from .sam2_onnx import SegmentAnything2ONNX

from .sam2_coreml import SegmentAnything2CoreML

class SegmentAnything(Model):
"""Segmentation model using SegmentAnything"""
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, config_path, on_message) -> None:
encoder_model_abs_path = self.get_model_abs_path(
self.config, "encoder_model_path"
)
if not encoder_model_abs_path or not os.path.isfile(encoder_model_abs_path):
if not encoder_model_abs_path or not (os.path.isfile(encoder_model_abs_path) or os.path.isdir(encoder_model_abs_path)):
raise FileNotFoundError(
QCoreApplication.translate(
"Model",
Expand All @@ -67,7 +67,7 @@ def __init__(self, config_path, on_message) -> None:
decoder_model_abs_path = self.get_model_abs_path(
self.config, "decoder_model_path"
)
if not decoder_model_abs_path or not os.path.isfile(decoder_model_abs_path):
if not decoder_model_abs_path or not (os.path.isfile(decoder_model_abs_path) or os.path.isdir(decoder_model_abs_path)):
raise FileNotFoundError(
QCoreApplication.translate(
"Model",
Expand All @@ -76,7 +76,10 @@ def __init__(self, config_path, on_message) -> None:
)

# Load models
if self.detect_model_variant(decoder_model_abs_path) == "sam2":
if "coreml" in decoder_model_abs_path:
config_folder = os.path.dirname(decoder_model_abs_path)
self.model = SegmentAnything2CoreML(config_folder)
elif self.detect_model_variant(decoder_model_abs_path) == "sam2":
self.model = SegmentAnything2ONNX(
encoder_model_abs_path, decoder_model_abs_path
)
Expand Down
1 change: 1 addition & 0 deletions requirements-macos.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ onnx==1.16.1
onnxruntime==1.18.1
qimage2ndarray==1.10.0
darkdetect==0.8.0
coremltools==8.3.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_install_requires():
"onnx==1.16.1",
"qimage2ndarray==1.10.0",
"darkdetect==0.8.0",
'coremltools==8.3.0; platform_system == "Darwin"',
]

# Add onnxruntime-gpu if GPU is preferred
Expand Down