diff --git a/.gitignore b/.gitignore
index df0976d7e..8dce41f55 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,7 +3,7 @@ __pycache__/
*.py[cod]
*$py.class
**/*.pyc
-
+*.out
# C extensions
*.so
diff --git a/configs/selfsup/_base_/datasets/coco_orl_stage1.py b/configs/selfsup/_base_/datasets/coco_orl_stage1.py
new file mode 100644
index 000000000..6f90e77f4
--- /dev/null
+++ b/configs/selfsup/_base_/datasets/coco_orl_stage1.py
@@ -0,0 +1,57 @@
+import copy
+
+# dataset settings
+dataset_type = 'mmdet.CocoDataset'
+# data_root = 'data/coco/'
+data_root = '../data/coco/'
+file_client_args = dict(backend='disk')
+view_pipeline = [
+ dict(
+ type='RandomResizedCrop',
+ size=224,
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandomApply',
+ transforms=[
+ dict(
+ type='ColorJitter',
+ brightness=0.4,
+ contrast=0.4,
+ saturation=0.2,
+ hue=0.1)
+ ],
+ prob=0.8),
+ dict(
+ type='RandomGrayscale',
+ prob=0.2,
+ keep_channels=True,
+ channel_weights=(0.114, 0.587, 0.2989)),
+ dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1),
+ dict(type='RandomSolarize', prob=0)
+]
+view_pipeline1 = copy.deepcopy(view_pipeline)
+view_pipeline2 = copy.deepcopy(view_pipeline)
+view_pipeline2[4]['prob'] = 0.1 # gaussian blur
+view_pipeline2[5]['prob'] = 0.2 # solarization
+train_pipeline = [
+ dict(type='LoadImageFromFile', file_client_args=file_client_args),
+ dict(
+ type='MultiView',
+ num_views=[1, 1],
+ transforms=[view_pipeline1, view_pipeline2]),
+ dict(type='PackSelfSupInputs', meta_keys=['img_path'])
+]
+train_dataloader = dict(
+ batch_size=64,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='default_collate'),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ ann_file='annotations/instances_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=train_pipeline))
diff --git a/configs/selfsup/_base_/datasets/coco_orl_stage3.py b/configs/selfsup/_base_/datasets/coco_orl_stage3.py
new file mode 100644
index 000000000..b51015f61
--- /dev/null
+++ b/configs/selfsup/_base_/datasets/coco_orl_stage3.py
@@ -0,0 +1,89 @@
+import copy
+
+# dataset settings
+dataset_type = 'ORLDataset'
+meta_json = '../data/coco/meta/train2017_10nn_instance_correspondence.json'
+data_train_root = '../data/coco/train2017'
+# file_client_args = dict(backend='disk')
+view_pipeline = [
+ dict(
+ type='RandomResizedCrop',
+ size=224,
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandomApply',
+ transforms=[
+ dict(
+ type='ColorJitter',
+ brightness=0.4,
+ contrast=0.4,
+ saturation=0.2,
+ hue=0.1)
+ ],
+ prob=0.8),
+ dict(
+ type='RandomGrayscale',
+ prob=0.2,
+ keep_channels=True,
+ channel_weights=(0.114, 0.587, 0.2989),
+ color_format='rgb'),
+ dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.),
+ dict(type='RandomSolarize', prob=0)
+]
+
+view_patch_pipeline = [
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='RandomApply',
+ transforms=[
+ dict(
+ type='ColorJitter',
+ brightness=0.4,
+ contrast=0.4,
+ saturation=0.2,
+ hue=0.1)
+ ],
+ prob=0.8),
+ dict(
+ type='RandomGrayscale',
+ prob=0.2,
+ keep_channels=True,
+ channel_weights=(0.114, 0.587, 0.2989),
+ color_format='rgb'),
+ dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1.),
+ dict(type='RandomSolarize', prob=0)
+]
+view_pipeline1 = copy.deepcopy(view_pipeline)
+view_pipeline2 = copy.deepcopy(view_pipeline)
+view_patch_pipeline1 = copy.deepcopy(view_patch_pipeline)
+view_patch_pipeline2 = copy.deepcopy(view_patch_pipeline)
+view_pipeline2[4]['prob'] = 0.1 # gaussian blur
+view_pipeline2[5]['prob'] = 0.2 # solarization
+view_patch_pipeline1[3]['prob'] = 0.1 # gaussian blur
+view_patch_pipeline2[4]['prob'] = 0.2 # solarization
+
+train_dataloader = dict(
+ batch_size=64,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='default_collate'),
+ dataset=dict(
+ type=dataset_type,
+ root=data_train_root,
+ json_file=meta_json,
+ topk_knn_image=10,
+ img_pipeline1=view_pipeline1,
+ img_pipeline2=view_pipeline2,
+ patch_pipeline1=view_patch_pipeline1,
+ patch_pipeline2=view_patch_pipeline2,
+ patch_size=96,
+ interpolation=2,
+ shift=(-0.5, 0.5),
+ scale=(0.5, 2.),
+ ratio=(0.5, 2.),
+ iou_thr=0.5,
+ attempt_num=200,
+ ))
diff --git a/configs/selfsup/_base_/models/orl.py b/configs/selfsup/_base_/models/orl.py
new file mode 100644
index 000000000..2dccdd71f
--- /dev/null
+++ b/configs/selfsup/_base_/models/orl.py
@@ -0,0 +1,41 @@
+# model settings
+model = dict(
+ type='ORL',
+ base_momentum=0.99,
+ data_preprocessor=dict(
+ mean=(0.485, 0.456, 0.406),
+ std=(0.229, 0.224, 0.225),
+ # mean=(123.675, 116.28, 103.53),
+ # std=(58.395, 57.12, 57.375),
+ bgr_to_rgb=False),
+ pretrained=None,
+ global_loss_weight=1.,
+ loc_intra_loss_weight=1.,
+ loc_inter_loss_weight=1.,
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ in_channels=3,
+ out_indices=[4], # 0: conv-1, x: stage-x
+ norm_cfg=dict(type='SyncBN')),
+ neck=dict(
+ type='NonLinearNeck',
+ in_channels=2048,
+ hid_channels=4096,
+ out_channels=256,
+ num_layers=2,
+ with_bias=False,
+ with_last_bn=False,
+ with_avg_pool=True),
+ head=dict(
+ type='LatentPredictHead',
+ predictor=dict(
+ type='NonLinearNeck',
+ in_channels=256,
+ hid_channels=4096,
+ out_channels=256,
+ num_layers=2,
+ with_bias=False,
+ with_last_bn=False,
+ with_avg_pool=False),
+ loss=dict(type='CosineSimilarityLoss')))
diff --git a/configs/selfsup/orl/README.md b/configs/selfsup/orl/README.md
new file mode 100644
index 000000000..ebb5b2e03
--- /dev/null
+++ b/configs/selfsup/orl/README.md
@@ -0,0 +1,141 @@
+# ORL
+
+> [Unsupervised Object-Level Representation Learning
+> from Scene Images
+> ](https://arxiv.org/abs/2106.11952)
+
+
+
+## Abstract
+
+Contrastive self-supervised learning has largely narrowed the gap to supervised pre-training on ImageNet. However, its success highly relies on the object-centric priors of ImageNet, i.e., different augmented views of the same image correspond to the same object. Such a heavily curated constraint becomes immediately infeasible when pre-trained on more complex scene images with many objects. To overcome this limitation, we introduce Object-level Representation Learning (ORL), a new self-supervised learning framework towards scene images. Our key insight is to leverage image-level self-supervised pre-training as the prior to discover object-level semantic correspondence, thus realizing object-level representation learning from scene images. Extensive experiments on COCO show that ORL significantly improves the performance of self-supervised learning on scene images, even surpassing supervised ImageNet pre-training on several downstream tasks. Furthermore, ORL improves the downstream performance when more unlabeled scene images are available, demonstrating its great potential of harnessing unlabeled data in the wild. We hope our approach can motivate future research on more general-purpose unsupervised representation learning from scene data.
+
+
+

+
+
+## Usage
+
+ORL is mainly composed of three stages.
+, e.g., BYOL. In Stage 2, we first use the pre-trained model to retrieve KNNs for each image in the embedding space to obtain image-level visually similar pairs. We then use unsupervised region proposal algorithms (e.g., selective search) to generate rough RoIs for each image pair. Afterwards, we reuse the pre-trained model to retrieve the top-ranked RoI pairs, i.e., correspondence. We find these pairs of RoIs are almost objects or object parts. In Stage 3, with the corresponding RoI pairs discovered across images, we finally perform object-level contrastive learning using the same architecture as Stage 1.
+
+### Stage 1: Image-level pre-training
+
+In Stage 1, ORL pre-trains an image-level contrastive learning model. In the end of pre-training, it will extract all features in the training set and retrieve KNNs for each image in the embedding space to obtain image-level visually similar pairs.
+
+```shell
+# Train with multiple GPUs
+bash tools/dist_train.sh
+configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py \
+${GPUS} \
+--work-dir work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/
+```
+
+or
+
+```shell
+# Train on cluster managed with slurm
+GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=${GPUS} CPUS_PER_TASK=${CPUS_PER_TASK} \
+bash tools/slurm_train.sh ${PARTITION} ${JOB_NAME} \
+configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py \
+--work-dir work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/
+```
+
+The corresponding KNN image ids will be saved as a json file `train2017_knn_instance.json` under `../data/coco/meta/`.
+
+### Stage 2: Correspondence discovery
+
+- **RoI generation**
+
+ORL applies selective search to generate region proposals for all images in the training set:
+
+```shell
+# Train with single GPU
+bash tools/dist_selective_search_single_gpu.sh
+configs/selfsup/orl/stage2/selective_search.py \
+../data/coco/meta/train2017_selective_search_proposal.json \
+--work-dir work_dirs/selfsup/orl/stage2/selective_search
+```
+
+or
+
+```shell
+# Train on cluster managed with slurm
+GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=1 CPUS_PER_TASK=${CPUS_PER_TASK} \
+bash tools/slurm_selective_search_single_gpu.sh ${PARTITION} \
+configs/selfsup/orl/stage2/selective_search.py \
+../data/coco/meta/train2017_selective_search_proposal.json \
+--work-dir work_dirs/selfsup/orl/stage2/selective_search
+```
+
+The script and config only support single-image single-gpu inference since different images can have different number of generated region proposals by selective search, which cannot be gathered if distributed in multiple gpus. You can also directly download [here](https://drive.google.com/drive/folders/1yYsyGiDjjVSOzIUkhxwO_NitUPLC-An_?usp=sharing) if you want to skip this step.
+
+- **RoI pair retrieval**
+
+ORL reuses the model pre-trained in stage 1 to retrieve the top-ranked RoI pairs, i.e., correspondence.
+
+```shell
+# Train with single GPU
+bash tools/dist_generate_correspondence_single_gpu.sh
+configs/selfsup/orl/stage2/correspondence.py \
+work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/epoch_800.pth \
+../data/coco/meta/train2017_10nn_instance.json \
+../data/coco/meta/train2017_10nn_instance_correspondence.json \
+--work-dir work_dirs/selfsup/orl/stage2/correspondence
+```
+
+or
+
+```shell
+# Train on cluster managed with slurm
+GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=1 CPUS_PER_TASK=${CPUS_PER_TASK} \
+bash tools/slurm_selective_search_single_gpu.sh ${PARTITION} \
+configs/selfsup/orl/stage2/correspondence.py \
+work_dirs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco/epoch_800.pth \
+../data/coco/meta/train2017_10nn_instance.json \
+../data/coco/meta/train2017_10nn_instance_correspondence.json \
+--work-dir work_dirs/selfsup/orl/stage2/correspondence
+```
+
+The script and config also only support single-image single-gpu inference since different image pairs can have different number of generated inter-RoI pairs, which cannot be gathered if distributed in multiple gpus. It will save the final correspondence json file `train2017_knn_instance_correspondence.json` under `../data/coco/meta/`.
+
+### Stage 3: Object-level pre-training
+
+After obtaining the correspondence file in Stage 2, ORL then performs object-level pre-training:
+
+```shell
+# Train with multiple GPUs
+bash tools/dist_train.sh
+configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py \
+${GPUS} \
+--work-dir work_dirs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco/
+```
+
+or
+
+```shell
+# Train on cluster managed with slurm
+GPUS_PER_NODE=${GPUS_PER_NODE} GPUS=${GPUS} CPUS_PER_TASK=${CPUS_PER_TASK} \
+bash tools/slurm_train.sh ${PARTITION} ${JOB_NAME} \
+configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py \
+--work-dir work_dirs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco/
+```
+
+## Models and Benchmarks
+
+Here, we report the Low-shot image classification results of the model, which is pre-trained on COCO train2017, we report mAP for each case across five runs and the details are below:
+
+| Self-Supervised Config | Best Layer | Weight | k=1 | k=2 | k=4 | k=8 | k=16 | k=32 | k=64 | k=96 |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
+| [stage3/orl_resnet50_8xb64-coslr-800e_coco](https://github.com/zhaozh10/mmselfsup/blob/2b14f8b06e4ba2596e90f19e4bac0c13757d80f7/configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py) | feature5 | [Pre-trained](https://drive.google.com/drive/folders/1oWzNZpoN_SPc56Gr-l3AlgGSv8jG1izG?usp=sharing) | 42.25 | 51.81 | 63.46 | 72.16 | 77.86 | 81.17 | 83.73 | 84.59 |
+
+## Citation
+
+```bibtex
+@inproceedings{xie2021unsupervised,
+ title={Unsupervised Object-Level Representation Learning from Scene Images},
+ author={Xie, Jiahao and Zhan, Xiaohang and Liu, Ziwei and Ong, Yew Soon and Loy, Chen Change},
+ booktitle={NeurIPS},
+ year={2021}
+}
+```
diff --git a/configs/selfsup/orl/metafile.yml b/configs/selfsup/orl/metafile.yml
new file mode 100644
index 000000000..da2b64422
--- /dev/null
+++ b/configs/selfsup/orl/metafile.yml
@@ -0,0 +1,26 @@
+Collections:
+ - Name: ORL
+ Metadata:
+ Training Data: COCOtrain2017
+ Training Techniques:
+ - SGD
+ Training Resources: 8x RTX3090 GPUs
+ Architecture:
+ - ResNet50
+ Paper:
+ URL: https://arxiv.org/abs/2106.11952
+ Title: "Unsupervised Object-Level Representation Learning from Scene Images"
+ README: configs/selfsup/ORL/README.md
+Models:
+ - Name: orl_resnet50_8xb64-coslr-800e_coco
+ In Collection: ORL
+ Metadata:
+ Epochs: 800
+ Batch Size: 512
+ Results:
+ - Task: Self-Supervised Low-shot Image Classification
+ Dataset: VOC07
+ Metrics:
+ mAP: 42.25|51.81|63.46|72.16|77.86|81.17|83.73|84.59
+ Config: stage3/orl_resnet50_8xb64-coslr-800e_coco.py
+ Weights: https://drive.google.com/drive/folders/1oWzNZpoN_SPc56Gr-l3AlgGSv8jG1izG?usp=sharing
diff --git a/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py b/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py
new file mode 100644
index 000000000..d990c8258
--- /dev/null
+++ b/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py
@@ -0,0 +1,82 @@
+_base_ = [
+ '../../_base_/models/byol.py',
+ '../../_base_/datasets/coco_orl_stage1.py',
+ '../../_base_/schedules/sgd_coslr-200e_in1k.py',
+ '../../_base_/default_runtime.py',
+]
+# model settings
+model = dict(
+ neck=dict(
+ type='NonLinearNeck',
+ in_channels=2048,
+ hid_channels=4096,
+ out_channels=256,
+ num_layers=2,
+ with_bias=False,
+ with_last_bn=False,
+ with_avg_pool=True),
+ head=dict(
+ type='LatentPredictHead',
+ predictor=dict(
+ type='NonLinearNeck',
+ in_channels=256,
+ hid_channels=4096,
+ out_channels=256,
+ num_layers=2,
+ with_bias=False,
+ with_last_bn=False,
+ with_avg_pool=False),
+ loss=dict(type='CosineSimilarityLoss')))
+
+update_interval = 1 # interval for accumulate gradient
+# Amp optimizer
+optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ optimizer=optimizer,
+ accumulative_counts=update_interval,
+)
+# running setting
+warmup_epochs = 4
+total_epochs = 800
+Knearest = 10
+# learning policy
+param_scheduler = [
+ # warmup
+ dict(
+ type='LinearLR',
+ start_factor=0.0001,
+ by_epoch=True,
+ end=warmup_epochs,
+ # Update the learning rate after every iters.
+ convert_to_iter_based=True),
+ # ConsineAnnealingLR/StepLR/..
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=0.,
+ T_max=total_epochs,
+ by_epoch=True,
+ begin=warmup_epochs,
+ end=total_epochs)
+]
+
+default_hooks = dict(checkpoint=dict(interval=100))
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)
+custom_hooks = [
+ dict(
+ type='ORLHook',
+ keys=Knearest,
+ extract_dataloader=dict(
+ batch_size=512,
+ num_workers=4,
+ persistent_workers=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=True),
+ collate_fn=dict(type='default_collate'),
+ dataset=dict(
+ type={{_base_.dataset_type}},
+ data_root={{_base_.data_root}},
+ ann_file='annotations/instances_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline={{_base_.train_pipeline}})),
+ normalize=True),
+]
diff --git a/configs/selfsup/orl/stage2/correspondence.py b/configs/selfsup/orl/stage2/correspondence.py
new file mode 100644
index 000000000..12a81cd3d
--- /dev/null
+++ b/configs/selfsup/orl/stage2/correspondence.py
@@ -0,0 +1,99 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+dist_params = dict(backend='nccl', port=29500)
+# model settings
+model = dict(
+ type='Correspondence',
+ base_momentum=0.99,
+ pretrained=None,
+ knn_image_num=10,
+ topk_bbox_ratio=0.1,
+ # data_preprocessor=dict(
+ # mean=(123.675, 116.28, 103.53),
+ # std=(58.395, 57.12, 57.375),
+ # bgr_to_rgb=True),
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ in_channels=3,
+ out_indices=[4], # 0: conv-1, x: stage-x
+ norm_cfg=dict(type='SyncBN')),
+ neck=dict(
+ type='NonLinearNeck',
+ in_channels=2048,
+ hid_channels=4096,
+ out_channels=256,
+ num_layers=2,
+ with_bias=False,
+ with_last_bn=False,
+ with_avg_pool=True),
+ head=dict(
+ type='LatentPredictHead',
+ predictor=dict(
+ type='NonLinearNeck',
+ in_channels=256,
+ hid_channels=4096,
+ out_channels=256,
+ num_layers=2,
+ with_bias=False,
+ with_last_bn=False,
+ with_avg_pool=False),
+ loss=dict(type='CosineSimilarityLoss')),
+)
+# dataset settings
+train_knn_json = '../data/coco/meta/train2017_10nn_instance.json'
+train_ss_json = '../data/coco/meta/train2017_selective_search_proposal.json'
+data_train_root = '../data/coco/train2017'
+dataset_type = 'CorrespondDataset'
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+dataset_dict = dict(
+ type=dataset_type,
+ knn_json_file=train_knn_json,
+ ss_json_file=train_ss_json,
+ root=data_train_root,
+ part=0, # [0, num_parts)
+ num_parts=1, # process the whole dataset
+ data_len=118287,
+ # **data_source_cfg),
+ norm_cfg=img_norm_cfg,
+ patch_size=224,
+ min_size=96,
+ max_ratio=3,
+ max_iou_thr=0.5,
+ topN=100,
+ knn_image_num=10,
+ topk_bbox_ratio=0.1)
+val_dataloader = dict(
+ # support single-image single-gpu inference only
+ batch_size=1,
+ num_workers=0,
+ persistent_workers=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ collate_fn=dict(type='default_collate'),
+ dataset=dataset_dict)
+
+# additional hooks
+update_interval = 1 # interval for accumulate gradient
+custom_hooks = [
+ dict(type='BYOLHook', end_momentum=1., update_interval=update_interval)
+]
+# Amp optimizer
+optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ optimizer=optimizer,
+ accumulative_counts=update_interval,
+)
+# learning policy
+lr_config = dict(
+ policy='CosineAnnealing',
+ min_lr=0.,
+ warmup='linear',
+ warmup_iters=4,
+ warmup_ratio=0.0001, # cannot be 0
+ warmup_by_epoch=True)
+checkpoint_config = dict(interval=10)
+# runtime settings
+total_epochs = 800
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)
diff --git a/configs/selfsup/orl/stage2/selective_search.py b/configs/selfsup/orl/stage2/selective_search.py
new file mode 100644
index 000000000..7611b13a9
--- /dev/null
+++ b/configs/selfsup/orl/stage2/selective_search.py
@@ -0,0 +1,25 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+]
+
+# model settings
+model = dict(type='SelectiveSearch')
+dist_params = dict(backend='nccl', port=29500)
+# dataset settings
+data_train_json = '../data/coco/annotations/instances_train2017.json'
+data_train_root = '../data/coco/train2017'
+dataset_type = 'SSDataset'
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=8,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ collate_fn=dict(type='default_collate'),
+ dataset=dict(
+ type=dataset_type,
+ root=data_train_root,
+ json_file=data_train_json,
+ method='fast',
+ min_size=None,
+ max_ratio=None,
+ topN=None))
diff --git a/configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py b/configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py
new file mode 100644
index 000000000..2d1bb65c8
--- /dev/null
+++ b/configs/selfsup/orl/stage3/orl_resnet50_8xb64-coslr-800e_coco.py
@@ -0,0 +1,39 @@
+_base_ = [
+ '../../_base_/models/orl.py',
+ '../../_base_/datasets/coco_orl_stage3.py',
+ '../../_base_/schedules/sgd_coslr-200e_in1k.py',
+ '../../_base_/default_runtime.py',
+]
+update_interval = 1 # interval for accumulate gradient
+# Amp optimizer
+optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9)
+optim_wrapper = dict(
+ type='AmpOptimWrapper',
+ optimizer=optimizer,
+ accumulative_counts=update_interval,
+)
+warmup_epochs = 4
+total_epochs = 800
+# learning policy
+param_scheduler = [
+ # warmup
+ dict(
+ type='LinearLR',
+ start_factor=0.0001,
+ by_epoch=True,
+ end=warmup_epochs,
+ # Update the learning rate after every iters.
+ convert_to_iter_based=True),
+ # ConsineAnnealingLR/StepLR/..
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=0.,
+ T_max=total_epochs,
+ by_epoch=True,
+ begin=warmup_epochs,
+ end=total_epochs)
+]
+
+# runtime settings
+default_hooks = dict(checkpoint=dict(interval=100))
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs)
diff --git a/mmselfsup/datasets/__init__.py b/mmselfsup/datasets/__init__.py
index 1cc27fe04..422f2e36e 100644
--- a/mmselfsup/datasets/__init__.py
+++ b/mmselfsup/datasets/__init__.py
@@ -2,11 +2,12 @@
from .builder import DATASETS, build_dataset
from .deepcluster_dataset import DeepClusterImageNet
from .image_list_dataset import ImageList
+from .orl_dataset import CorrespondDataset, ORLDataset, SSDataset
from .places205 import Places205
from .samplers import * # noqa: F401,F403
from .transforms import * # noqa: F401,F403
__all__ = [
'DATASETS', 'build_dataset', 'Places205', 'DeepClusterImageNet',
- 'ImageList'
+ 'ImageList', 'SSDataset', 'CorrespondDataset', 'ORLDataset'
]
diff --git a/mmselfsup/datasets/orl_dataset.py b/mmselfsup/datasets/orl_dataset.py
new file mode 100644
index 000000000..925987279
--- /dev/null
+++ b/mmselfsup/datasets/orl_dataset.py
@@ -0,0 +1,541 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import os
+import random
+from typing import Union
+
+import cv2
+import mmengine
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+from mmengine import build_from_cfg
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision.transforms import Compose
+
+from mmselfsup.registry import DATASETS, TRANSFORMS
+
+
+def get_max_iou(pred_boxes: list, gt_box: list) -> np.float32:
+ """
+ pred_boxes : multiple coordinate for predict bounding boxes (x, y, w, h)
+ gt_box : the coordinate for ground truth bounding box (x, y, w, h)
+ return : the max iou score about pred_boxes and gt_box
+ """
+ # 1.get the coordinate of inters
+ ixmin = np.maximum(pred_boxes[:, 0], gt_box[0])
+ ixmax = np.minimum(pred_boxes[:, 0] + pred_boxes[:, 2],
+ gt_box[0] + gt_box[2])
+ iymin = np.maximum(pred_boxes[:, 1], gt_box[1])
+ iymax = np.minimum(pred_boxes[:, 1] + pred_boxes[:, 3],
+ gt_box[1] + gt_box[3])
+
+ iw = np.maximum(ixmax - ixmin, 0.)
+ ih = np.maximum(iymax - iymin, 0.)
+
+ # 2. calculate the area of inters
+ inters = iw * ih
+
+ # 3. calculate the area of union
+ uni = (
+ pred_boxes[:, 2] * pred_boxes[:, 3] + gt_box[2] * gt_box[3] - inters)
+
+ # 4. calculate the overlaps and find the max overlap
+ # between pred_boxes and gt_box
+ iou = inters / uni
+ iou_max = np.max(iou)
+
+ return iou_max
+
+
+def correpondence_box_filter(boxes: list,
+ min_size=20,
+ max_ratio=None,
+ topN=None,
+ max_iou_thr=None) -> Union[list, None]:
+ """
+ pred_boxes : multiple coordinate for predict bounding boxes (x, y, w, h)
+ min_size : Filter bboxes smaller than min_size
+ max_ratio: Filter bboxes with too large height/width ratio
+ topN: Remain top-ranked proposals
+ max_iou_thr: Filter proposals with too large overlap
+ return : remained proposals
+ """
+ proposal = []
+
+ for box in boxes:
+ # Calculate width and height of the box
+ w, h = box[2], box[3]
+
+ # Filter for size
+ if min_size:
+ if w < min_size or h < min_size:
+ continue
+
+ # Filter for box ratio
+ if max_ratio:
+ if w / h > max_ratio or h / w > max_ratio:
+ continue
+
+ # Filter for overlap
+ if max_iou_thr:
+ if len(proposal):
+ iou_max = get_max_iou(np.array(proposal), np.array(box))
+ if iou_max > max_iou_thr:
+ continue
+
+ proposal.append(box)
+
+ if not len(proposal): # ensure at least one box for each image
+ proposal.append(boxes[0])
+
+ if topN:
+ if topN <= len(proposal):
+ return proposal[:topN]
+ else:
+ return proposal
+ else:
+ return
+
+
+def selective_search(image, method='fast') -> list:
+ """
+ image : PIL.Image object. Loaded from json file
+ method : the mode of selective_search. default: fast
+ return : bboxes generated by selective search algorithm
+ """
+ # initialize OpenCV's selective search implementation
+ ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()
+ # set the input image
+ ss.setBaseImage(image)
+ # check to see if we are using the *fast* but *less accurate* version
+ # of selective search
+ if method == 'fast':
+ # print("[INFO] using *fast* selective search")
+ ss.switchToSelectiveSearchFast()
+ # otherwise we are using the *slower* but *more accurate* version
+ else:
+ # print("[INFO] using *quality* selective search")
+ ss.switchToSelectiveSearchQuality()
+ # run selective search on the input image
+ boxes = ss.process()
+ return boxes
+
+
+def box_filter(boxes: list, min_size=None, max_ratio=None, topN=None) -> list:
+ """Filter bboxes in stage 2: selective search."""
+ proposal = []
+
+ for box in boxes:
+ # Calculate width and height of the box
+ w, h = box[2], box[3]
+
+ # Filter for size
+ if min_size:
+ if w < min_size or h < min_size:
+ continue
+
+ # Filter for box ratio
+ if max_ratio:
+ if w / h > max_ratio or h / w > max_ratio:
+ continue
+
+ proposal.append(box)
+
+ if topN:
+ if topN <= len(proposal):
+ return proposal[:topN]
+ else:
+ return proposal
+ else:
+ return proposal
+
+
+def get_iou(pred_box: list, gt_box: list) -> float:
+ """
+ pred_box : the coordinate for predict bounding box (x, y, w, h)
+ gt_box : the coordinate for ground truth bounding box (x, y, w, h)
+ return : the iou score
+ """
+ # 1.get the coordinate of inters
+ ixmin = max(pred_box[0], gt_box[0])
+ ixmax = min(pred_box[0] + pred_box[2], gt_box[0] + gt_box[2])
+ iymin = max(pred_box[1], gt_box[1])
+ iymax = min(pred_box[1] + pred_box[3], gt_box[1] + gt_box[3])
+
+ iw = max(ixmax - ixmin, 0.)
+ ih = max(iymax - iymin, 0.)
+
+ # 2. calculate the area of inters
+ inters = iw * ih
+
+ # 3. calculate the area of union
+ uni = (pred_box[2] * pred_box[3] + gt_box[2] * gt_box[3] - inters)
+
+ # 4. calculate the overlaps between pred_box and gt_box
+ iou = inters / float(uni)
+
+ return iou
+
+
+def aug_bbox(img,
+ box: list,
+ shift: tuple,
+ scale: tuple,
+ ratio: tuple,
+ iou_thr: float,
+ attempt_num=200) -> list:
+ """
+ img : PIL.Image object
+ box : the chosen bbox
+ shift :
+ """
+ img_w, img_h = img.size
+ x, y, w, h = box[0], box[1], box[2], box[3]
+ cx, cy = (x + 0.5 * w), (y + 0.5 * h)
+ area = w * h
+ for attempt in range(attempt_num):
+ aug_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aug_ratio = math.exp(random.uniform(*log_ratio))
+ aug_w = int(round(math.sqrt(aug_area * aug_ratio)))
+ aug_h = int(round(math.sqrt(aug_area / aug_ratio)))
+ aug_cx = cx + random.uniform(*shift) * w
+ aug_cy = cy + random.uniform(*shift) * h
+ aug_x, aug_y = int(round(aug_cx - 0.5 * aug_w)), int(
+ round(aug_cy - 0.5 * aug_h))
+ if aug_x >= 0 and aug_y >= 0 and (aug_x + aug_w) <= img_w and (
+ aug_y + aug_h) <= img_h:
+ aug_box = [aug_x, aug_y, aug_w, aug_h]
+ if iou_thr is not None:
+ iou = get_iou(aug_box, box)
+ if iou > iou_thr:
+ return aug_box
+ else:
+ return aug_box
+ return box
+
+
+@DATASETS.register_module()
+class SSDataset(Dataset):
+ """Dataset for generating selective search proposals."""
+
+ def __init__(self,
+ root: str,
+ json_file: str,
+ method='fast',
+ min_size=None,
+ max_ratio=None,
+ topN=None):
+ data = mmengine.load(json_file)
+ self.fns = [item['file_name'] for item in data['images']]
+ self.fns = [os.path.join(root, fn) for fn in self.fns]
+ self.initialized = False
+ self.method = method
+ self.min_size = min_size
+ self.max_ratio = max_ratio
+ self.topN = topN
+
+ def __len__(self):
+ return len(self.fns)
+
+ def __getitem__(self, idx: int) -> dict:
+ img = Image.open(self.fns[idx])
+ img = img.convert('RGB')
+ img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ boxes = selective_search(img_cv2, self.method)
+ if self.topN is not None:
+ boxes = box_filter(boxes, self.min_size, self.max_ratio, self.topN)
+ boxes = torch.from_numpy(np.array(boxes))
+ # bbox: Bx4
+ # B is the total number of original/topN selective search bboxes
+ return dict(bbox=boxes)
+
+
+class CorrespondJson(object):
+ """Load knn_images, proposals and corresponding bboxes based on idx."""
+
+ def __init__(
+ self,
+ root: str,
+ knn_json_file: str,
+ ss_json_file: str,
+ knn_image_num: int,
+ part=0,
+ num_parts=1,
+ data_len=118287,
+ ):
+ assert part in np.arange(num_parts).tolist(), \
+ 'part order must be within [0, num_parts)'
+
+ print('loading knn json file...')
+ data = mmengine.load(knn_json_file)
+ print('loaded knn json file!')
+ print('loading selective search json file, this may take minutes...')
+ if isinstance(ss_json_file, list):
+ data_ss_list = [mmengine.load(ss) for ss in ss_json_file]
+ self.bboxes = []
+ for ss in data_ss_list:
+ self.bboxes += ss['bbox']
+ else:
+ data_ss = mmengine.load(ss_json_file)
+ self.bboxes = data_ss['bbox']
+ print('loaded selective search json file!')
+ # divide the whole dataset into several parts
+ # to enable parallel roi pair retrieval.
+ # each part should be run on single gpu
+ # and all parts can be run on multiple gpus in parallel.
+ part_len = int(data_len / num_parts)
+ print('processing part {}...'.format(part))
+ if part == num_parts - 1: # last part
+ self.fns = data['images']['file_name']
+ self.fns = [os.path.join(root, fn) for fn in self.fns]
+ self.part_fns = self.fns[part * part_len:]
+ self.part_labels = data['pseudo_annotations']['knn_image_id']
+ self.part_bboxes = self.bboxes
+ else:
+ self.fns = data['images']['file_name']
+ self.fns = [os.path.join(root, fn) for fn in self.fns]
+ self.part_fns = self.fns[part * part_len:(part + 1) * part_len]
+ self.part_labels = data['pseudo_annotations']['knn_image_id'][
+ part * part_len:(part + 1) * part_len]
+ self.part_bboxes = self.bboxes[part * part_len:(part + 1) *
+ part_len]
+ self.knn_image_num = knn_image_num
+
+ def get_length(self):
+ return len(self.part_fns)
+
+ def get_sample(self, idx: int):
+ img = Image.open(self.part_fns[idx])
+ img = img.convert('RGB')
+ # load knn images
+ target = self.part_labels[idx][:self.knn_image_num]
+ knn_imgs = [Image.open(self.fns[i]) for i in target]
+ knn_imgs = [knn_img.convert('RGB') for knn_img in knn_imgs]
+ # load selective search proposals
+ bbox = self.part_bboxes[idx]
+ knn_bboxes = [self.bboxes[i] for i in target]
+ return img, knn_imgs, bbox, knn_bboxes
+
+
+@DATASETS.register_module()
+class CorrespondDataset(Dataset):
+ """Dataset for generating corresponding intra- and inter-RoIs."""
+
+ def __init__(
+ self,
+ root: str,
+ knn_json_file: str,
+ ss_json_file: str,
+ part=0,
+ num_parts=1,
+ data_len=118287,
+ norm_cfg=None,
+ patch_size=224,
+ min_size=96,
+ max_ratio=3,
+ topN=100,
+ max_iou_thr=0.5,
+ knn_image_num=10,
+ topk_bbox_ratio=0.1,
+ ):
+
+ self.data_source = CorrespondJson(root, knn_json_file, ss_json_file,
+ knn_image_num, part, num_parts,
+ data_len)
+ self.format_pipeline = Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(norm_cfg['mean'], norm_cfg['std'])
+ ])
+ self.patch_size = patch_size
+ self.min_size = min_size
+ self.max_ratio = max_ratio
+ self.topN = topN
+ self.max_iou_thr = max_iou_thr
+ self.knn_image_num = knn_image_num
+ self.topk_bbox_ratio = topk_bbox_ratio
+
+ def __len__(self):
+ return self.data_source.get_length()
+
+ def __getitem__(self, idx: int) -> dict:
+ img, knn_imgs, box, knn_boxes = self.data_source.get_sample(idx)
+ filtered_box = correpondence_box_filter(box, self.min_size,
+ self.max_ratio, self.topN,
+ self.max_iou_thr)
+ filtered_knn_boxes = [
+ correpondence_box_filter(knn_box, self.min_size, self.max_ratio,
+ self.topN, self.max_iou_thr)
+ for knn_box in knn_boxes
+ ]
+ patch_list = []
+ for x, y, w, h in filtered_box:
+ patch = TF.resized_crop(img, y, x, h, w,
+ (self.patch_size, self.patch_size))
+ patch = self.format_pipeline(patch)
+ patch_list.append(patch)
+ knn_patch_lists = []
+ for k in range(len(knn_imgs)):
+ knn_patch_list = []
+ for x, y, w, h in filtered_knn_boxes[k]:
+ patch = TF.resized_crop(knn_imgs[k], y, x, h, w,
+ (self.patch_size, self.patch_size))
+ patch = self.format_pipeline(patch)
+ knn_patch_list.append(patch)
+ knn_patch_lists.append(torch.stack(knn_patch_list))
+
+ filtered_box = torch.from_numpy(np.array(filtered_box))
+ filtered_knn_boxes = [
+ torch.from_numpy(np.array(knn_box))
+ for knn_box in filtered_knn_boxes
+ ]
+ knn_img_keys = ['{}nn_img'.format(k) for k in range(len(knn_imgs))]
+ knn_bbox_keys = ['{}nn_bbox'.format(k) for k in range(len(knn_imgs))]
+ # img: BCHW, knn_img: K BCHW, bbox: Bx4, knn_bbox= K Bx4
+ # K is the number of knn images, B is the number of filtered bboxes
+ dict1 = dict(img=torch.stack(patch_list))
+ dict2 = dict(bbox=filtered_box)
+ dict3 = dict(img_keys=dict(zip(knn_img_keys, knn_patch_lists)))
+ dict4 = dict(bbox_keys=dict(zip(knn_bbox_keys, filtered_knn_boxes)))
+ return {**dict1, **dict2, **dict3, **dict4}
+
+
+class COCOORLJson(object):
+
+ def __init__(self, root: str, json_file: str, topk_knn_image: int):
+ data = mmengine.load(json_file)
+ self.fns = data['images']['file_name']
+ self.intra_bboxes = data['pseudo_annotations']['bbox']
+ self.total_knn_image_num = data['info']['knn_image_num']
+ self.knn_image_ids = data['pseudo_annotations']['knn_image_id']
+ self.knn_bbox_pairs = data['pseudo_annotations'][
+ 'knn_bbox_pair'] # NxKx(topk_bbox_num)x8
+ self.fns = [os.path.join(root, fn) for fn in self.fns]
+ self.topk_knn_image = topk_knn_image
+ assert self.topk_knn_image <= self.total_knn_image_num, \
+ 'Top-k knn image number exceeds total number of knn images!'
+
+ def get_length(self):
+ return len(self.fns)
+
+ def get_sample(self, idx):
+ # randomly select one knn image
+ rnd = random.randint(0, self.topk_knn_image - 1)
+ target_id = self.knn_image_ids[idx][rnd]
+ img = Image.open(self.fns[idx])
+ knn_img = Image.open(self.fns[target_id])
+ img = img.convert('RGB')
+ knn_img = knn_img.convert('RGB')
+ # load proposals
+ intra_bbox = self.intra_bboxes[idx]
+ knn_bbox = self.knn_bbox_pairs[idx][rnd] # (topk_bbox_num)x8
+ return img, knn_img, intra_bbox, knn_bbox
+
+
+@DATASETS.register_module()
+class ORLDataset(Dataset):
+ """Dataset for ORL."""
+
+ def __init__(self,
+ root: str,
+ json_file: str,
+ topk_knn_image: int,
+ img_pipeline1: list,
+ img_pipeline2: list,
+ patch_pipeline1: list,
+ patch_pipeline2: list,
+ patch_size=224,
+ interpolation: int = 2,
+ shift=(-0.5, 0.5),
+ scale=(0.5, 2.),
+ ratio=(0.5, 2.),
+ iou_thr=0.5,
+ attempt_num=200,
+ prefetch=False):
+ self.data_source = COCOORLJson(root, json_file, topk_knn_image)
+ self.format_pipeline = Compose([
+ transforms.ToTensor(),
+ # transforms.Normalize(img_norm_cfg['mean'], img_norm_cfg['std'])
+ ])
+ img_pipeline1 = [build_from_cfg(p, TRANSFORMS) for p in img_pipeline1]
+ img_pipeline2 = [build_from_cfg(p, TRANSFORMS) for p in img_pipeline2]
+ patch_pipeline1 = [
+ build_from_cfg(p, TRANSFORMS) for p in patch_pipeline1
+ ]
+ patch_pipeline2 = [
+ build_from_cfg(p, TRANSFORMS) for p in patch_pipeline2
+ ]
+ self.img_pipeline1 = Compose(img_pipeline1)
+ self.img_pipeline2 = Compose(img_pipeline2)
+ self.patch_pipeline1 = Compose(patch_pipeline1)
+ self.patch_pipeline2 = Compose(patch_pipeline2)
+ self.patch_size = patch_size
+ self.interpolation = interpolation
+ self.shift = shift
+ self.scale = scale
+ self.ratio = ratio
+ self.iou_thr = iou_thr
+ self.attempt_num = attempt_num
+ self.prefetch = prefetch
+
+ def __len__(self):
+ return self.data_source.get_length()
+
+ def __getitem__(self, idx: list) -> dict:
+ img, knn_img, intra_box, knn_box = self.data_source.get_sample(idx)
+ ibox1 = random.choice(intra_box)
+ ibox2 = aug_bbox(img, ibox1, self.shift, self.scale, self.ratio,
+ self.iou_thr, self.attempt_num)
+ kbox_pair = random.choice(knn_box)
+ kbox1, kbox2 = kbox_pair[:4], kbox_pair[4:]
+ ipatch1 = TF.resized_crop(
+ img,
+ ibox1[1],
+ ibox1[0],
+ ibox1[3],
+ ibox1[2], (self.patch_size, self.patch_size),
+ interpolation=self.interpolation)
+ ipatch2 = TF.resized_crop(
+ img,
+ ibox2[1],
+ ibox2[0],
+ ibox2[3],
+ ibox2[2], (self.patch_size, self.patch_size),
+ interpolation=self.interpolation)
+ kpatch1 = TF.resized_crop(
+ img,
+ kbox1[1],
+ kbox1[0],
+ kbox1[3],
+ kbox1[2], (self.patch_size, self.patch_size),
+ interpolation=self.interpolation)
+ kpatch2 = TF.resized_crop(
+ knn_img,
+ kbox2[1],
+ kbox2[0],
+ kbox2[3],
+ kbox2[2], (self.patch_size, self.patch_size),
+ interpolation=self.interpolation)
+ img1 = self.img_pipeline1({'img': np.array(img)})
+ img2 = self.img_pipeline2({'img': np.array(img)})
+ ipatch1 = self.patch_pipeline1({'img': np.array(ipatch1)})
+ ipatch2 = self.patch_pipeline2({'img': np.array(ipatch2)})
+ kpatch1 = self.patch_pipeline1({'img': np.array(kpatch1)})
+ kpatch2 = self.patch_pipeline2({'img': np.array(kpatch2)})
+ img1 = self.format_pipeline(img1['img'])
+ img2 = self.format_pipeline(img2['img'])
+ ipatch1 = self.format_pipeline(ipatch1['img'])
+ ipatch2 = self.format_pipeline(ipatch2['img'])
+ kpatch1 = self.format_pipeline(kpatch1['img'])
+ kpatch2 = self.format_pipeline(kpatch2['img'])
+
+ assert img1.shape[0] == 3
+ img_cat = torch.cat((img1.unsqueeze(0), img2.unsqueeze(0)), dim=0)
+ ipatch_cat = torch.cat((ipatch1.unsqueeze(0), ipatch2.unsqueeze(0)),
+ dim=0)
+ kpatch_cat = torch.cat((kpatch1.unsqueeze(0), kpatch2.unsqueeze(0)),
+ dim=0)
+ return dict(img=[img_cat, ipatch_cat, kpatch_cat], sample_idx=idx)
diff --git a/mmselfsup/engine/hooks/__init__.py b/mmselfsup/engine/hooks/__init__.py
index 147d254f5..82c82092d 100644
--- a/mmselfsup/engine/hooks/__init__.py
+++ b/mmselfsup/engine/hooks/__init__.py
@@ -2,9 +2,11 @@
from .deepcluster_hook import DeepClusterHook
from .densecl_hook import DenseCLHook
from .odc_hook import ODCHook
+from .orl_hook import ORLHook
from .simsiam_hook import SimSiamHook
from .swav_hook import SwAVHook
__all__ = [
- 'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook'
+ 'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook',
+ 'ORLHook'
]
diff --git a/mmselfsup/engine/hooks/deepcluster_hook.py b/mmselfsup/engine/hooks/deepcluster_hook.py
index 902127fde..d6c862c1b 100644
--- a/mmselfsup/engine/hooks/deepcluster_hook.py
+++ b/mmselfsup/engine/hooks/deepcluster_hook.py
@@ -75,6 +75,7 @@ def deepcluster(self, runner) -> None:
# step 1: get features
runner.model.eval()
features = self.extractor(runner.model.module)
+
runner.model.train()
# step 2: get labels
diff --git a/mmselfsup/engine/hooks/orl_hook.py b/mmselfsup/engine/hooks/orl_hook.py
new file mode 100644
index 000000000..e476c67d0
--- /dev/null
+++ b/mmselfsup/engine/hooks/orl_hook.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import time
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmengine.dist import is_distributed
+from mmengine.hooks import Hook
+from mmengine.logging import print_log
+
+from mmselfsup.models.utils import Extractor
+from mmselfsup.registry import HOOKS
+
+
+@HOOKS.register_module()
+class ORLHook(Hook):
+ """ORL feature extractor hook."""
+
+ def __init__(self,
+ extract_dataloader: dict,
+ keys: int = 10,
+ normalize=True,
+ seed: Optional[int] = None) -> None:
+
+ self.dist_mode = is_distributed()
+ self.keys = keys
+ self.knn_batchsize = extract_dataloader['batch_size']
+ self.dataset = extract_dataloader['dataset']
+ self.extractor = Extractor(
+ extract_dataloader=extract_dataloader,
+ seed=seed,
+ dist_mode=self.dist_mode,
+ pool_cfg=None)
+ self.normalize = normalize
+
+ def retrieve_knn(self, features: torch.Tensor):
+ """
+ retrieve knn image ids for each image in COCO train2017
+ Args:
+ features : Embeddings of all images in COCO train 2017
+ """
+ # load data
+ data_root = self.dataset['data_root']
+ data_ann = self.dataset['ann_file']
+ data_prefix = self.dataset['data_prefix']['img']
+ train_json = data_root + data_ann
+ train_root = data_root + data_prefix
+ # train_json = '../data/coco/annotations/instances_train2017.json'
+ # train_root = '../data/coco/train2017/'
+ with open(train_json, 'r') as json_file:
+ data = json.load(json_file)
+
+ train_fns = [train_root + item['file_name'] for item in data['images']]
+ imgids = [item['id'] for item in data['images']]
+ knn_imgids = []
+ # batch processing
+ batch = self.knn_batchsize
+ keys = self.keys
+
+ feat_bank = features
+ for i in range(0, len(train_fns), batch):
+ print('[INFO] processing batch: {}'.format(i + 1))
+ start = time.time()
+ if (i + batch) < len(train_fns):
+ query_feats = feat_bank[i:i + batch, :]
+ else:
+ query_feats = feat_bank[i:len(train_fns), :]
+ similarity = torch.mm(query_feats, feat_bank.T)
+ I_knn = torch.topk(similarity, keys + 1, dim=1)[1].cpu()
+ I_knn = I_knn[:, 1:] # exclude itself (i.e., 1st nn)
+ knn_list = I_knn.numpy().tolist()
+ [knn_imgids.append(knn) for knn in knn_list]
+ end = time.time()
+ print('[INFO] batch {} took {:.4f} seconds'.format(
+ i + 1, end - start))
+
+ # 118287 for coco, 241690 for coco+
+ num_image = len(train_fns)
+ save_dir = data_root + '/meta/'
+ save_path = save_dir + 'train2017_{}nn_instance.json'.format(keys)
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ assert len(imgids) == len(knn_imgids) == len(train_fns) == num_image, \
+ f'Mismatch number of training images, got: {len(knn_imgids)}'
+ # dict
+ data_new = {}
+ info = {}
+ image_info = {}
+ pseudo_anno = {}
+ info['knn_image_num'] = keys
+ image_info['file_name'] = [
+ item['file_name'] for item in data['images']
+ ]
+ image_info['id'] = [item['id'] for item in data['images']]
+ pseudo_anno['image_id'] = imgids
+ pseudo_anno['knn_image_id'] = knn_imgids
+ data_new['info'] = info
+ data_new['images'] = image_info
+ data_new['pseudo_annotations'] = pseudo_anno
+ with open(save_path, 'w') as f:
+ json.dump(data_new, f)
+ print('[INFO] image-level knn json file has been saved to {}'.format(
+ save_path))
+
+ def after_run(self, runner):
+ self._extract_func(runner)
+
+ def _extract_func(self, runner):
+ # step 1: get features
+ runner.model.eval()
+ features = self.extractor(runner.model.module)
+
+ # step 2: save features
+ if not self.dist_mode or (self.dist_mode and runner.rank == 0):
+ if self.normalize:
+ features = nn.functional.normalize(features['feat'], dim=1)
+ np.save(
+ '{}/feature_epoch_{}.npy'.format(runner.work_dir,
+ runner.epoch),
+ features.cpu().numpy())
+ print_log(
+ 'Feature extraction done!!! total features: {}\t\
+ feature dimension: {}'.format(
+ features.size(0), features.size(1)),
+ logger='current')
+
+ # step3: retrieval knn
+ if runner.rank == 0:
+ self.retrieve_knn(features)
diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py
index 590782430..19c5c46c9 100644
--- a/mmselfsup/models/algorithms/__init__.py
+++ b/mmselfsup/models/algorithms/__init__.py
@@ -15,6 +15,7 @@
from .mocov3 import MoCoV3
from .npid import NPID
from .odc import ODC
+from .orl import ORL, Correspondence, SelectiveSearch
from .relative_loc import RelativeLoc
from .rotation_pred import RotationPred
from .simclr import SimCLR
@@ -23,8 +24,29 @@
from .swav import SwAV
__all__ = [
- 'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
- 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
- 'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA',
- 'MixMIM'
+ 'BaseModel',
+ 'BarlowTwins',
+ 'BEiT',
+ 'BYOL',
+ 'DeepCluster',
+ 'DenseCL',
+ 'MoCo',
+ 'NPID',
+ 'ODC',
+ 'RelativeLoc',
+ 'RotationPred',
+ 'SimCLR',
+ 'SimSiam',
+ 'SwAV',
+ 'MAE',
+ 'MoCoV3',
+ 'SimMIM',
+ 'CAE',
+ 'MaskFeat',
+ 'MILAN',
+ 'EVA',
+ 'MixMIM',
+ 'SelectiveSearch',
+ 'Correspondence',
+ 'ORL',
]
diff --git a/mmselfsup/models/algorithms/orl.py b/mmselfsup/models/algorithms/orl.py
new file mode 100644
index 000000000..acc7b5ed3
--- /dev/null
+++ b/mmselfsup/models/algorithms/orl.py
@@ -0,0 +1,302 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmselfsup.registry import MODELS
+from mmselfsup.structures import SelfSupDataSample
+from ..utils import CosineEMA
+from .base import BaseModel
+
+
+@MODELS.register_module()
+class SelectiveSearch(nn.Module):
+ """Selective-search proposal generation."""
+
+ def __init__(self, **kwargs):
+ super(SelectiveSearch, self).__init__()
+
+ def forward_test(self, bbox, **kwargs):
+ assert bbox.dim() == 3, \
+ 'Input bbox must have 3 dims, got: {}'.format(bbox.dim())
+ # bbox: 1xBx4
+ return dict(bbox=bbox.cpu())
+
+ def forward(self, mode='test', **kwargs):
+ assert mode == 'test', \
+ 'Support test inference mode only, got: {}'.format(mode)
+ return self.forward_test(**kwargs)
+
+
+@MODELS.register_module()
+class Correspondence(BaseModel):
+ """Correspondence discovery in Stage 2 of ORL.
+
+ Args:
+ backbone (dict): Config dict for module of backbone ConvNet.
+ neck (dict): Config dict for module of deep features to
+ compact feature vectors. Default: None.
+ head (dict): Config dict for module of loss functions.
+ Default: None.
+ pretrained (str, optional): Path to pre-trained weights.
+ Default: None.
+ base_momentum (float): The base momentum coefficient for
+ the target network. Default: 0.99.
+ knn_image_num (int): The number of KNN images. Default: 10.
+ topk_bbox_ratio (float): The ratio of retrieved top-ranked RoI pairs.
+ Default: 0.1.
+ """
+
+ def __init__(self,
+ backbone: dict,
+ neck: dict,
+ head: dict,
+ base_momentum: float = 0.99,
+ pretrained: Optional[str] = None,
+ init_cfg: Optional[Union[List[dict], dict]] = None,
+ knn_image_num: int = 10,
+ topk_bbox_ratio: float = 0.1) -> None:
+ # super(Correspondence, self).__init__()
+ super().__init__(
+ backbone=backbone,
+ neck=neck,
+ head=head,
+ pretrained=pretrained,
+ # data_preprocessor=data_preprocessor,
+ init_cfg=init_cfg)
+ # create momentum model
+ self.online_net = nn.Sequential(self.backbone, self.neck)
+ self.target_net = CosineEMA(self.online_net, momentum=base_momentum)
+
+ self.base_momentum = base_momentum
+ self.momentum = base_momentum
+
+ self.knn_image_num = knn_image_num
+ self.topk_bbox_ratio = topk_bbox_ratio
+
+ def predict(self, img: List[torch.Tensor], bbox: List[torch.Tensor],
+ img_keys: dict, bbox_keys: dict, **kwargs) -> dict:
+
+ knn_imgs = [
+ img_keys.get('{}nn_img'.format(k))
+ for k in range(self.knn_image_num)
+ ]
+ knn_bboxes = [
+ bbox_keys.get('{}nn_bbox'.format(k))
+ for k in range(self.knn_image_num)
+ ]
+ assert img.size(0) == 1, \
+ f'batch size must be 1, got: {img.size(0)}'
+ assert img.dim() == 5, \
+ f'img must have 5 dims, got: {img.dim()}'
+ assert bbox.dim() == 3, \
+ f'bbox must have 3 dims, got: {bbox.dim()}'
+ assert knn_imgs[0].dim() == 5, \
+ f'knn_img must have 5 dims, got: {knn_imgs[0].dim()}'
+ assert knn_bboxes[0].dim() == 3, \
+ f'knn_bbox must have 3 dims, got: {knn_bboxes[0].dim()}'
+ img = img.view(
+ img.size(0) * img.size(1), img.size(2), img.size(3),
+ img.size(4)) # (1B)xCxHxW
+ knn_imgs = [
+ knn_img.view(
+ knn_img.size(0) * knn_img.size(1),
+ knn_img.size(2),
+ # K (1B)xCxHxW
+ knn_img.size(3),
+ knn_img.size(4)) for knn_img in knn_imgs
+ ]
+ with torch.no_grad():
+ feat = self.backbone(img)[0].clone().detach()
+ knn_feats = [
+ self.backbone(knn_img)[0].clone().detach()
+ for knn_img in knn_imgs
+ ]
+ feat = F.adaptive_avg_pool2d(feat, (1, 1))
+ knn_feats = [
+ F.adaptive_avg_pool2d(knn_feat, (1, 1))
+ for knn_feat in knn_feats
+ ]
+ feat = feat.view(feat.size(0), -1) # (1B)xC
+ knn_feats = [
+ knn_feat.view(knn_feat.size(0), -1) for knn_feat in knn_feats
+ ] # K (1B)xC
+ feat_norm = F.normalize(feat, dim=1)
+ knn_feats_norm = [
+ F.normalize(knn_feat, dim=1) for knn_feat in knn_feats
+ ]
+ # smaps: a list containing K similarity matrix (BxB Tensor)
+ smaps = [
+ torch.mm(feat_norm, knn_feat_norm.transpose(0, 1))
+ for knn_feat_norm in knn_feats_norm
+ ] # K BxB
+ top_query_inds = []
+ top_key_inds = []
+ for smap in smaps:
+ topk_num = int(self.topk_bbox_ratio * smap.size(0))
+ _, top_ind = torch.topk(smap.flatten(),
+ topk_num if topk_num > 0 else 1)
+ top_query_ind = top_ind // smap.size(1)
+ top_key_ind = top_ind % smap.size(1)
+ top_query_inds.append(top_query_ind)
+ top_key_inds.append(top_key_ind)
+ bbox = bbox.view(bbox.size(0) * bbox.size(1),
+ bbox.size(2)) # (1B)x4
+ knn_bboxes = [
+ knn_bbox.view(
+ knn_bbox.size(0) * knn_bbox.size(1), knn_bbox.size(2))
+ for knn_bbox in knn_bboxes
+ ] # K (1B)x4
+ # K (topk_bbox_num)x8
+ topk_box_pairs_list = [
+ torch.cat((bbox[qind], kbox[kind]),
+ dim=1).cpu() for kbox, qind, kind in zip(
+ knn_bboxes, top_query_inds, top_key_inds)
+ ]
+ knn_bbox_keys = [
+ '{}nn_bbox'.format(k) for k in range(len(topk_box_pairs_list))
+ ]
+ dict1 = dict(intra_bbox=bbox.cpu())
+ dict2 = dict(zip(knn_bbox_keys, topk_box_pairs_list))
+ # intra_bbox: Bx4, inter_bbox: K (topk_bbox_num)x8
+ # B is the number of filtered bboxes, K is the number of knn images,
+ return {**dict1, **dict2}
+
+ def forward(self, img, bbox, img_keys, bbox_keys, mode='test', **kwargs):
+ if mode == 'test':
+ return self.predict(img, bbox, img_keys, bbox_keys, **kwargs)
+ else:
+ raise Exception('No such mode: {}'.format(mode))
+
+
+@MODELS.register_module()
+class ORL(BaseModel):
+ """ORL.
+
+ Args:
+ backbone (dict): Config dict for module of
+ backbone ConvNet.
+ neck (dict):
+ Config dict for module of deep features to compact feature vectors.
+ Default: None.
+ head (dict):
+ Config dict for module of loss functions.
+ Default: None.
+ pretrained (str, optional):
+ Path to pre-trained weights.
+ Default: None.
+ base_momentum (float):
+ The base momentum coefficient for the target network.
+ Default: 0.99.
+ global_loss_weight (float):
+ Loss weight for global image branch. Default: 1.
+ local_intra_loss_weight (float):
+ Loss weight for local intra-roi branch. Default: 1.
+ local_inter_loss_weight (float):
+ Loss weight for local inter-roi branch. Default: 1.
+ """
+
+ def __init__(self,
+ backbone: dict,
+ neck: dict,
+ head: dict,
+ base_momentum: float = 0.99,
+ pretrained: Optional[str] = None,
+ init_cfg: Optional[Union[List[dict], dict]] = None,
+ global_loss_weight: float = 1.,
+ loc_intra_loss_weight: float = 1.,
+ loc_inter_loss_weight: float = 1.,
+ **kwargs):
+ super().__init__(
+ backbone=backbone,
+ neck=neck,
+ head=head,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+ self.global_loss_weight = global_loss_weight
+ self.loc_intra_weight = loc_intra_loss_weight
+ self.loc_inter_weight = loc_inter_loss_weight
+ self.online_net = nn.Sequential(self.backbone, self.neck)
+ self.target_net = CosineEMA(self.online_net, momentum=base_momentum)
+
+ self.loc_intra_head = MODELS.build(head)
+ self.loc_inter_head = MODELS.build(head)
+
+ self.base_momentum = base_momentum
+ self.momentum = base_momentum
+
+ def loss(self, inputs: List[torch.Tensor],
+ data_samples: List[SelfSupDataSample], **kwargs) -> dict:
+ """Forward computation during training.
+
+ Args:
+ img (Tensor):
+ Input of two concatenated images with shape (N, 2, C, H, W).
+ Typically should be mean centered and std scaled.
+ ipatch (Tensor):
+ Input of two concatenated intra-RoI patches with shape
+ (N, 2, C, H, W). Typically should be mean centered and std scaled.
+ kpatch (Tensor):
+ Input of two concatenated inter-RoI patches with shape
+ (N, 2, C, H, W). Typically should be mean centered and std scaled.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert isinstance(inputs, list)
+ global_img, ipatch, kpatch = inputs
+ assert global_img.dim() == 5, \
+ 'Input must have 5 dims, got: {}'.format(global_img.dim())
+ img_v1 = global_img[:, 0, ...].contiguous()
+ img_v2 = global_img[:, 1, ...].contiguous()
+ assert ipatch.dim() == 5, \
+ 'Input must have 5 dims, got: {}'.format(ipatch.dim())
+ ipatch_v1 = ipatch[:, 0, ...].contiguous()
+ ipatch_v2 = ipatch[:, 1, ...].contiguous()
+ assert kpatch.dim() == 5, \
+ 'Input must have 5 dims, got: {}'.format(kpatch.dim())
+ kpatch_v1 = kpatch[:, 0, ...].contiguous()
+ kpatch_v2 = kpatch[:, 1, ...].contiguous()
+ # compute online features
+ global_online_v1 = self.online_net(img_v1)[0]
+ global_online_v2 = self.online_net(img_v2)[0]
+ loc_intra_v1 = self.online_net(ipatch_v1)[0]
+ loc_intra_v2 = self.online_net(ipatch_v2)[0]
+ loc_inter_v1 = self.online_net(kpatch_v1)[0]
+ loc_inter_v2 = self.online_net(kpatch_v2)[0]
+ # compute target features
+ with torch.no_grad():
+ global_target_v1 = self.target_net(img_v1)[0].clone().detach()
+ global_target_v2 = self.target_net(img_v2)[0].clone().detach()
+ loc_intra_tar_v1 = self.target_net(ipatch_v1)[0].clone().detach()
+ loc_intra_tar_v2 = self.target_net(ipatch_v2)[0].clone().detach()
+ loc_inter_tar_v1 = self.target_net(kpatch_v1)[0].clone().detach()
+ loc_inter_tar_v2 = self.target_net(kpatch_v2)[0].clone().detach()
+ # compute losses
+ global_loss =\
+ self.head(global_online_v1, global_target_v2) + \
+ self.head(global_online_v2, global_target_v1)
+
+ local_intra_loss =\
+ self.loc_intra_head(loc_intra_v1, loc_intra_tar_v2) + \
+ self.loc_intra_head(loc_intra_v2, loc_intra_tar_v1)
+ local_inter_loss = \
+ self.loc_inter_head(loc_inter_v1, loc_inter_tar_v2) + \
+ self.loc_inter_head(loc_inter_v2, loc_inter_tar_v1)
+ losses = dict()
+ loss_global = self.global_loss_weight * global_loss
+ loss_local_intra = self.loc_intra_weight * local_intra_loss
+ loss_local_inter = self.loc_inter_weight * local_inter_loss
+
+ losses = dict(loss=loss_global + loss_local_intra + loss_local_inter)
+ return losses
+
+ def forward(self,
+ inputs: List[torch.Tensor],
+ data_samples: Optional[List[SelfSupDataSample]] = None,
+ mode: str = 'tensor'):
+ if mode == 'loss':
+ return self.loss(inputs, data_samples)
+ else:
+ raise RuntimeError(f'Invalid mode "{mode}".')
diff --git a/mmselfsup/utils/collect.py b/mmselfsup/utils/collect.py
index 3022eeb1f..6aa35ad40 100644
--- a/mmselfsup/utils/collect.py
+++ b/mmselfsup/utils/collect.py
@@ -53,6 +53,7 @@ def dist_forward_collect(func: object, data_loader: DataLoader,
rank, world_size = get_dist_info()
results = []
if rank == 0:
+
prog_bar = mmengine.ProgressBar(len(data_loader))
for _, data in enumerate(data_loader):
with torch.no_grad():
diff --git a/tools/dist_generate_correspondence_single_gpu.sh b/tools/dist_generate_correspondence_single_gpu.sh
new file mode 100644
index 000000000..da06bf7b9
--- /dev/null
+++ b/tools/dist_generate_correspondence_single_gpu.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/env bash
+
+CFG=$1
+CHECKPOINT=$2
+INPUT=$3
+OUTPUT=$4
+PY_ARGS=${@:5}
+GPUS=1
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch \
+ --nnodes=$NNODES \
+ --node_rank=$NODE_RANK \
+ --master_addr=$MASTER_ADDR \
+ --nproc_per_node=$GPUS \
+ --master_port=$PORT \
+ $(dirname "$0")/generate_correspondence.py \
+ $CFG \
+ $CHECKPOINT \
+ $INPUT \
+ $OUTPUT --launcher pytorch ${PY_ARGS}
diff --git a/tools/dist_selective_search_single_gpu.sh b/tools/dist_selective_search_single_gpu.sh
new file mode 100644
index 000000000..0ebcda40e
--- /dev/null
+++ b/tools/dist_selective_search_single_gpu.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+
+CFG=$1
+OUTPUT=$2
+PY_ARGS=${@:3}
+GPUS=1
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+PORT=${PORT:-29500}
+MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python -m torch.distributed.launch \
+ --nnodes=$NNODES \
+ --node_rank=$NODE_RANK \
+ --master_addr=$MASTER_ADDR \
+ --nproc_per_node=$GPUS \
+ --master_port=$PORT \
+ $(dirname "$0")/selective_search.py \
+ $CFG $OUTPUT --launcher pytorch ${PY_ARGS}
diff --git a/tools/generate_correspondence.py b/tools/generate_correspondence.py
new file mode 100644
index 000000000..8aac74270
--- /dev/null
+++ b/tools/generate_correspondence.py
@@ -0,0 +1,277 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import logging
+import os
+import os.path as osp
+import time
+
+import mmengine
+import torch
+from mmengine.config import Config
+from mmengine.dist import get_dist_info, init_dist
+from mmengine.logging import MMLogger, print_log
+from mmengine.model import MMDistributedDataParallel
+from mmengine.registry import build_model_from_cfg
+from mmengine.runner import Runner, load_checkpoint
+
+from mmselfsup.registry import MODELS
+from mmselfsup.utils import register_all_modules
+
+
+def nondist_single_forward_collect(func, data_loader, length):
+ """Forward and collect network outputs.
+
+ This function performs forward propagation and collects outputs.
+ It can be used to collect results, features, losses, etc.
+ Args:
+ func (function): The function to process data. The output must be
+ a dictionary of CPU tensors.
+ length (int): Expected length of output arrays.
+ Returns:
+ results_all (dict(list)): The concatenated outputs.
+ """
+ results = []
+ prog_bar = mmengine.ProgressBar(len(data_loader))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = func(**data)
+ results.append(result)
+ prog_bar.update()
+
+ results_all = {}
+ for k in results[0].keys():
+ if k == 'intra_bbox':
+ intra_results_list = [
+ batch[k].numpy().tolist() for batch in results
+ ]
+ results_all[k] = intra_results_list
+ assert len(results_all[k]) == length
+ inter_results_list = []
+ for batch in results:
+ merge_batch_results = []
+ for k in batch.keys():
+ if k != 'intra_bbox':
+ merge_batch_results.append(batch[k].numpy().tolist())
+ inter_results_list.append(merge_batch_results)
+ results_all['inter_bbox'] = inter_results_list
+ assert len(results_all['inter_bbox']) == length
+ return results_all
+
+
+def dist_single_forward_collect(func, data_loader, rank, length):
+ """Forward and collect network outputs in a distributed manner.
+
+ This function performs forward propagation and collects outputs.
+ It can be used to collect results, features, losses, etc.
+ Args:
+ func (function): The function to process data. The output must be
+ a dictionary of CPU tensors.
+ rank (int): This process id.
+ Returns:
+ results_all (dict(list)): The concatenated outputs.
+ """
+ results = []
+ if rank == 0:
+ prog_bar = mmengine.ProgressBar(len(data_loader))
+ for idx, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = func(**data) # dict{key: tensor}
+ results.append(result)
+
+ if rank == 0:
+ prog_bar.update()
+
+ results_all = {}
+ for k in results[0].keys():
+ if k == 'intra_bbox':
+ intra_results_list = [
+ batch[k].numpy().tolist() for batch in results
+ ]
+ results_all[k] = intra_results_list
+ inter_results_list = []
+ for batch in results:
+ merge_batch_results = []
+ for k in batch.keys():
+ if k != 'intra_bbox':
+ merge_batch_results.append(batch[k].numpy().tolist())
+ inter_results_list.append(merge_batch_results)
+ results_all['inter_bbox'] = inter_results_list
+ return results_all
+
+
+def single_gpu_test(model, data_loader):
+ model.eval()
+
+ def func(**x):
+ return model(mode='test', **x)
+
+ # func = lambda **x: model(mode='test', **x)
+ results = nondist_single_forward_collect(func, data_loader,
+ len(data_loader.dataset))
+ return results
+
+
+def multi_gpu_test(model, data_loader):
+ model.eval()
+
+ def func(**x):
+ return model(mode='test', **x)
+
+ # func = lambda **x: model(mode='test', **x)
+ rank, world_size = get_dist_info()
+ results = dist_single_forward_collect(func, data_loader, rank,
+ len(data_loader.dataset))
+ return results
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Generate correspondence in Stage 2 of ORL')
+ parser.add_argument('config', help='test config file path')
+ parser.add_argument('checkpoint', help='checkpoint file')
+ parser.add_argument('input', type=str, help='input knn instance json file')
+ parser.add_argument(
+ 'output', type=str, help='output correspondence json file')
+ parser.add_argument(
+ '--work-dir',
+ type=str,
+ default=None,
+ help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume',
+ nargs='?',
+ type=str,
+ const='auto',
+ help='If specify checkpint path, resume from it, while if not '
+ 'specify, try to auto resume from the latest checkpoint '
+ 'in the work directory.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--port',
+ type=int,
+ default=29500,
+ help='port only works when launcher=="slurm"')
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def evaluate(
+ json_file,
+ dataset_info,
+ intra_bbox,
+ inter_bbox,
+):
+ assert (len(intra_bbox) == len(inter_bbox)), \
+ 'Mismatch the number of images in part training set, \
+ got: intra: {} inter: {}'\
+ .format(len(intra_bbox), len(inter_bbox))
+ data = mmengine.load(json_file)
+ # dict
+ data_new = {}
+ # sub-dict
+ info = {}
+ image_info = {}
+ pseudo_anno = {}
+ info['bbox_min_size'] = dataset_info['min_size']
+ info['bbox_max_aspect_ratio'] = dataset_info['max_ratio']
+ info['bbox_max_iou'] = dataset_info['max_iou_thr']
+ info['intra_bbox_num'] = dataset_info['topN']
+ info['knn_image_num'] = dataset_info['knn_image_num']
+ info['knn_bbox_pair_ratio'] = dataset_info['topk_bbox_ratio']
+ image_info['file_name'] = data['images']['file_name']
+ image_info['id'] = data['images']['id']
+ pseudo_anno['image_id'] = data['pseudo_annotations']['image_id']
+ pseudo_anno['bbox'] = intra_bbox
+ pseudo_anno['knn_image_id'] = data['pseudo_annotations']['knn_image_id']
+ pseudo_anno['knn_bbox_pair'] = inter_bbox
+ data_new['info'] = info
+ data_new['images'] = image_info
+ data_new['pseudo_annotations'] = pseudo_anno
+ return data_new
+
+
+def main():
+
+ args = parse_args()
+ register_all_modules(init_default_scope=True)
+
+ cfg = Config.fromfile(args.config)
+ cfg.launcher = args.launcher
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ work_type = args.config.split('/')[1]
+ cfg.work_dir = osp.join('./work_dirs', work_type,
+ osp.splitext(osp.basename(args.config))[0])
+ # if args.work_dir is not None:
+ # if not os.path.exists(args.work_dir):
+ # os.makedirs(args.work_dir)
+ # cfg.work_dir = args.work_dir
+
+ # ensure to use checkpoint rather than pretraining
+ cfg.model.pretrained = None
+
+ # init distributed env first,
+ # since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ # if args.launcher == 'slurm':
+ # cfg.dist_params['port'] = args.port
+ init_dist(args.launcher, **cfg.dist_params)
+
+ # logger
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(cfg.work_dir, 'correpondece_{}.log'.format(timestamp))
+ if not os.path.exists(cfg.work_dir):
+ os.makedirs(cfg.work_dir)
+ # os.mkdir(cfg.work_dir)
+ logging.basicConfig(filename=log_file, level=cfg.log_level)
+
+ logger = MMLogger.get_instance(
+ 'mmengine', log_file=log_file, log_level=cfg.log_level)
+
+ # build the model
+ model = build_model_from_cfg(cfg.model, registry=MODELS)
+
+ # build the dataloader
+ data_loader = Runner.build_dataloader(cfg.val_dataloader)
+ load_checkpoint(model, args.checkpoint, map_location='cpu')
+
+ if not distributed:
+ # model = MMDataParallel(model, device_ids=[0])
+ outputs = single_gpu_test(model.cuda(), data_loader)
+ else:
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False)
+ outputs = multi_gpu_test(model, data_loader) # dict{key: list}
+
+ rank, _ = get_dist_info()
+ if rank == 0:
+ out = evaluate(args.input, cfg.dataset_dict, **outputs)
+ with open(args.output, 'w') as f:
+ json.dump(out, f)
+ print_log(
+ 'Correspondence json file has been saved to: {}'.format(
+ args.output),
+ logger=logger)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/selective_search.py b/tools/selective_search.py
new file mode 100644
index 000000000..7cbcad0bc
--- /dev/null
+++ b/tools/selective_search.py
@@ -0,0 +1,261 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import json
+import logging
+import os
+import os.path as osp
+import time
+
+import mmengine
+import torch
+import torch.multiprocessing
+from mmengine.config import Config, DictAction
+from mmengine.dist import get_dist_info, init_dist
+from mmengine.logging import MMLogger, print_log
+from mmengine.registry import build_model_from_cfg
+from mmengine.runner import Runner
+
+from mmselfsup.registry import MODELS
+from mmselfsup.utils import register_all_modules
+
+# from mmselfsup.datasets import build_dataloader, build_dataset
+# from mmselfsup.models import build_model
+# from mmselfsup.utils import (get_root_logger, traverse_replace, print_log)
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+
+def evaluate(bbox, **kwargs):
+
+ if not isinstance(bbox, list):
+ bbox = bbox.tolist()
+ # dict
+ data_ss = {}
+ data_ss['bbox'] = bbox
+ return data_ss
+
+
+def nondist_single_forward_collect(func, data_loader, length):
+ """Forward and collect network outputs.
+
+ This function performs forward propagation and collects outputs.
+ It can be used to collect results, features, losses, etc.
+
+ Args:
+ func (function): The function to process data. The output must be
+ a dictionary of CPU tensors.
+ length (int): Expected length of output arrays.
+
+ Returns:
+ results_all (dict(list)): The concatenated outputs.
+ """
+ results = []
+ prog_bar = mmengine.ProgressBar(len(data_loader))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = func(**data)
+ results.append(result)
+ prog_bar.update()
+
+ results_all = {}
+ for k in results[0].keys():
+ results_all[k] = [
+ batch[k].squeeze().numpy().tolist() for batch in results
+ ]
+ assert len(results_all[k]) == length
+ return results_all
+
+
+def dist_single_forward_collect(func, data_loader, rank, length):
+ """Forward and collect network outputs in a distributed manner.
+
+ This function performs forward propagation and collects outputs.
+ It can be used to collect results, features, losses, etc.
+
+ Args:
+ func (function): The function to process data. The output must be
+ a dictionary of CPU tensors.
+ rank (int): This process id.
+
+ Returns:
+ results_all (dict(list)): The concatenated outputs.
+ """
+ results = []
+ if rank == 0:
+ prog_bar = mmengine.ProgressBar(len(data_loader))
+ for idx, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = func(**data) # dict{key: tensor}
+ results.append(result)
+
+ if rank == 0:
+ prog_bar.update()
+
+ results_all = {}
+ for k in results[0].keys():
+ results_list = [
+ batch[k].squeeze().numpy().tolist() for batch in results
+ ]
+ results_all[k] = results_list
+ # assert len(results_all[k]) == length
+ return results_all
+
+
+def single_gpu_test(model, data_loader):
+ model.eval()
+
+ def func(**x):
+ return model(mode='test', **x)
+
+ # func = lambda **x: model(mode='test', **x)
+ results = nondist_single_forward_collect(func, data_loader,
+ len(data_loader.dataset))
+ return results
+
+
+def multi_gpu_test(model, data_loader):
+ model.eval()
+
+ def func(**x):
+ return model(mode='test', **x)
+
+ # func = lambda **x: model(mode='test', **x)
+ rank, world_size = get_dist_info()
+ results = dist_single_forward_collect(func, data_loader, rank,
+ len(data_loader.dataset))
+ return results
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a model')
+ parser.add_argument(
+ 'config',
+ default='configs/selfsup/orl/stage2/selective_search.py',
+ type=str,
+ help='train config file path')
+ parser.add_argument(
+ 'output',
+ default='../data/coco/meta/train2017_selective_search_proposal.json ',
+ type=str,
+ help='output total selective search proposal json file')
+ parser.add_argument(
+ '--work-dir',
+ type=str,
+ default=None,
+ help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume',
+ nargs='?',
+ type=str,
+ const='auto',
+ help='If specify checkpint path, resume from it, while if not '
+ 'specify, try to auto resume from the latest checkpoint '
+ 'in the work directory.')
+ parser.add_argument(
+ '--amp',
+ action='store_true',
+ help='enable automatic-mixed-precision training')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # register all modules in mmselfsup into the registries
+ # do not init the default scope here because it will be init in the runner
+ register_all_modules(init_default_scope=True)
+
+ cfg = Config.fromfile(args.config)
+ cfg.launcher = args.launcher
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+
+ # # set cudnn_benchmark
+ # if cfg.get('cudnn_benchmark', False):
+ # torch.backends.cudnn.benchmark = True
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ work_type = args.config.split('/')[1]
+ cfg.work_dir = osp.join('./work_dirs', work_type,
+ osp.splitext(osp.basename(args.config))[0])
+
+ # # check memcached package exists
+ # if importlib.util.find_spec('mc') is None:
+ # traverse_replace(cfg, 'memcached', False)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ # if args.launcher == 'slurm':
+ # cfg.dist_params['port'] = args.port
+ init_dist(args.launcher, **cfg.dist_params)
+
+ # logger
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(cfg.work_dir,
+ 'SelectiveSearch_{}.log'.format(timestamp))
+ if not os.path.exists(cfg.work_dir):
+ os.makedirs(cfg.work_dir)
+ logging.basicConfig(filename=log_file, level=cfg.log_level)
+
+ logger = MMLogger.get_instance(
+ 'mmengine', log_file=log_file, log_level=cfg.log_level)
+
+ # build the model
+ model = build_model_from_cfg(cfg.model, registry=MODELS)
+
+ # build the dataloader
+ # dataset = build_dataset(cfg.data.val)
+ # data_loader = build_from_cfg(cfg.val_dataloader,registry=DATASETS)
+ data_loader = Runner.build_dataloader(cfg.val_dataloader)
+
+ # outputs = single_gpu_test(model, data_loader)
+
+ if not distributed:
+ outputs = single_gpu_test(model, data_loader)
+ else:
+ outputs = multi_gpu_test(model, data_loader) # dict{key: list}
+
+ print(type(outputs))
+ if isinstance(outputs, dict):
+ print(outputs.keys())
+
+ rank, _ = get_dist_info()
+ if rank == 0:
+ out = evaluate(**outputs)
+ with open(args.output, 'w') as f:
+ json.dump(out, f)
+ print_log(
+ 'Selective search proposal json file has been saved to: {}'.format(
+ args.output),
+ logger=logger)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/slurm_generate_correspondence_single_gpu.sh b/tools/slurm_generate_correspondence_single_gpu.sh
new file mode 100644
index 000000000..3aa0d03ac
--- /dev/null
+++ b/tools/slurm_generate_correspondence_single_gpu.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+
+set -x
+PARTITION=$1
+JOB_NAME='correspondence'
+CFG=$2
+CHECKPOINT=$3
+INPUT=$4
+OUTPUT=$5
+PY_ARGS=${@:6}
+GPUS=${GPUS:-1}
+GPUS_PER_NODE=${GPUS_PER_NODE:-1}
+CPUS_PER_TASK=${CPUS_PER_TASK:-5}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ python -u tools/generate_correspondence.py \
+ $CFG \
+ $CHECKPOINT \
+ $INPUT \
+ $OUTPUT --launcher="slurm" ${PY_ARGS}
diff --git a/tools/slurm_selective_search_single_gpu.sh b/tools/slurm_selective_search_single_gpu.sh
new file mode 100644
index 000000000..4968a255e
--- /dev/null
+++ b/tools/slurm_selective_search_single_gpu.sh
@@ -0,0 +1,23 @@
+#!/usr/bin/env bash
+
+set -x
+PARTITION=$1
+JOB_NAME="selective_search"
+CFG=$2
+OUTPUT=$3
+PY_ARGS=${@:4}
+GPUS=${GPUS:-1}
+GPUS_PER_NODE=${GPUS_PER_NODE:-1}
+CPUS_PER_TASK=${CPUS_PER_TASK:-4}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+srun -p ${PARTITION} \
+ --job-name=${JOB_NAME} \
+ --gres=gpu:${GPUS_PER_NODE} \
+ --ntasks=${GPUS} \
+ --ntasks-per-node=${GPUS_PER_NODE} \
+ --cpus-per-task=${CPUS_PER_TASK} \
+ --kill-on-bad-exit=1 \
+ ${SRUN_ARGS} \
+ python -u tools/selective_search.py $CFG $OUTPUT --launcher="slurm" ${PY_ARGS}
diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh
index ac36d5082..ae1ebc465 100644
--- a/tools/slurm_train.sh
+++ b/tools/slurm_train.sh
@@ -1,7 +1,6 @@
#!/usr/bin/env bash
set -x
-
PARTITION=$1
JOB_NAME=$2
CONFIG=$3