-
Notifications
You must be signed in to change notification settings - Fork 20
Open
Description
Hi, when I try to instantiate ScribblePrompt-UNet and make predictions, the segmentation results I get are not accurate. Part of my code is as follows. Is there anything wrong? Thank you very much!
def create_scribble_tensor(positive_coords, negative_coords, H, W):
scribbles = torch.zeros((1, 2, H, W), dtype=torch.float32)
for coord in positive_coords:
x, y = coord
if 0 <= x < W and 0 <= y < H:
scribbles[0, 0, y, x] = 1
for coord in negative_coords:
x, y = coord
if 0 <= x < W and 0 <= y < H:
scribbles[0, 1, y, x] = 1
return scribbles
def binary_mask_to_polygon(binary_mask, tolerance=0):
padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
contours = measure.find_contours(padded_binary_mask, 0.5)
return contours
def parse_coords(coords_str):
coords = list(map(int, coords_str.split(',')))
coords_list = [coords[i:i+2] for i in range(0, len(coords), 2)]
return coords_list
def main(args: argparse.Namespace) -> None:
targets = [args.input_dir]
for index, t in enumerate(targets):
name = os.path.basename(t)
image = pydicom.dcmread(t)
image_w = image.Rows
image_h = image.Columns
image = image.pixel_array
image = (image - image.min()) / (image.max() - image.min())
image = np.clip(image, 0, 1)
image = torch.tensor(image, dtype=torch.float32)
image = image.unsqueeze(0).unsqueeze(0)
image = image.permute(0, 1, 3, 2)
image = F.interpolate(image, size=(128,128), mode='bilinear')
pos_coords = parse_coords(args.pos_box)
neg_coords = parse_coords(args.neg_box)
positive_coords = [tuple(point) for point in pos_coords]
negative_coords = [tuple(point) for point in neg_coords]
scribbles = create_scribble_tensor(positive_coords, negative_coords, image_h, image_w)
scribbles = F.interpolate(scribbles, size=(128,128), mode='bilinear')
sp_unet = ScribblePromptUNet()
mask = sp_unet.predict(image, None, None, scribbles, None, None)
mask = F.interpolate(mask, size=(image_h, image_w), mode='bilinear').squeeze()
mask = mask.cpu().numpy()
binary_mask = (mask > 0.5).astype(int)
contours = binary_mask_to_polygon(binary_mask)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels