Skip to content

Commit 7808b8e

Browse files
author
Mark-ZhouWX
committed
update readme and inference demo
1 parent ff74566 commit 7808b8e

File tree

10 files changed

+70
-25
lines changed

10 files changed

+70
-25
lines changed

official/cv/segment-anything/README.md

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ Beside fine-tuning our code on COCO2017 dataset which contains common seen objec
3030
The bellowing shows the mask quality before and after finetune.
3131

3232

33-
| pretrained_model | dataset | epochs | mIOU |
34-
|:----------------:| -------- |:-------------:|------|
35-
| sam-vit-b | COCO2017 | 0 (zero-shot) | 77.4 |
36-
| sam-vit-b | COCO2017 | 20 | 83.5 |
37-
| sam-vit-b | FLARE22 | 0 (zero-shot) | 79.5 |
38-
| sam-vit-b | FLARE22 | 10 | 88.1 |
33+
| pretrained_model | dataset | epochs | mIOU | ckpt |
34+
|:----------------:| -------- |:-------------:|------|--------------------------------------------------------------------------------------------------------------|
35+
| sam-vit-b | COCO2017 | 0 (zero-shot) | 74.5 | |
36+
| sam-vit-b | COCO2017 | 20 | 80.2 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_coco-a9b75828.ckpt) |
37+
| sam-vit-b | FLARE22 | 0 (zero-shot) | 78.6 | |
38+
| sam-vit-b | FLARE22 | 10 | 87.4 | [link](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_box_finetune_flare-ace06cc2.ckpt) |
3939

4040
A machine with **32G ascend memory** is required for box-prompt finetune.
4141

@@ -82,6 +82,38 @@ Here are the examples of segmentation result predicted by box-prompt fine-tuned
8282
<em> FLARE22 image example </em>
8383
</p>
8484

85+
### Finetune with point-prompt
86+
The point in addition to the previous-step-output mask are used as prompt input to predict mask.
87+
We follow an iterative interactive training schedule described in the official SAM paper. First a foreground point is sampled uniformly from the ground truth mask. After making a prediction,
88+
subsequent points are selected uniformly from the error region between the previous mask prediction and the ground truth mask. Each new point is a foreground or background if the error region is a false negative or false positive.
89+
The mask prediction from the previous iteration is used as an additional prompt. In order to encourage the model to benefit from the supplied mask, several more iterations are used where no additional points are sampled.
90+
The total iteration number and the position where mask-only iterations are inserted is configurable.
91+
92+
Since the original training dataset (SA-1B) is almost of common objects, we use a medical imaging segmentation dataset [FLARE22](https://flare22.grand-challenge.org/Dataset/) (preprocess the raw dataset as mentioned in the last chapter) for the finetune experiment.
93+
We note that SAM model express strong zero-shot ability and the finetune process may learn mainly the labelling bias for most downstream datasets.
94+
95+
for standalone finetune of FLARE22 dataset, please run:
96+
```shell
97+
python train.py -c configs/sa1b_point_finetune.yaml
98+
```
99+
100+
for distributed finetune of FLARE22 dataset, please run:
101+
```shell
102+
mpirun --allow-run-as-root -n 4 python train.py -c configs/sa1b_point_finetune.yaml
103+
```
104+
105+
the fine-tuned model will be saved at the work_root specified in `configs/sa1b_point_finetune.yaml`. For a fast single image inference, please run,
106+
107+
```shell
108+
python point_inference.py --checkpoint=your/path/to/ckpt
109+
```
110+
111+
Below is an experimental result batch-prompted with 5 points and the model is trained at scale `vit_b`. The checkpoint can be downloaded [here](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_point_finetune_flare-898ae8f6.ckpt).
112+
<div align="center">
113+
<img alt="img.png" src="images/tumor2_5point.png" width="600"/>
114+
</div>
115+
116+
Explore more interesting applications such as iterative positive and negative points prompting described in the following Demo Chapter.
85117

86118
### Finetune with text-prompt
87119
*Note again that text-to-mask finetune is exploratory and not robust, and the official pytorch code is not release yet.*
@@ -111,14 +143,26 @@ mpirun --allow-run-as-root -n 8 python train.py -c configs/sa1b_text_finetune_bl
111143
the fine-tuned model will be saved at the work_root specified in `configs/sa1b_text_finetune.yaml`. For a fast single image inference, please run,
112144

113145
```shell
114-
python text_inference.py --checkpoint=your/path/to/ckpt
146+
python text_inference.py --checkpoint=your/path/to/ckpt --text-prompt your_prompt
115147
```
116148

117-
Below is an experimental result prompted with `wheels`. _Note that the model is trained with limited data and the smallest SAM type `vit_b`._
149+
Below are some zero-shot experimental result prompted with `floor` and `buildings`. The checkpoint can be downloaded [here](https://download-mindspore.osinfra.cn/toolkits/mindone/sam/sam_vitb_text_finetune_sa1b_10k-972de39e.ckpt). _Note that the model is trained with limited data and the smallest SAM type `vit_b`._
150+
118151
<div align="center">
119-
<img alt="img.png" src="images/blip2-text-prompt-wheel.png" width="600"/>
152+
<img src="images/dengta-floor.png" height="350" />
153+
    
154+
<img src="images/dengta-buildings.png" height="350" />
120155
</div>
121156

157+
<p align="center">
158+
<em> prompt: floor</em>
159+
                      
160+
                      
161+
<em> prompt: buildings </em>
162+
</p>
163+
164+
Try more prompts like `sky` or `trees` etc.
165+
122166
## Demo
123167

124168
First download the weights ([sam_vit_b](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_b-35e4849c.ckpt), [sam_vit_l](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_l-1b460f38.ckpt), [sam_vit_h](https://download.mindspore.cn/toolkits/mindone/sam/sam_vit_h-c72f8ba1.ckpt)) and put them under `${project_root}/models` directory.

official/cv/segment-anything/configs/cloud/sa1b_text_finetune_blip2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ optimizer:
2222
group_param:
2323

2424
lr_scheduler:
25-
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
25+
type: segment_anything.optim.scheduler.sam_dynamic_decay_lr
2626
learning_rate: 8e-6
2727
warmup_steps: 250
2828
decay_steps: [ 60000, 86666 ]

official/cv/segment-anything/configs/cloud/sa1b_text_finetune_clip.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ optimizer:
2222
group_param:
2323

2424
lr_scheduler:
25-
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
25+
type: segment_anything.optim.scheduler.sam_dynamic_decay_lr
2626
learning_rate: 8e-6
2727
warmup_steps: 250
2828
decay_steps: [ 60000, 86666 ]

official/cv/segment-anything/demo/inference_with_prompts.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ def main(args: argparse.Namespace):
5656

5757
def predict_with_point(predictor, image, args: argparse.Namespace):
5858
# predict the first point
59-
input_point = np.array([[500, 375]])
60-
input_label = np.array([1])
59+
input_point1 = np.array([[500, 375]])
60+
input_label1 = np.array([1])
6161

6262
s1 = time.time()
6363
masks, scores, logits = predictor.predict(
64-
point_coords=input_point,
65-
point_labels=input_label,
64+
point_coords=input_point1,
65+
point_labels=input_label1,
6666
multimask_output=True,
6767
)
6868
s2 = time.time()
@@ -73,7 +73,7 @@ def predict_with_point(predictor, image, args: argparse.Namespace):
7373
plt.figure(figsize=(10, 10))
7474
plt.imshow(image)
7575
show_mask(mask, plt.gca())
76-
show_points(input_point, input_label, plt.gca())
76+
show_points(input_point1, input_label1, plt.gca())
7777
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
7878
plt.axis('off')
7979
path = os.path.join(args.output_dir, f'mask_{i+1}.jpg')
@@ -83,15 +83,15 @@ def predict_with_point(predictor, image, args: argparse.Namespace):
8383
plt.show()
8484

8585
# predict the second and third points
86-
input_point = np.array([[500, 375], [1125, 625]])
87-
input_label = np.array([1, 0])
86+
input_point2 = np.array([[500, 375], [1125, 625]])
87+
input_label2 = np.array([1, 0])
8888

8989
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
9090
print(f'mask input shape {mask_input.shape}')
9191
s3 = time.time()
9292
masks, _, _ = predictor.predict(
93-
point_coords=input_point,
94-
point_labels=input_label,
93+
point_coords=input_point2,
94+
point_labels=input_label2,
9595
mask_input=mask_input[None, :, :],
9696
multimask_output=False,
9797
)
@@ -101,7 +101,8 @@ def predict_with_point(predictor, image, args: argparse.Namespace):
101101
plt.figure(figsize=(10, 10))
102102
plt.imshow(image)
103103
show_mask(masks, plt.gca())
104-
show_points(input_point, input_label, plt.gca())
104+
show_points(input_point1, input_label1, plt.gca())
105+
show_points(input_point2, input_label2, plt.gca())
105106
plt.axis('off')
106107
path = os.path.join(args.output_dir, f'two_point.jpg')
107108
print(f'saving mask at {path}')
350 KB
Loading
347 KB
Loading
338 KB
Loading
251 KB
Loading

official/cv/segment-anything/point_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def infer(args):
9090
parser.add_argument(
9191
"--checkpoint",
9292
type=str,
93-
default='./models/sam_vit_b-35e4849c.ckpt',
93+
default='./models/sam_vitb_point_finetune_flare-898ae8f6.ckpt',
9494
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
9595
)
9696

official/cv/segment-anything/text_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def infer(args):
8383

8484
if __name__ == '__main__':
8585
parser = argparse.ArgumentParser(description=("Runs inference on one image"))
86-
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.")
86+
parser.add_argument("--image_path", type=str, default='./images/dengta.jpg', help="Path to an input image.")
8787
parser.add_argument(
8888
"--model-type",
8989
type=str,
@@ -100,14 +100,14 @@ def infer(args):
100100
parser.add_argument(
101101
"--text-prompt",
102102
type=str,
103-
default='wheels',
103+
default='floor',
104104
help="Text prompt",
105105
)
106106

107107
parser.add_argument(
108108
"--checkpoint",
109109
type=str,
110-
default='./models/sam_vit_b-35e4849c.ckpt',
110+
default='./models/sam_vitb_text_finetune_sa1b_10k-972de39e.ckpt',
111111
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
112112
)
113113

0 commit comments

Comments
 (0)