From e33d10c13011ac1acef9060de8b2707e6b22773b Mon Sep 17 00:00:00 2001 From: tackhwa <55059307+tackhwa@users.noreply.github.com> Date: Sun, 23 Jun 2024 20:32:57 +0800 Subject: [PATCH 1/4] add new config for mobilenetv3 --- mmseg/configs/_base_/datasets/cityscapes.py | 79 +++++++++++++ mmseg/configs/_base_/models/lraspp_m_v3_d8.py | 43 +++++++ mmseg/configs/mobilenet_v3/README.md | 50 ++++++++ mmseg/configs/mobilenet_v3/metafile.yaml | 109 ++++++++++++++++++ ...d8_lraspp_4xb4_320k_cityscapes_512x1024.py | 20 ++++ ..._s_lraspp_4xb4_320k_cityscapes_512x1024.py | 30 +++++ ...ch_lraspp_4xb4_320k_cityscapes_512x1024.py | 16 +++ ..._s_lraspp_4xb4_320k_cityscapes_512x1024.py | 27 +++++ 8 files changed, 374 insertions(+) create mode 100644 mmseg/configs/_base_/datasets/cityscapes.py create mode 100644 mmseg/configs/_base_/models/lraspp_m_v3_d8.py create mode 100644 mmseg/configs/mobilenet_v3/README.md create mode 100644 mmseg/configs/mobilenet_v3/metafile.yaml create mode 100644 mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py create mode 100644 mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py create mode 100644 mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py create mode 100644 mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py diff --git a/mmseg/configs/_base_/datasets/cityscapes.py b/mmseg/configs/_base_/datasets/cityscapes.py new file mode 100644 index 0000000000..03ddc229a0 --- /dev/null +++ b/mmseg/configs/_base_/datasets/cityscapes.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.transforms.loading import LoadImageFromFile +from mmcv.transforms.processing import (RandomFlip, RandomResize, Resize, + TestTimeAug) +from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler + +from mmseg.datasets.cityscapes import CityscapesDataset +from mmseg.datasets.transforms.formatting import PackSegInputs +from mmseg.datasets.transforms.loading import LoadAnnotations +from mmseg.datasets.transforms.transforms import (PhotoMetricDistortion, + RandomCrop) +from mmseg.evaluation import IoUMetric + +# dataset settings +dataset_type = CityscapesDataset +data_root = 'data/cityscapes/' +crop_size = (512, 1024) +train_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=LoadAnnotations), + dict( + type=RandomResize, + scale=(2048, 1024), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type=RandomCrop, crop_size=crop_size, cat_max_ratio=0.75), + dict(type=RandomFlip, prob=0.5), + dict(type=PhotoMetricDistortion), + dict(type=PackSegInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile), + dict(type=Resize, scale=(2048, 1024), keep_ratio=True), + # add loading annotation after ``Resize`` because ground truth + # does not need to do resize data transform + dict(type=LoadAnnotations), + dict(type=PackSegInputs) +] +img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] +tta_pipeline = [ + dict(type=LoadImageFromFile, backend_args=None), + dict( + type=TestTimeAug, + transforms=[[ + dict(type=Resize, scale_factor=r, keep_ratio=True) + for r in img_ratios + ], + [ + dict(type=RandomFlip, prob=0., direction='horizontal'), + dict(type=RandomFlip, prob=1., direction='horizontal') + ], [dict(type=LoadAnnotations)], + [dict(type=PackSegInputs)]]) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=InfiniteSampler, shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='leftImg8bit/train', seg_map_path='gtFine/train'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict( + img_path='leftImg8bit/val', seg_map_path='gtFine/val'), + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type=IoUMetric, iou_metrics=['mIoU']) +test_evaluator = val_evaluator diff --git a/mmseg/configs/_base_/models/lraspp_m_v3_d8.py b/mmseg/configs/_base_/models/lraspp_m_v3_d8.py new file mode 100644 index 0000000000..22feb75a25 --- /dev/null +++ b/mmseg/configs/_base_/models/lraspp_m_v3_d8.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.nn.modules.activation import ReLU +from torch.nn.modules.batchnorm import SyncBatchNorm as SyncBN + +from mmseg.models.backbones import MobileNetV3 +from mmseg.models.data_preprocessor import SegDataPreProcessor +from mmseg.models.decode_heads import LRASPPHead +from mmseg.models.losses import CrossEntropyLoss +from mmseg.models.segmentors import EncoderDecoder + +# model settings +norm_cfg = dict(type=SyncBN, eps=0.001, requires_grad=True) +data_preprocessor = dict( + type=SegDataPreProcessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255) +model = dict( + type=EncoderDecoder, + data_preprocessor=data_preprocessor, + backbone=dict( + type=MobileNetV3, + arch='large', + out_indices=(1, 3, 16), + norm_cfg=norm_cfg), + decode_head=dict( + type=LRASPPHead, + in_channels=(16, 24, 960), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type=ReLU), + align_corners=False, + loss_decode=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/mmseg/configs/mobilenet_v3/README.md b/mmseg/configs/mobilenet_v3/README.md new file mode 100644 index 0000000000..8ed0a5692a --- /dev/null +++ b/mmseg/configs/mobilenet_v3/README.md @@ -0,0 +1,50 @@ +# MobileNetV3 + +> [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244) + +## Introduction + + + + + +Official Repo + +Code Snippet + +## Abstract + + + +We present the next generation of MobileNets based on a combination of complementary search techniques as well as a novel architecture design. MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware-aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances. This paper starts the exploration of how automated search algorithms and network design can work together to harness complementary approaches improving the overall state of the art. Through this process we create two new MobileNet models for release: MobileNetV3-Large and MobileNetV3-Small which are targeted for high and low resource use cases. These models are then adapted and applied to the tasks of object detection and semantic segmentation. For the task of semantic segmentation (or any dense pixel prediction), we propose a new efficient segmentation decoder Lite Reduced Atrous Spatial Pyramid Pooling (LR-ASPP). We achieve new state of the art results for mobile classification, detection and segmentation. MobileNetV3-Large is 3.2% more accurate on ImageNet classification while reducing latency by 15% compared to MobileNetV2. MobileNetV3-Small is 4.6% more accurate while reducing latency by 5% compared to MobileNetV2. MobileNetV3-Large detection is 25% faster at roughly the same accuracy as MobileNetV2 on COCO detection. MobileNetV3-Large LR-ASPP is 30% faster than MobileNetV2 R-ASPP at similar accuracy for Cityscapes segmentation. + + + +
+ +
+ +## Results and models + +### Cityscapes + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download | +| ------ | ------------------ | --------- | ------: | -------: | -------------- | ------ | ----: | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| LRASPP | M-V3-D8 | 512x1024 | 320000 | 8.9 | 15.22 | V100 | 69.54 | 70.89 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes_20201224_220337-cfe8fb07.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes-20201224_220337.log.json) | +| LRASPP | M-V3-D8 (scratch) | 512x1024 | 320000 | 8.9 | 14.77 | V100 | 67.87 | 69.78 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8-scratch_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes_20201224_220337-9f29cd72.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes-20201224_220337.log.json) | +| LRASPP | M-V3s-D8 | 512x1024 | 320000 | 5.3 | 23.64 | V100 | 64.11 | 66.42 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8-s_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes_20201224_223935-61565b34.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes-20201224_223935.log.json) | +| LRASPP | M-V3s-D8 (scratch) | 512x1024 | 320000 | 5.3 | 24.50 | V100 | 62.74 | 65.01 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8-scratch-s_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes_20201224_223935-03daeabb.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes-20201224_223935.log.json) | + +## Citation + +```bibtex +@inproceedings{Howard_2019_ICCV, + title={Searching for MobileNetV3}, + author={Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and Le, Quoc V. and Adam, Hartwig}, + booktitle={The IEEE International Conference on Computer Vision (ICCV)}, + pages={1314-1324}, + month={October}, + year={2019}, + doi={10.1109/ICCV.2019.00140}} +} +``` diff --git a/mmseg/configs/mobilenet_v3/metafile.yaml b/mmseg/configs/mobilenet_v3/metafile.yaml new file mode 100644 index 0000000000..0351d3b8e4 --- /dev/null +++ b/mmseg/configs/mobilenet_v3/metafile.yaml @@ -0,0 +1,109 @@ +Collections: +- Name: LRASPP + License: Apache License 2.0 + Metadata: + Training Data: + - Cityscapes + Paper: + Title: Searching for MobileNetV3 + URL: https://arxiv.org/abs/1905.02244 + README: configs/mobilenet_v3/README.md + Frameworks: + - PyTorch +Models: +- Name: mobilenet-v3-d8_lraspp_4xb4-320k_cityscapes-512x1024 + In Collection: LRASPP + Results: + Task: Semantic Segmentation + Dataset: Cityscapes + Metrics: + mIoU: 69.54 + mIoU(ms+flip): 70.89 + Config: configs/mobilenet_v3/mobilenet-v3-d8_lraspp_4xb4-320k_cityscapes-512x1024.py + Metadata: + Training Data: Cityscapes + Batch Size: 16 + Architecture: + - M-V3-D8 + - LRASPP + Training Resources: 4x V100 GPUS + Memory (GB): 8.9 + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes_20201224_220337-cfe8fb07.pth + Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes-20201224_220337.log.json + Paper: + Title: Searching for MobileNetV3 + URL: https://arxiv.org/abs/1905.02244 + Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 + Framework: PyTorch +- Name: mobilenet-v3-d8-scratch_lraspp_4xb4-320k_cityscapes-512x1024 + In Collection: LRASPP + Results: + Task: Semantic Segmentation + Dataset: Cityscapes + Metrics: + mIoU: 67.87 + mIoU(ms+flip): 69.78 + Config: configs/mobilenet_v3/mobilenet-v3-d8-scratch_lraspp_4xb4-320k_cityscapes-512x1024.py + Metadata: + Training Data: Cityscapes + Batch Size: 16 + Architecture: + - M-V3-D8 + - LRASPP + Training Resources: 4x V100 GPUS + Memory (GB): 8.9 + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes_20201224_220337-9f29cd72.pth + Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes-20201224_220337.log.json + Paper: + Title: Searching for MobileNetV3 + URL: https://arxiv.org/abs/1905.02244 + Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 + Framework: PyTorch +- Name: mobilenet-v3-d8-s_lraspp_4xb4-320k_cityscapes-512x1024 + In Collection: LRASPP + Results: + Task: Semantic Segmentation + Dataset: Cityscapes + Metrics: + mIoU: 64.11 + mIoU(ms+flip): 66.42 + Config: configs/mobilenet_v3/mobilenet-v3-d8-s_lraspp_4xb4-320k_cityscapes-512x1024.py + Metadata: + Training Data: Cityscapes + Batch Size: 16 + Architecture: + - M-V3s-D8 + - LRASPP + Training Resources: 4x V100 GPUS + Memory (GB): 5.3 + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes_20201224_223935-61565b34.pth + Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes-20201224_223935.log.json + Paper: + Title: Searching for MobileNetV3 + URL: https://arxiv.org/abs/1905.02244 + Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 + Framework: PyTorch +- Name: mobilenet-v3-d8-scratch-s_lraspp_4xb4-320k_cityscapes-512x1024 + In Collection: LRASPP + Results: + Task: Semantic Segmentation + Dataset: Cityscapes + Metrics: + mIoU: 62.74 + mIoU(ms+flip): 65.01 + Config: configs/mobilenet_v3/mobilenet-v3-d8-scratch-s_lraspp_4xb4-320k_cityscapes-512x1024.py + Metadata: + Training Data: Cityscapes + Batch Size: 16 + Architecture: + - M-V3s-D8 + - LRASPP + Training Resources: 4x V100 GPUS + Memory (GB): 5.3 + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes_20201224_223935-03daeabb.pth + Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes-20201224_223935.log.json + Paper: + Title: Searching for MobileNetV3 + URL: https://arxiv.org/abs/1905.02244 + Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 + Framework: PyTorch diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py new file mode 100644 index 0000000000..dc0df8b30a --- /dev/null +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit + +with read_base(): + from .._base_.datasets.cityscapes import * + from .._base_.default_runtime import * + from .._base_.models.lraspp_m_v3_d8 import * + from .._base_.schedules.schedule_320k import * + +checkpoint = 'open-mmlab://contrib/mobilenet_v3_large' +crop_size = (512, 1024) +data_preprocessor.update(size=crop_size) +model.update( + data_preprocessor=data_preprocessor, + backbone=dict(init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint))) +# Re-config the data sampler. +train_dataloader.update(batch_size=4, num_workers=4) +val_dataloader.update(batch_size=1, num_workers=4) +test_dataloader = val_dataloader diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py new file mode 100644 index 0000000000..997a21658c --- /dev/null +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base +from mmengine.model.weight_init import PretrainedInit + +with read_base(): + from .mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024 import * + +checkpoint = 'open-mmlab://contrib/mobilenet_v3_small' +norm_cfg.update(type=SyncBN, eps=0.001, requires_grad=True) +model.update( + type=EncoderDecoder, + backbone=dict( + type=MobileNetV3, + init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint), + arch='small', + out_indices=(0, 1, 12), + norm_cfg=norm_cfg), + decode_head=dict( + type=LRASPPHead, + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type=ReLU), + align_corners=False, + loss_decode=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0))) diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py new file mode 100644 index 0000000000..308b7ff83c --- /dev/null +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .._base_.datasets.cityscapes import * + from .._base_.default_runtime import * + from .._base_.models.lraspp_m_v3_d8 import * + from .._base_.schedules.schedule_320k import * + +crop_size = (512, 1024) +data_preprocessor = dict(size=crop_size) +# Re-config the data sampler. +model.update(data_preprocessor=data_preprocessor) +train_dataloader.update(batch_size=4, num_workers=4) +val_dataloader.update(batch_size=1, num_workers=4) +test_dataloader = val_dataloader diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py new file mode 100644 index 0000000000..9a98853638 --- /dev/null +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from .mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024 import * + +norm_cfg.update(type=SyncBN, eps=0.001, requires_grad=True) +model.update( + type=EncoderDecoder, + backbone=dict( + type=MobileNetV3, + arch='small', + out_indices=(0, 1, 12), + norm_cfg=norm_cfg), + decode_head=dict( + type=LRASPPHead, + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type=ReLU), + align_corners=False, + loss_decode=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0))) From 6aec5627b0f9d863b1c68f122579332b935aa3e9 Mon Sep 17 00:00:00 2001 From: tackhwa <55059307+tackhwa@users.noreply.github.com> Date: Sun, 23 Jun 2024 20:36:16 +0800 Subject: [PATCH 2/4] delete readme --- mmseg/configs/mobilenet_v3/README.md | 50 ----------- mmseg/configs/mobilenet_v3/metafile.yaml | 109 ----------------------- 2 files changed, 159 deletions(-) delete mode 100644 mmseg/configs/mobilenet_v3/README.md delete mode 100644 mmseg/configs/mobilenet_v3/metafile.yaml diff --git a/mmseg/configs/mobilenet_v3/README.md b/mmseg/configs/mobilenet_v3/README.md deleted file mode 100644 index 8ed0a5692a..0000000000 --- a/mmseg/configs/mobilenet_v3/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# MobileNetV3 - -> [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244) - -## Introduction - - - - - -Official Repo - -Code Snippet - -## Abstract - - - -We present the next generation of MobileNets based on a combination of complementary search techniques as well as a novel architecture design. MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware-aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances. This paper starts the exploration of how automated search algorithms and network design can work together to harness complementary approaches improving the overall state of the art. Through this process we create two new MobileNet models for release: MobileNetV3-Large and MobileNetV3-Small which are targeted for high and low resource use cases. These models are then adapted and applied to the tasks of object detection and semantic segmentation. For the task of semantic segmentation (or any dense pixel prediction), we propose a new efficient segmentation decoder Lite Reduced Atrous Spatial Pyramid Pooling (LR-ASPP). We achieve new state of the art results for mobile classification, detection and segmentation. MobileNetV3-Large is 3.2% more accurate on ImageNet classification while reducing latency by 15% compared to MobileNetV2. MobileNetV3-Small is 4.6% more accurate while reducing latency by 5% compared to MobileNetV2. MobileNetV3-Large detection is 25% faster at roughly the same accuracy as MobileNetV2 on COCO detection. MobileNetV3-Large LR-ASPP is 30% faster than MobileNetV2 R-ASPP at similar accuracy for Cityscapes segmentation. - - - -
- -
- -## Results and models - -### Cityscapes - -| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download | -| ------ | ------------------ | --------- | ------: | -------: | -------------- | ------ | ----: | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| LRASPP | M-V3-D8 | 512x1024 | 320000 | 8.9 | 15.22 | V100 | 69.54 | 70.89 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes_20201224_220337-cfe8fb07.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes-20201224_220337.log.json) | -| LRASPP | M-V3-D8 (scratch) | 512x1024 | 320000 | 8.9 | 14.77 | V100 | 67.87 | 69.78 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8-scratch_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes_20201224_220337-9f29cd72.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes-20201224_220337.log.json) | -| LRASPP | M-V3s-D8 | 512x1024 | 320000 | 5.3 | 23.64 | V100 | 64.11 | 66.42 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8-s_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes_20201224_223935-61565b34.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes-20201224_223935.log.json) | -| LRASPP | M-V3s-D8 (scratch) | 512x1024 | 320000 | 5.3 | 24.50 | V100 | 62.74 | 65.01 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/mobilenet_v3/mobilenet-v3-d8-scratch-s_lraspp_4xb4-320k_cityscapes-512x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes_20201224_223935-03daeabb.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes-20201224_223935.log.json) | - -## Citation - -```bibtex -@inproceedings{Howard_2019_ICCV, - title={Searching for MobileNetV3}, - author={Howard, Andrew and Sandler, Mark and Chu, Grace and Chen, Liang-Chieh and Chen, Bo and Tan, Mingxing and Wang, Weijun and Zhu, Yukun and Pang, Ruoming and Vasudevan, Vijay and Le, Quoc V. and Adam, Hartwig}, - booktitle={The IEEE International Conference on Computer Vision (ICCV)}, - pages={1314-1324}, - month={October}, - year={2019}, - doi={10.1109/ICCV.2019.00140}} -} -``` diff --git a/mmseg/configs/mobilenet_v3/metafile.yaml b/mmseg/configs/mobilenet_v3/metafile.yaml deleted file mode 100644 index 0351d3b8e4..0000000000 --- a/mmseg/configs/mobilenet_v3/metafile.yaml +++ /dev/null @@ -1,109 +0,0 @@ -Collections: -- Name: LRASPP - License: Apache License 2.0 - Metadata: - Training Data: - - Cityscapes - Paper: - Title: Searching for MobileNetV3 - URL: https://arxiv.org/abs/1905.02244 - README: configs/mobilenet_v3/README.md - Frameworks: - - PyTorch -Models: -- Name: mobilenet-v3-d8_lraspp_4xb4-320k_cityscapes-512x1024 - In Collection: LRASPP - Results: - Task: Semantic Segmentation - Dataset: Cityscapes - Metrics: - mIoU: 69.54 - mIoU(ms+flip): 70.89 - Config: configs/mobilenet_v3/mobilenet-v3-d8_lraspp_4xb4-320k_cityscapes-512x1024.py - Metadata: - Training Data: Cityscapes - Batch Size: 16 - Architecture: - - M-V3-D8 - - LRASPP - Training Resources: 4x V100 GPUS - Memory (GB): 8.9 - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes_20201224_220337-cfe8fb07.pth - Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_512x1024_320k_cityscapes/lraspp_m-v3-d8_512x1024_320k_cityscapes-20201224_220337.log.json - Paper: - Title: Searching for MobileNetV3 - URL: https://arxiv.org/abs/1905.02244 - Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 - Framework: PyTorch -- Name: mobilenet-v3-d8-scratch_lraspp_4xb4-320k_cityscapes-512x1024 - In Collection: LRASPP - Results: - Task: Semantic Segmentation - Dataset: Cityscapes - Metrics: - mIoU: 67.87 - mIoU(ms+flip): 69.78 - Config: configs/mobilenet_v3/mobilenet-v3-d8-scratch_lraspp_4xb4-320k_cityscapes-512x1024.py - Metadata: - Training Data: Cityscapes - Batch Size: 16 - Architecture: - - M-V3-D8 - - LRASPP - Training Resources: 4x V100 GPUS - Memory (GB): 8.9 - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes_20201224_220337-9f29cd72.pth - Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3-d8_scratch_512x1024_320k_cityscapes-20201224_220337.log.json - Paper: - Title: Searching for MobileNetV3 - URL: https://arxiv.org/abs/1905.02244 - Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 - Framework: PyTorch -- Name: mobilenet-v3-d8-s_lraspp_4xb4-320k_cityscapes-512x1024 - In Collection: LRASPP - Results: - Task: Semantic Segmentation - Dataset: Cityscapes - Metrics: - mIoU: 64.11 - mIoU(ms+flip): 66.42 - Config: configs/mobilenet_v3/mobilenet-v3-d8-s_lraspp_4xb4-320k_cityscapes-512x1024.py - Metadata: - Training Data: Cityscapes - Batch Size: 16 - Architecture: - - M-V3s-D8 - - LRASPP - Training Resources: 4x V100 GPUS - Memory (GB): 5.3 - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes_20201224_223935-61565b34.pth - Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_512x1024_320k_cityscapes/lraspp_m-v3s-d8_512x1024_320k_cityscapes-20201224_223935.log.json - Paper: - Title: Searching for MobileNetV3 - URL: https://arxiv.org/abs/1905.02244 - Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 - Framework: PyTorch -- Name: mobilenet-v3-d8-scratch-s_lraspp_4xb4-320k_cityscapes-512x1024 - In Collection: LRASPP - Results: - Task: Semantic Segmentation - Dataset: Cityscapes - Metrics: - mIoU: 62.74 - mIoU(ms+flip): 65.01 - Config: configs/mobilenet_v3/mobilenet-v3-d8-scratch-s_lraspp_4xb4-320k_cityscapes-512x1024.py - Metadata: - Training Data: Cityscapes - Batch Size: 16 - Architecture: - - M-V3s-D8 - - LRASPP - Training Resources: 4x V100 GPUS - Memory (GB): 5.3 - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes_20201224_223935-03daeabb.pth - Training log: https://download.openmmlab.com/mmsegmentation/v0.5/mobilenet_v3/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes/lraspp_m-v3s-d8_scratch_512x1024_320k_cityscapes-20201224_223935.log.json - Paper: - Title: Searching for MobileNetV3 - URL: https://arxiv.org/abs/1905.02244 - Code: https://github.com/open-mmlab/mmsegmentation/blob/v0.17.0/mmseg/models/backbones/mobilenet_v3.py#L15 - Framework: PyTorch From b92518e1f05d7e0ba1b6dfd37c56886411b831d8 Mon Sep 17 00:00:00 2001 From: tackhwa <55059307+tackhwa@users.noreply.github.com> Date: Tue, 25 Jun 2024 22:46:21 +0800 Subject: [PATCH 3/4] Update mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py --- .../mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py index 997a21658c..5d0d5e25db 100644 --- a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.config import read_base -from mmengine.model.weight_init import PretrainedInit with read_base(): from .mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024 import * From 9db35f4ce3a059dee81f8071be64013beaa00aa6 Mon Sep 17 00:00:00 2001 From: tackhwa <55059307+tackhwa@users.noreply.github.com> Date: Thu, 27 Jun 2024 02:54:25 +0800 Subject: [PATCH 4/4] update --- ...d8_lraspp_4xb4_320k_cityscapes_512x1024.py | 12 +++--- ..._s_lraspp_4xb4_320k_cityscapes_512x1024.py | 43 ++++++++++--------- ...ch_lraspp_4xb4_320k_cityscapes_512x1024.py | 6 +-- ..._s_lraspp_4xb4_320k_cityscapes_512x1024.py | 41 +++++++++--------- 4 files changed, 53 insertions(+), 49 deletions(-) diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py index dc0df8b30a..ca888a5d3c 100644 --- a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -10,11 +10,13 @@ checkpoint = 'open-mmlab://contrib/mobilenet_v3_large' crop_size = (512, 1024) -data_preprocessor.update(size=crop_size) +data_preprocessor.update(dict(size=crop_size)) model.update( - data_preprocessor=data_preprocessor, - backbone=dict(init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint))) + dict( + data_preprocessor=data_preprocessor, + backbone=dict( + init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint)))) # Re-config the data sampler. -train_dataloader.update(batch_size=4, num_workers=4) -val_dataloader.update(batch_size=1, num_workers=4) +train_dataloader.update(dict(batch_size=4, num_workers=4)) +val_dataloader.update(dict(batch_size=1, num_workers=4)) test_dataloader = val_dataloader diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py index 5d0d5e25db..c7054f9e4a 100644 --- a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_s_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -5,25 +5,26 @@ from .mobilenet_v3_d8_lraspp_4xb4_320k_cityscapes_512x1024 import * checkpoint = 'open-mmlab://contrib/mobilenet_v3_small' -norm_cfg.update(type=SyncBN, eps=0.001, requires_grad=True) +norm_cfg.update(dict(type=SyncBN, eps=0.001, requires_grad=True)) model.update( - type=EncoderDecoder, - backbone=dict( - type=MobileNetV3, - init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint), - arch='small', - out_indices=(0, 1, 12), - norm_cfg=norm_cfg), - decode_head=dict( - type=LRASPPHead, - in_channels=(16, 16, 576), - in_index=(0, 1, 2), - channels=128, - input_transform='multiple_select', - dropout_ratio=0.1, - num_classes=19, - norm_cfg=norm_cfg, - act_cfg=dict(type=ReLU), - align_corners=False, - loss_decode=dict( - type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0))) + dict( + type=EncoderDecoder, + backbone=dict( + type=MobileNetV3, + init_cfg=dict(type=PretrainedInit, checkpoint=checkpoint), + arch='small', + out_indices=(0, 1, 12), + norm_cfg=norm_cfg), + decode_head=dict( + type=LRASPPHead, + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type=ReLU), + align_corners=False, + loss_decode=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0)))) diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py index 308b7ff83c..8e8c1a7cfc 100644 --- a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -10,7 +10,7 @@ crop_size = (512, 1024) data_preprocessor = dict(size=crop_size) # Re-config the data sampler. -model.update(data_preprocessor=data_preprocessor) -train_dataloader.update(batch_size=4, num_workers=4) -val_dataloader.update(batch_size=1, num_workers=4) +model.update(dict(data_preprocessor=data_preprocessor)) +train_dataloader.update(dict(batch_size=4, num_workers=4)) +val_dataloader.update(dict(batch_size=1, num_workers=4)) test_dataloader = val_dataloader diff --git a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py index 9a98853638..080400480b 100644 --- a/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py +++ b/mmseg/configs/mobilenet_v3/mobilenet_v3_d8_scratch_s_lraspp_4xb4_320k_cityscapes_512x1024.py @@ -4,24 +4,25 @@ with read_base(): from .mobilenet_v3_d8_scratch_lraspp_4xb4_320k_cityscapes_512x1024 import * -norm_cfg.update(type=SyncBN, eps=0.001, requires_grad=True) +norm_cfg.update(dict(type=SyncBN, eps=0.001, requires_grad=True)) model.update( - type=EncoderDecoder, - backbone=dict( - type=MobileNetV3, - arch='small', - out_indices=(0, 1, 12), - norm_cfg=norm_cfg), - decode_head=dict( - type=LRASPPHead, - in_channels=(16, 16, 576), - in_index=(0, 1, 2), - channels=128, - input_transform='multiple_select', - dropout_ratio=0.1, - num_classes=19, - norm_cfg=norm_cfg, - act_cfg=dict(type=ReLU), - align_corners=False, - loss_decode=dict( - type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0))) + dict( + type=EncoderDecoder, + backbone=dict( + type=MobileNetV3, + arch='small', + out_indices=(0, 1, 12), + norm_cfg=norm_cfg), + decode_head=dict( + type=LRASPPHead, + in_channels=(16, 16, 576), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type=ReLU), + align_corners=False, + loss_decode=dict( + type=CrossEntropyLoss, use_sigmoid=False, loss_weight=1.0))))