Skip to content

Instantiating ScribblePrompt-UNet #9

@chenyuanjiao342

Description

@chenyuanjiao342

Hi, when I try to instantiate ScribblePrompt-UNet and make predictions, the segmentation results I get are not accurate. Part of my code is as follows. Is there anything wrong? Thank you very much!

def create_scribble_tensor(positive_coords, negative_coords, H, W):
    scribbles = torch.zeros((1, 2, H, W), dtype=torch.float32)
    for coord in positive_coords:
        x, y = coord
        if 0 <= x < W and 0 <= y < H:
            scribbles[0, 0, y, x] = 1

    for coord in negative_coords:
        x, y = coord
        if 0 <= x < W and 0 <= y < H:
            scribbles[0, 1, y, x] = 1
    return scribbles

def binary_mask_to_polygon(binary_mask, tolerance=0):
    padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
    contours = measure.find_contours(padded_binary_mask, 0.5)
    return contours

def parse_coords(coords_str):
    coords = list(map(int, coords_str.split(',')))
    coords_list = [coords[i:i+2] for i in range(0, len(coords), 2)]
    return coords_list

def main(args: argparse.Namespace) -> None:
    targets = [args.input_dir]
    
    for index, t in enumerate(targets):
        name = os.path.basename(t)
        image = pydicom.dcmread(t)
        image_w = image.Rows
        image_h = image.Columns
        image = image.pixel_array
        image = (image - image.min()) / (image.max() - image.min())
        image = np.clip(image, 0, 1)
        image = torch.tensor(image, dtype=torch.float32)
        image = image.unsqueeze(0).unsqueeze(0)
        image = image.permute(0, 1, 3, 2)
        image = F.interpolate(image, size=(128,128), mode='bilinear')
        
        pos_coords = parse_coords(args.pos_box)
        neg_coords = parse_coords(args.neg_box)
        positive_coords = [tuple(point) for point in pos_coords]
        negative_coords = [tuple(point) for point in neg_coords]
        scribbles = create_scribble_tensor(positive_coords, negative_coords, image_h, image_w)
        scribbles = F.interpolate(scribbles, size=(128,128), mode='bilinear') 

        sp_unet = ScribblePromptUNet()
        mask = sp_unet.predict(image, None, None, scribbles, None, None)
        mask = F.interpolate(mask, size=(image_h, image_w), mode='bilinear').squeeze()   
        mask = mask.cpu().numpy()     

        binary_mask = (mask > 0.5).astype(int)
        contours = binary_mask_to_polygon(binary_mask)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions