Skip to content

Commit 62460e9

Browse files
Merge branch 'main' into feature/try-to-beat-the-limitation-of-ee-in-terms-of-singular-elements-pushed-into-batch-inputs
2 parents 831583a + 53012e0 commit 62460e9

File tree

7 files changed

+84
-3
lines changed

7 files changed

+84
-3
lines changed

inference/enterprise/workflows/enterprise_blocks/sinks/PLC_modbus/v1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ class ModbusTCPBlockManifest(WorkflowBlockManifest):
4646
"long_description": LONG_DESCRIPTION,
4747
"license": "Apache-2.0",
4848
"block_type": "analytics",
49+
"ui_manifest": {
50+
"section": "industrial",
51+
"icon": "fal fa-network-wired",
52+
"blockPriority": 14,
53+
"enterprise_only": True,
54+
"local_only": True,
55+
},
4956
}
5057
)
5158

inference/enterprise/workflows/enterprise_blocks/sinks/PLCethernetIP/v1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class PLCBlockManifest(WorkflowBlockManifest):
5656
"long_description": LONG_DESCRIPTION,
5757
"license": "Roboflow Enterprise License",
5858
"block_type": "sinks",
59+
"ui_manifest": {
60+
"section": "industrial",
61+
"icon": "fal fa-microchip",
62+
"blockPriority": 13,
63+
"enterprise_only": True,
64+
"local_only": True,
65+
},
5966
}
6067
)
6168

inference/enterprise/workflows/enterprise_blocks/sinks/microsoft_sql_server/v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ class BlockManifest(WorkflowBlockManifest):
130130
"icon": "fal fa-database",
131131
"blockPriority": 3,
132132
"popular": True,
133+
"enterprise_only": True,
134+
"local_only": True,
133135
},
134136
}
135137
)

inference/enterprise/workflows/enterprise_blocks/sinks/mqtt_writer/v1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ class BlockManifest(WorkflowBlockManifest):
4242
"long_description": LONG_DESCRIPTION,
4343
"license": "Roboflow Enterprise License",
4444
"block_type": "sink",
45+
"ui_manifest": {
46+
"section": "industrial",
47+
"icon": "fal fa-network-wired",
48+
"blockPriority": 10,
49+
"enterprise_only": True,
50+
"local_only": True,
51+
},
4552
}
4653
)
4754
type: Literal["mqtt_writer_sink@v1"]

inference/enterprise/workflows/enterprise_blocks/sinks/opc_writer/v1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ class BlockManifest(WorkflowBlockManifest):
106106
"long_description": LONG_DESCRIPTION,
107107
"license": "Roboflow Enterprise License",
108108
"block_type": "sink",
109+
"ui_manifest": {
110+
"section": "industrial",
111+
"icon": "fal fa-industry",
112+
"blockPriority": 11,
113+
"enterprise_only": True,
114+
"local_only": True,
115+
},
109116
}
110117
)
111118
type: Literal[BLOCK_TYPE]

inference/models/owlv2/owlv2.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def infer_from_embed(
564564
query_embeddings: Dict[str, PosNegDictType],
565565
confidence: float,
566566
iou_threshold: float,
567+
max_detections: int = MAX_DETECTIONS,
567568
) -> List[Dict]:
568569
image_embeds = self.get_image_embeds(image_hash)
569570
if image_embeds is None:
@@ -601,6 +602,11 @@ def infer_from_embed(
601602
all_predicted_classes = all_predicted_classes[survival_indices]
602603
all_predicted_scores = all_predicted_scores[survival_indices]
603604

605+
if len(all_predicted_boxes) > max_detections:
606+
all_predicted_boxes = all_predicted_boxes[:max_detections]
607+
all_predicted_classes = all_predicted_classes[:max_detections]
608+
all_predicted_scores = all_predicted_scores[:max_detections]
609+
604610
# move tensors to numpy before returning
605611
all_predicted_boxes = all_predicted_boxes.cpu().numpy()
606612
all_predicted_classes = all_predicted_classes.cpu().numpy()
@@ -626,13 +632,18 @@ def infer(
626632
training_data: Dict,
627633
confidence: float = 0.99,
628634
iou_threshold: float = 0.3,
635+
max_detections: int = MAX_DETECTIONS,
629636
**kwargs,
630637
):
631638
class_embeddings_dict = self.make_class_embeddings_dict(
632639
training_data, iou_threshold
633640
)
634641
return self.infer_from_embedding_dict(
635-
image, class_embeddings_dict, confidence, iou_threshold
642+
image,
643+
class_embeddings_dict,
644+
confidence,
645+
iou_threshold,
646+
max_detections=max_detections,
636647
)
637648

638649
def infer_from_embedding_dict(
@@ -641,6 +652,7 @@ def infer_from_embedding_dict(
641652
class_embeddings_dict: Dict[str, PosNegDictType],
642653
confidence: float,
643654
iou_threshold: float,
655+
max_detections: int = MAX_DETECTIONS,
644656
**kwargs,
645657
):
646658
if not isinstance(image, list):
@@ -660,7 +672,11 @@ def infer_from_embedding_dict(
660672
image_hash = self.embed_image(image_wrapper)
661673
image_wrapper.unload_numpy_image()
662674
result = self.infer_from_embed(
663-
image_hash, class_embeddings_dict, confidence, iou_threshold
675+
image_hash,
676+
class_embeddings_dict,
677+
confidence,
678+
iou_threshold,
679+
max_detections=max_detections,
664680
)
665681
results.append(result)
666682
return self.make_response(
@@ -944,14 +960,20 @@ def weights_file(self):
944960
return self.weights_file_path
945961

946962
def infer(
947-
self, image, confidence: float = 0.99, iou_threshold: float = 0.3, **kwargs
963+
self,
964+
image,
965+
confidence: float = 0.99,
966+
iou_threshold: float = 0.3,
967+
max_detections: int = MAX_DETECTIONS,
968+
**kwargs,
948969
):
949970
logger.info(f"Inferring OWLv2 model")
950971
result = self.owlv2.infer_from_embedding_dict(
951972
image,
952973
self.train_data_dict,
953974
confidence=confidence,
954975
iou_threshold=iou_threshold,
976+
max_detections=max_detections,
955977
**kwargs,
956978
)
957979
logger.info(f"OWLv2 model inference complete")
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
from unittest.mock import MagicMock
3+
4+
from inference.models.owlv2 import owlv2
5+
6+
7+
def test_infer_from_embed_respects_max_detections(monkeypatch):
8+
model = owlv2.OwlV2.__new__(owlv2.OwlV2)
9+
image_boxes = torch.tensor(
10+
[[0, 0, 1, 1], [0, 0, 2, 2], [0, 0, 3, 3], [0, 0, 4, 4]],
11+
dtype=torch.float32,
12+
)
13+
image_class_embeds = torch.zeros((4, 2))
14+
model.get_image_embeds = MagicMock(return_value=(None, image_boxes, image_class_embeds, None, None))
15+
16+
def fake_get_class_preds_from_embeds(*args, **kwargs):
17+
boxes = image_boxes
18+
classes = torch.zeros(4, dtype=torch.int64)
19+
scores = torch.tensor([0.9, 0.8, 0.7, 0.6])
20+
return boxes, classes, scores
21+
22+
monkeypatch.setattr(owlv2, "get_class_preds_from_embeds", fake_get_class_preds_from_embeds)
23+
monkeypatch.setattr(owlv2.torchvision.ops, "nms", lambda boxes, scores, iou: torch.arange(boxes.shape[0]))
24+
25+
query_embeddings = {"a": {"positive": torch.zeros((1, 2)), "negative": None}}
26+
predictions = model.infer_from_embed(
27+
"hash", query_embeddings, confidence=0.5, iou_threshold=0.5, max_detections=2
28+
)
29+
assert len(predictions) == 2

0 commit comments

Comments
 (0)