Skip to content

Commit dd9e8fa

Browse files
authored
fix(VisualReplayStrategy): increase default NMS thresh/conf (#758)
* repair VisualReplayStrategy * black/flake8
1 parent 177ea25 commit dd9e8fa

File tree

6 files changed

+53
-12
lines changed

6 files changed

+53
-12
lines changed

experiments/fastsamsom.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ def main() -> None:
2525
image_contrasted.show()
2626

2727
segmentation_adapter = adapters.get_default_segmentation_adapter()
28-
segmented_image = segmentation_adapter.fetch_segmented_image(image)
28+
segmented_image = segmentation_adapter.fetch_segmented_image(
29+
image,
30+
# threshold below which boxes will be filtered out
31+
conf=0,
32+
# discards all overlapping boxes with IoU > iou_threshold
33+
iou=0.05,
34+
)
2935
if DEBUG:
3036
segmented_image.show()
3137

openadapt/adapters/openai.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def prompt(
195195
"""
196196
logger.info(f"system_prompt=\n{system_prompt}")
197197
logger.info(f"prompt=\n{prompt}")
198+
images = images or []
198199
logger.info(f"{len(images)=}")
199200
payload = create_payload(
200201
prompt,

openadapt/adapters/ultralytics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,23 @@
4949
def fetch_segmented_image(
5050
image: Image.Image,
5151
model_name: str = DEFAULT_MODEL_NAME,
52+
**kwargs,
5253
) -> Image.Image:
5354
"""Segment a PIL.Image using ultralytics.
5455
5556
Args:
5657
image: The input image to be segmented.
5758
model_name: The name of the model to use.
59+
kwargs: Arguments to pass to segmentation function.
5860
5961
Returns:
6062
The segmented image as a PIL Image.
6163
"""
6264
assert model_name in MODEL_NAMES, "{model_name=} must be in {MODEL_NAMES=}"
6365
if model_name in FASTSAM_MODEL_NAMES:
64-
return do_fastsam(image, model_name)
66+
return do_fastsam(image, model_name, **kwargs)
6567
else:
66-
return do_sam(image, model_name)
68+
return do_sam(image, model_name, **kwargs)
6769

6870

6971
@cache.cache()
@@ -75,9 +77,9 @@ def do_fastsam(
7577
retina_masks: bool = True,
7678
imgsz: int | tuple[int, int] | None = 1024,
7779
# threshold below which boxes will be filtered out
78-
conf: float = 0,
80+
conf: float = 0.4,
7981
# discards all overlapping boxes with IoU > iou_threshold
80-
iou: float = 0.05,
82+
iou: float = 0.9,
8183
) -> Image:
8284
model = FastSAM(model_name)
8385

openadapt/strategies/visual.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"""
4646

4747
from dataclasses import dataclass
48+
from pprint import pformat
4849
import time
4950

5051
from loguru import logger
@@ -485,6 +486,7 @@ def prompt_for_descriptions(
485486
exceptions=exceptions,
486487
)
487488
logger.info(f"prompt=\n{prompt}")
489+
logger.info(f"{len(images)=}")
488490
descriptions_json = prompt_adapter.prompt(
489491
prompt,
490492
system_prompt,
@@ -498,8 +500,16 @@ def prompt_for_descriptions(
498500
len(masked_images),
499501
)
500502
except Exception as exc:
501-
# TODO XXX
502-
raise exc
503+
exceptions = exceptions or []
504+
exceptions.append(exc)
505+
logger.info(f"exceptions=\n{pformat(exceptions)}")
506+
return prompt_for_descriptions(
507+
original_image,
508+
masked_images,
509+
active_segment_description,
510+
exceptions,
511+
)
512+
503513
# remove indexes
504514
descriptions = [desc for idx, desc in descriptions]
505515
return descriptions

openadapt/window/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919
raise Exception(f"Unsupported platform: {sys.platform}")
2020

2121

22-
def get_active_window_data() -> dict[str, Any] | None:
22+
def get_active_window_data(
23+
include_window_data: bool = config.RECORD_WINDOW_DATA,
24+
) -> dict[str, Any] | None:
2325
"""Get data of the active window.
2426
2527
Returns:
2628
dict or None: A dictionary containing information about the active window,
2729
or None if the state is not available.
2830
"""
29-
state = get_active_window_state(config.RECORD_WINDOW_DATA)
31+
state = get_active_window_state(include_window_data)
3032
if not state:
3133
return None
3234
title = state["title"]

scripts/downloads.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import requests
44
import matplotlib.pyplot as plt
5-
from datetime import datetime
5+
from datetime import datetime, timedelta
66
from pprint import pformat
77

88
import numpy as np
@@ -58,7 +58,8 @@ def fetch_download_data(api_url: str) -> dict:
5858
def plot_downloads(data: dict) -> None:
5959
"""Plots number of downloads and cumulative downloads over time using matplotlib.
6060
61-
Includes total cumulative in the title.
61+
Includes total cumulative in the title and annotates a specific event date with
62+
styled text.
6263
6364
Args:
6465
data (dict): A dictionary with dates as keys and download counts as values.
@@ -82,9 +83,28 @@ def plot_downloads(data: dict) -> None:
8283
color="r",
8384
label="Cumulative Downloads",
8485
)
86+
87+
# Annotation for the release download button addition
88+
event_date = datetime(2024, 5, 9, 2, 46) # Year, Month, Day, Hour, Minute
89+
plt.axvline(x=event_date, color="g", linestyle=":", label="Download Buttons Added")
90+
plt.annotate(
91+
"Download Buttons Added at\nwww.openadapt.ai",
92+
xy=(event_date, plt.ylim()[0] + 100),
93+
xytext=(
94+
event_date - timedelta(days=10),
95+
plt.ylim()[1] * 0.85,
96+
), # Shift left by 10 days
97+
horizontalalignment="center",
98+
fontsize=10,
99+
bbox=dict(
100+
boxstyle="round,pad=0.5", edgecolor="green", facecolor="#ffffcc", alpha=0.9
101+
),
102+
)
103+
85104
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
86105
plt.title(
87-
"Downloads Over Time"
106+
"github.com/OpenAdaptAI/OpenAdapt"
107+
"\nRelease Downloads Over Time"
88108
f"\n(Total Cumulative: {total_cumulative_downloads}) "
89109
f"\n{current_time}"
90110
)

0 commit comments

Comments
 (0)