Skip to content

Commit e2ab42b

Browse files
authored
Merge pull request #1647 from roboflow/seg-preview-workflow-block
Seg preview workflow block
2 parents bb59e28 + 8bff929 commit e2ab42b

File tree

3 files changed

+314
-0
lines changed

3 files changed

+314
-0
lines changed

inference/core/workflows/core_steps/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@
221221
from inference.core.workflows.core_steps.models.foundation.qwen.v1 import (
222222
Qwen25VLBlockV1,
223223
)
224+
from inference.core.workflows.core_steps.models.foundation.seg_preview.v1 import (
225+
SegPreviewBlockV1,
226+
)
224227
from inference.core.workflows.core_steps.models.foundation.segment_anything2.v1 import (
225228
SegmentAnything2BlockV1,
226229
)
@@ -644,6 +647,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
644647
SIFTComparisonBlockV1,
645648
SIFTComparisonBlockV2,
646649
SegmentAnything2BlockV1,
650+
SegPreviewBlockV1,
647651
StabilityAIInpaintingBlockV1,
648652
StabilityAIImageGenBlockV1,
649653
StabilityAIOutpaintingBlockV1,

inference/core/workflows/core_steps/models/foundation/seg_preview/__init__.py

Whitespace-only changes.
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
from types import SimpleNamespace
2+
from typing import Any, List, Literal, Optional, Type, Union
3+
4+
import numpy as np
5+
import requests
6+
from pydantic import ConfigDict, Field
7+
8+
from inference.core.entities.responses.inference import (
9+
InferenceResponseImage,
10+
InstanceSegmentationInferenceResponse,
11+
InstanceSegmentationPrediction,
12+
Point,
13+
)
14+
from inference.core.env import (
15+
API_BASE_URL,
16+
ROBOFLOW_INTERNAL_SERVICE_NAME,
17+
ROBOFLOW_INTERNAL_SERVICE_SECRET,
18+
)
19+
from inference.core.managers.base import ModelManager
20+
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
21+
from inference.core.workflows.core_steps.common.utils import (
22+
attach_parents_coordinates_to_batch_of_sv_detections,
23+
attach_prediction_type_info_to_sv_detections_batch,
24+
convert_inference_detections_batch_to_sv_detections,
25+
)
26+
from inference.core.workflows.execution_engine.entities.base import (
27+
Batch,
28+
OutputDefinition,
29+
WorkflowImageData,
30+
)
31+
from inference.core.workflows.execution_engine.entities.types import (
32+
FLOAT_KIND,
33+
IMAGE_KIND,
34+
INSTANCE_SEGMENTATION_PREDICTION_KIND,
35+
LIST_OF_VALUES_KIND,
36+
ImageInputField,
37+
Selector,
38+
)
39+
from inference.core.workflows.prototypes.block import (
40+
BlockResult,
41+
WorkflowBlock,
42+
WorkflowBlockManifest,
43+
)
44+
45+
DETECTIONS_CLASS_NAME_FIELD = "class_name"
46+
DETECTION_ID_FIELD = "detection_id"
47+
48+
49+
LONG_DESCRIPTION = "Seg Preview"
50+
51+
52+
class BlockManifest(WorkflowBlockManifest):
53+
model_config = ConfigDict(
54+
json_schema_extra={
55+
"name": "Seg Preview",
56+
"version": "v1",
57+
"short_description": "Seg Preview",
58+
"long_description": LONG_DESCRIPTION,
59+
"license": "Apache-2.0",
60+
"block_type": "model",
61+
"search_keywords": ["Seg Preview"],
62+
"ui_manifest": {
63+
"section": "model",
64+
"icon": "fa-solid fa-eye",
65+
"blockPriority": 9.49,
66+
"needsGPU": True,
67+
"inference": True,
68+
},
69+
},
70+
protected_namespaces=(),
71+
)
72+
73+
type: Literal["roboflow_core/seg-preview@v1"]
74+
75+
images: Selector(kind=[IMAGE_KIND]) = ImageInputField
76+
77+
class_names: Optional[Union[List[str], Selector(kind=[LIST_OF_VALUES_KIND])]] = (
78+
Field(
79+
title="Class Names",
80+
default=None,
81+
description="List of classes to recognise",
82+
examples=[["car", "person"], "$inputs.classes"],
83+
)
84+
)
85+
threshold: Union[Selector(kind=[FLOAT_KIND]), float] = Field(
86+
default=0.5, description="Threshold for predicted mask scores", examples=[0.3]
87+
)
88+
89+
@classmethod
90+
def get_parameters_accepting_batches(cls) -> List[str]:
91+
return ["images", "boxes"]
92+
93+
@classmethod
94+
def describe_outputs(cls) -> List[OutputDefinition]:
95+
return [
96+
OutputDefinition(
97+
name="predictions",
98+
kind=[INSTANCE_SEGMENTATION_PREDICTION_KIND],
99+
),
100+
]
101+
102+
@classmethod
103+
def get_execution_engine_compatibility(cls) -> Optional[str]:
104+
return ">=1.3.0,<2.0.0"
105+
106+
107+
class SegPreviewBlockV1(WorkflowBlock):
108+
109+
def __init__(
110+
self,
111+
model_manager: ModelManager,
112+
api_key: Optional[str],
113+
step_execution_mode: StepExecutionMode,
114+
):
115+
self._model_manager = model_manager
116+
self._api_key = api_key
117+
self._step_execution_mode = step_execution_mode
118+
119+
@classmethod
120+
def get_init_parameters(cls) -> List[str]:
121+
return ["model_manager", "api_key", "step_execution_mode"]
122+
123+
@classmethod
124+
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
125+
return BlockManifest
126+
127+
def run(
128+
self,
129+
images: Batch[WorkflowImageData],
130+
class_names: Optional[List[str]],
131+
threshold: float,
132+
) -> BlockResult:
133+
134+
return self.run_via_request(
135+
images=images,
136+
class_names=class_names,
137+
threshold=threshold,
138+
)
139+
140+
def run_via_request(
141+
self,
142+
images: Batch[WorkflowImageData],
143+
class_names: Optional[List[str]],
144+
threshold: float,
145+
) -> BlockResult:
146+
predictions = []
147+
if class_names is None:
148+
class_names = []
149+
if len(class_names) == 0:
150+
class_names.append(None)
151+
152+
endpoint = f"{API_BASE_URL}/inferenceproxy/seg-preview"
153+
api_key = self._api_key
154+
155+
for single_image in images:
156+
prompt_class_ids: List[Optional[int]] = []
157+
prompt_class_names: List[Optional[str]] = []
158+
prompt_detection_ids: List[Optional[str]] = []
159+
160+
# Build unified prompt list payloads for HTTP
161+
http_prompts: List[dict] = []
162+
for class_name in class_names:
163+
http_prompts.append({"type": "text", "text": class_name})
164+
165+
# Prepare image for remote API (base64)
166+
http_image = {"type": "base64", "value": single_image.base64_image}
167+
168+
payload = {
169+
"image": http_image,
170+
"prompts": http_prompts,
171+
"output_prob_thresh": threshold,
172+
}
173+
174+
try:
175+
headers = {"Content-Type": "application/json"}
176+
if ROBOFLOW_INTERNAL_SERVICE_NAME:
177+
headers["X-Roboflow-Internal-Service-Name"] = (
178+
ROBOFLOW_INTERNAL_SERVICE_NAME
179+
)
180+
if ROBOFLOW_INTERNAL_SERVICE_SECRET:
181+
headers["X-Roboflow-Internal-Service-Secret"] = (
182+
ROBOFLOW_INTERNAL_SERVICE_SECRET
183+
)
184+
185+
response = requests.post(
186+
f"{endpoint}?api_key={api_key}",
187+
json=payload,
188+
headers=headers,
189+
timeout=60,
190+
)
191+
response.raise_for_status()
192+
resp_json = response.json()
193+
except Exception:
194+
resp_json = {"prompt_results": []}
195+
196+
class_predictions: List[InstanceSegmentationPrediction] = []
197+
for prompt_result in resp_json.get("prompt_results", []):
198+
idx = prompt_result.get("prompt_index", 0)
199+
class_name = class_names[idx] if idx < len(class_names) else None
200+
raw_predictions = prompt_result.get("predictions", [])
201+
# Adapt JSON dicts to objects with attribute-style access
202+
adapted_predictions = [SimpleNamespace(**p) for p in raw_predictions]
203+
class_pred = convert_segmentation_response_to_inference_instances_seg_response(
204+
segmentation_predictions=adapted_predictions, # type: ignore[arg-type]
205+
image=single_image,
206+
prompt_class_ids=prompt_class_ids,
207+
prompt_class_names=prompt_class_names,
208+
prompt_detection_ids=prompt_detection_ids,
209+
threshold=threshold,
210+
text_prompt=class_name,
211+
specific_class_id=idx,
212+
)
213+
class_predictions.extend(class_pred.predictions)
214+
215+
image_width = single_image.numpy_image.shape[1]
216+
image_height = single_image.numpy_image.shape[0]
217+
final_inference_prediction = InstanceSegmentationInferenceResponse(
218+
predictions=class_predictions,
219+
image=InferenceResponseImage(width=image_width, height=image_height),
220+
)
221+
predictions.append(final_inference_prediction)
222+
223+
predictions = [
224+
e.model_dump(by_alias=True, exclude_none=True) for e in predictions
225+
]
226+
return self._post_process_result(
227+
images=images,
228+
predictions=predictions,
229+
)
230+
231+
def _post_process_result(
232+
self,
233+
images: Batch[WorkflowImageData],
234+
predictions: List[dict],
235+
) -> BlockResult:
236+
predictions = convert_inference_detections_batch_to_sv_detections(predictions)
237+
predictions = attach_prediction_type_info_to_sv_detections_batch(
238+
predictions=predictions,
239+
prediction_type="instance-segmentation",
240+
)
241+
predictions = attach_parents_coordinates_to_batch_of_sv_detections(
242+
images=images,
243+
predictions=predictions,
244+
)
245+
return [{"predictions": prediction} for prediction in predictions]
246+
247+
248+
def convert_segmentation_response_to_inference_instances_seg_response(
249+
segmentation_predictions: List[Any],
250+
image: WorkflowImageData,
251+
prompt_class_ids: List[Optional[int]],
252+
prompt_class_names: List[Optional[str]],
253+
prompt_detection_ids: List[Optional[str]],
254+
threshold: float,
255+
text_prompt: Optional[str] = None,
256+
specific_class_id: Optional[int] = None,
257+
) -> InstanceSegmentationInferenceResponse:
258+
image_width = image.numpy_image.shape[1]
259+
image_height = image.numpy_image.shape[0]
260+
predictions = []
261+
if len(prompt_class_ids) == 0:
262+
prompt_class_ids = [
263+
specific_class_id if specific_class_id else 0
264+
for _ in range(len(segmentation_predictions))
265+
]
266+
prompt_class_names = [
267+
text_prompt if text_prompt else "foreground"
268+
for _ in range(len(segmentation_predictions))
269+
]
270+
prompt_detection_ids = [None for _ in range(len(segmentation_predictions))]
271+
for prediction, class_id, class_name, detection_id in zip(
272+
segmentation_predictions,
273+
prompt_class_ids,
274+
prompt_class_names,
275+
prompt_detection_ids,
276+
):
277+
for mask in prediction.masks:
278+
if len(mask) < 3:
279+
# skipping empty masks
280+
continue
281+
if prediction.confidence < threshold:
282+
# skipping masks below threshold
283+
continue
284+
x_coords = [coord[0] for coord in mask]
285+
y_coords = [coord[1] for coord in mask]
286+
min_x = np.min(x_coords)
287+
max_x = np.max(x_coords)
288+
min_y = np.min(y_coords)
289+
max_y = np.max(y_coords)
290+
center_x = (min_x + max_x) / 2
291+
center_y = (min_y + max_y) / 2
292+
predictions.append(
293+
InstanceSegmentationPrediction(
294+
**{
295+
"x": center_x,
296+
"y": center_y,
297+
"width": max_x - min_x,
298+
"height": max_y - min_y,
299+
"points": [Point(x=point[0], y=point[1]) for point in mask],
300+
"confidence": prediction.confidence,
301+
"class": class_name,
302+
"class_id": class_id,
303+
"parent_id": detection_id,
304+
}
305+
)
306+
)
307+
return InstanceSegmentationInferenceResponse(
308+
predictions=predictions,
309+
image=InferenceResponseImage(width=image_width, height=image_height),
310+
)

0 commit comments

Comments
 (0)