From 60f5833f1258f381739119cab4b02765e537f782 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Sat, 28 Jan 2023 10:23:38 +0800 Subject: [PATCH 01/17] add greenmim infer --- .../_base_/models/greenmim_swin-base.py | 23 + ...in-base_16xb128-amp-coslr-100e_in1k-192.py | 49 + mmselfsup/models/algorithms/__init__.py | 3 +- mmselfsup/models/algorithms/greenmim.py | 103 +++ mmselfsup/models/backbones/__init__.py | 3 +- mmselfsup/models/backbones/greenmim.py | 874 ++++++++++++++++++ mmselfsup/models/heads/__init__.py | 4 +- mmselfsup/models/heads/greenmim_head.py | 41 + mmselfsup/models/losses/greenmim_loss.py | 40 + mmselfsup/models/necks/__init__.py | 4 +- mmselfsup/models/necks/greenmim_neck.py | 125 +++ 11 files changed, 1265 insertions(+), 4 deletions(-) create mode 100644 configs/selfsup/_base_/models/greenmim_swin-base.py create mode 100644 configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py create mode 100644 mmselfsup/models/algorithms/greenmim.py create mode 100644 mmselfsup/models/backbones/greenmim.py create mode 100644 mmselfsup/models/heads/greenmim_head.py create mode 100644 mmselfsup/models/losses/greenmim_loss.py create mode 100644 mmselfsup/models/necks/greenmim_neck.py diff --git a/configs/selfsup/_base_/models/greenmim_swin-base.py b/configs/selfsup/_base_/models/greenmim_swin-base.py new file mode 100644 index 000000000..34aa5cec3 --- /dev/null +++ b/configs/selfsup/_base_/models/greenmim_swin-base.py @@ -0,0 +1,23 @@ +# model settings +img_size = 224 +patch_size = 4 + +model = dict( + type='GreenMIM', + data_preprocessor=dict( + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='GreenMIMSwinTransformer', + arch='B', + img_size=img_size, + patch_size=patch_size, + drop_path_rate=0.0, + stage_cfgs=dict(block_cfgs=dict(window_size=7))), + neck=dict(type='GreenMIMNeck', in_channels=3, encoder_stride=32, img_size=img_size, patch_size=patch_size), + head=dict( + type='GreenMIMHead', + patch_size=patch_size, + norm_pix_loss=False, + loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3))) diff --git a/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py b/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py new file mode 100644 index 000000000..cd1fc4be8 --- /dev/null +++ b/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py @@ -0,0 +1,49 @@ +_base_ = [ + '../_base_/models/greenmim_swin-base.py', + '../_base_/datasets/imagenet_mae.py', + '../_base_/schedules/adamw_coslr-200e_in1k.py', + '../_base_/default_runtime.py', +] + +# dataset 16 GPUs x 128 +train_dataloader = dict(batch_size=128, num_workers=16) + +# optimizer wrapper +optimizer = dict( + type='AdamW', lr=2e-4 * 2048 / 512, betas=(0.9, 0.999), eps=1e-8) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=5.0), + paramwise_cfg=dict( + custom_keys={ + 'norm': dict(decay_mult=0.0), + 'bias': dict(decay_mult=0.0), + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-6 / 2e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=90, + eta_min=1e-5 * 2048 / 512, + by_epoch=True, + begin=10, + end=100, + convert_to_iter_based=True) +] + +# schedule +train_cfg = dict(max_epochs=100) + +# runtime +default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py index 590782430..9c88a70bd 100644 --- a/mmselfsup/models/algorithms/__init__.py +++ b/mmselfsup/models/algorithms/__init__.py @@ -21,10 +21,11 @@ from .simmim import SimMIM from .simsiam import SimSiam from .swav import SwAV +from .greenmim import GreenMIM __all__ = [ 'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL', 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA', - 'MixMIM' + 'MixMIM', 'GreenMIM' ] diff --git a/mmselfsup/models/algorithms/greenmim.py b/mmselfsup/models/algorithms/greenmim.py new file mode 100644 index 000000000..ae9ccd9f4 --- /dev/null +++ b/mmselfsup/models/algorithms/greenmim.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple + +import torch +from mmengine.structures import BaseDataElement + +from mmselfsup.registry import MODELS +from mmselfsup.structures import SelfSupDataSample +from .base import BaseModel + +@MODELS.register_module() +class GreenMIM(BaseModel): + """GreenMIM. + + Implementation of `GreenMIM: Green Hierarchical Vision Transformer for Masked Image Modeling + `_. + """ + + def extract_feat(self, + inputs: List[torch.Tensor], + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwarg) -> Tuple[torch.Tensor]: + """The forward function to extract features from neck. + + Args: + inputs (List[torch.Tensor]): The input images. + + Returns: + Tuple[torch.Tensor]: Neck outputs. + """ + latent, mask, ids_restore = self.backbone(inputs[0]) + pred = self.neck(latent, ids_restore) + self.mask = mask + return pred + + def reconstruct(self, + features: torch.Tensor, + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: + """The function is for image reconstruction. + + Args: + features (torch.Tensor): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + SelfSupDataSample: The prediction from model. + """ + mean = kwargs['mean'] + std = kwargs['std'] + features = features * std + mean + + pred = self.head.unpatchify(features) + pred = torch.einsum('nchw->nhwc', pred).detach().cpu() + + mask = self.mask.detach() + mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * + 3) # (N, H*W, p*p*3) + mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping + mask = torch.einsum('nchw->nhwc', mask).detach().cpu() + + results = SelfSupDataSample() + results.mask = BaseDataElement(**dict(value=mask)) + results.pred = BaseDataElement(**dict(value=pred)) + + return results + + def patchify(self, imgs, patch_size): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = patch_size + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + return x + + def loss(self, inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore = self.backbone(inputs[0]) + pred = self.neck(latent, ids_restore) + target = self.patchify(inputs[0], self.backbone.final_patch_size) + loss = self.head(pred, target, mask) + losses = dict(loss=loss) + return losses diff --git a/mmselfsup/models/backbones/__init__.py b/mmselfsup/models/backbones/__init__.py index 0a40c999b..dab9c4ebb 100644 --- a/mmselfsup/models/backbones/__init__.py +++ b/mmselfsup/models/backbones/__init__.py @@ -9,9 +9,10 @@ from .resnet import ResNet, ResNetSobel, ResNetV1d from .resnext import ResNeXt from .simmim_swin import SimMIMSwinTransformer +from .greenmim import GreenMIMSwinTransformer __all__ = [ 'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT', 'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT', - 'MixMIMTransformerPretrain' + 'MixMIMTransformerPretrain', 'GreenMIMSwinTransformer' ] diff --git a/mmselfsup/models/backbones/greenmim.py b/mmselfsup/models/backbones/greenmim.py new file mode 100644 index 000000000..61835c80d --- /dev/null +++ b/mmselfsup/models/backbones/greenmim.py @@ -0,0 +1,874 @@ +import math +from functools import partial +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch.utils.checkpoint as checkpoint +from mmselfsup.registry import MODELS +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.vision_transformer import Block + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def get_coordinates(h, w, device='cpu'): + coords_h = torch.arange(h, device=device) + coords_w = torch.arange(w, device=device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + return coords + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + # NOTE: the index is not used at pretraining and is kept for compatibility + coords = get_coordinates(*window_size) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None, pos_idx=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + # projection + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # B_, nH, N, N + + # relative position bias + assert pos_idx.dim() == 3, f"Expect the pos_idx/mask to be a 3-d tensor, but got{pos_idx.dim()}" + rel_pos_mask = torch.masked_fill(torch.ones_like(mask), mask=mask.bool(), value=0.0) + pos_idx_m = torch.masked_fill(pos_idx, mask.bool(), value=0).view(-1) + relative_position_bias = self.relative_position_bias_table[pos_idx_m].view( + -1, N, N, self.num_heads) # nW, Wh*Ww, Wh*Ww,nH + relative_position_bias = relative_position_bias * rel_pos_mask.view(-1, N, N, 1) + + nW = relative_position_bias.shape[0] + relative_position_bias = relative_position_bias.permute(0, 3, 1, 2).contiguous() # nW, nH, Wh*Ww, Wh*Ww + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + relative_position_bias.unsqueeze(0) + + # attention mask + attn = attn + mask.view(1, nW, 1, N, N) + attn = attn.view(B_, self.num_heads, N, N) + + # normalization + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + # aggregation + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, attn_mask, rel_pos_idx): + shortcut = x + x = self.norm1(x) + + # W-MSA/SW-MSA + x = self.attn(x, mask=attn_mask, pos_idx=rel_pos_idx) # B*nW, N_vis, C + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, mask_prev): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + ratio = H // 7 if H % 7 == 0 else H // 6 # FIXME + x = x.view(B, -1, ratio//2, 2, ratio//2, 2, C) + x = x.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, L//4, 4 * C) + + # merging by a linear layer + x = self.norm(x) + x = self.reduction(x) + + mask_new = mask_prev.view(1, -1, ratio//2, 2, ratio//2, 2).sum(dim=(3, 5)) + assert torch.unique(mask_new).shape[0] == 2 # should be [0, 4] + mask_new = (mask_new > 0).reshape(1, -1) + coords_new = get_coordinates(H//2, W//2, x.device).reshape(1, 2, -1) + coords_new = coords_new.transpose(2, 1)[mask_new].reshape(1, -1, 2) + return x, coords_new, mask_new + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +def knapsack(W, wt): + '''Args: + W (int): capacity + wt (tuple[int]): the numbers of elements within each window + ''' + val = wt + n = len(val) + K = [[0 for w in range(W + 1)] + for i in range(n + 1)] + + # Build table K[][] in bottom up manner + for i in range(n + 1): + for w in range(W + 1): + if i == 0 or w == 0: + K[i][w] = 0 + elif wt[i - 1] <= w: + K[i][w] = max(val[i - 1] + + K[i - 1][w - wt[i - 1]], + K[i - 1][w]) + else: + K[i][w] = K[i - 1][w] + + # stores the result of Knapsack + res = res_ret = K[n][W] + + # stores the selected indexes + w = W + idx = [] + for i in range(n, 0, -1): + if res <= 0: + break + # Either the result comes from the top (K[i-1][w]) + # or from (val[i-1] + K[i-1] [w-wt[i-1]]) as in Knapsack table. + # If it comes from the latter one, it means the item is included. + if res == K[i - 1][w]: + continue + else: + # This item is included. + idx.append(i - 1) + # Since this weight is included, its value is deducted + res = res - val[i - 1] + w = w - wt[i - 1] + + return res_ret, idx[::-1] # make the idx in an increasing order + + +def group_windows(group_size, num_ele_win): + '''Greedily apply the DP algorithm to group the elements. + Args: + group_size (int): maximal size of the group + num_ele_win (list[int]): number of visible elements of each window + Outputs: + num_ele_group (list[int]): number of elements of each group + grouped_idx (list[list[int]]): the seleted indeices of each group + ''' + wt = num_ele_win.copy() + ori_idx = list(range(len(wt))) + grouped_idx = [] + num_ele_group = [] + + while len(wt) > 0: + res, idx = knapsack(group_size, wt) + num_ele_group.append(res) + + # append the selected idx + selected_ori_idx = [ori_idx[i] for i in idx] + grouped_idx.append(selected_ori_idx) + + # remaining idx + wt = [wt[i] for i in range(len(ori_idx)) if i not in idx] + ori_idx = [ori_idx[i] for i in range(len(ori_idx)) if i not in idx] + + return num_ele_group, grouped_idx + + +class GroupingModule: + def __init__(self, window_size, shift_size, group_size=None): + self.window_size = window_size + self.shift_size = shift_size + assert shift_size >= 0 and shift_size < window_size + + self.group_size = group_size or self.window_size**2 + self.attn_mask = None + self.rel_pos_idx = None + + def _get_group_id(self, coords): + group_id = coords.clone() + group_id += (self.window_size - self.shift_size) % self.window_size + group_id = group_id // self.window_size + group_id = group_id[0, :, 0] * group_id.shape[1] + group_id[0, :, 1] # (N_vis, ) + return group_id + + def _get_attn_mask(self, group_id): + pos_mask = (group_id == -1) + pos_mask = torch.logical_and(pos_mask[:, :, None], pos_mask[:, None, :]) + gid = group_id.float() + attn_mask_float = gid.unsqueeze(2) - gid.unsqueeze(1) + attn_mask = torch.logical_or(attn_mask_float != 0, pos_mask) + attn_mask_float.masked_fill_(attn_mask, -100.) + return attn_mask_float + + def _get_rel_pos_idx(self, coords): + # num_groups, group_size, group_size, 2 + rel_pos_idx = coords[:, :, None, :] - coords[:, None, :, :] + rel_pos_idx += self.window_size - 1 + rel_pos_idx[..., 0] *= 2 * self.window_size - 1 + rel_pos_idx = rel_pos_idx.sum(dim=-1) + return rel_pos_idx + + def _prepare_masking(self, coords): + # coords: (B, N_vis, 2) + group_id = self._get_group_id(coords) # (N_vis, ) + attn_mask = self._get_attn_mask(group_id.unsqueeze(0)) + rel_pos_idx = self._get_rel_pos_idx(coords[:1]) + + # do not shuffle + self.idx_shuffle = None + self.idx_unshuffle = None + + return attn_mask, rel_pos_idx + + def _prepare_grouping(self, coords): + # find out and merge the elements within each local window + # coords: (B, N_vis, 2) + group_id = self._get_group_id(coords) # (N_vis, ) + idx_merge = torch.argsort(group_id) + group_id = group_id[idx_merge].contiguous() + exact_win_sz = torch.unique_consecutive(group_id, return_counts=True)[1].tolist() + + # group the windows by DP algorithm + self.group_size = min(self.window_size**2, max(exact_win_sz)) + num_ele_group, grouped_idx = group_windows(self.group_size, exact_win_sz) + + # pad the splits if their sizes are smaller than the group size + idx_merge_spl = idx_merge.split(exact_win_sz) + group_id_spl = group_id.split(exact_win_sz) + shuffled_idx, attn_mask = [], [] + for num_ele, gidx in zip(num_ele_group, grouped_idx): + pad_r = self.group_size - num_ele + # shuffle indexes: (group_size) + sidx = torch.cat([idx_merge_spl[i] for i in gidx], dim=0) + shuffled_idx.append(F.pad(sidx, (0, pad_r), value=-1)) + # attention mask: (group_size) + amask = torch.cat([group_id_spl[i] for i in gidx], dim=0) + attn_mask.append(F.pad(amask, (0, pad_r), value=-1)) + + # shuffle indexes: (num_groups * group_size, ) + self.idx_shuffle = torch.cat(shuffled_idx, dim=0) + # unshuffle indexes that exclude the padded indexes: (N_vis, ) + self.idx_unshuffle = torch.argsort(self.idx_shuffle)[-sum(num_ele_group):] + self.idx_shuffle[self.idx_shuffle==-1] = 0 # index_select does not permit negative index + + # attention mask: (num_groups, group_size, group_size) + attn_mask = torch.stack(attn_mask, dim=0) + attn_mask = self._get_attn_mask(attn_mask) + + # relative position indexes: (num_groups, group_size, group_size) + coords_shuffled = coords[0][self.idx_shuffle].reshape(-1, self.group_size, 2) + rel_pos_idx = self._get_rel_pos_idx(coords_shuffled) # num_groups, group_size, group_size + rel_pos_mask = torch.ones_like(rel_pos_idx).masked_fill_(attn_mask.bool(), 0) + rel_pos_idx = rel_pos_idx * rel_pos_mask + + return attn_mask, rel_pos_idx + + def prepare(self, coords, mode): + self._mode = mode + if mode == 'masking': + return self._prepare_masking(coords) + elif mode == 'grouping': + return self._prepare_grouping(coords) + else: + raise KeyError("") + + def group(self, x): + if self._mode == 'grouping': + self.ori_shape = x.shape + x = torch.index_select(x, 1, self.idx_shuffle) # (B, nG*GS, C) + x = x.reshape(-1, self.group_size, x.shape[-1]) # (B*nG, GS, C) + return x + + def merge(self, x): + if self._mode == 'grouping': + B, N, C = self.ori_shape + x = x.reshape(B, -1, C) # (B, nG*GS, C) + x = torch.index_select(x, 1, self.idx_unshuffle) # (B, N, C) + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + self.window_size = window_size + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + else: + self.shift_size = window_size // 2 + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + + def forward(self, x, coords, patch_mask, return_x_before_down=False): + # prepare the attention mask + # when the number of visible patches is small, + # all patches are partitioned into a single group + mode = "masking" if x.shape[1] <= 2 * self.window_size**2 else "grouping" + + group_block = GroupingModule(self.window_size, 0) + mask, pos_idx = group_block.prepare(coords, mode) + if self.window_size < min(self.input_resolution) and self.shift_size != 0: + group_block_shift = GroupingModule(self.window_size, self.shift_size) + mask_shift, pos_idx_shift = group_block_shift.prepare(coords, mode) + else: + # do not shift + group_block_shift = group_block + mask_shift, pos_idx_shift = mask, pos_idx + + # forward with grouping/masking + for i, blk in enumerate(self.blocks): + gblk = group_block if i % 2 ==0 else group_block_shift + attn_mask = mask if i % 2 ==0 else mask_shift + rel_pos_idx = pos_idx if i % 2 ==0 else pos_idx_shift + x = gblk.group(x) + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask, rel_pos_idx) + else: + x = blk(x, attn_mask, rel_pos_idx) + x = gblk.merge(x) + + # patch merging + if self.downsample is not None: + x_down, coords, patch_mask = self.downsample(x, patch_mask) + else: + x_down = x + + if return_x_before_down: + return x, x_down, coords, patch_mask + else: + return x_down, coords, patch_mask + + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, window_size={self.window_size},"\ + f"shift_size={self.shift_size}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + self.drop_path_rate = drop_path_rate + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x, mask): + # patch embedding + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # mask out some patches according to the random mask + B, N, C = x.shape + H, W = self.patches_resolution + ratio = N // mask.shape[1] + mask = mask[:1].clone() # we use the same mask for the whole batch + assert ratio * mask.shape[1] == N + window_size = int(ratio**0.5) + if ratio > 1: # mask_size != patch_embed_size + Mh, Mw = [sz // window_size for sz in self.patches_resolution] + mask = mask.reshape(1, Mh, 1, Mw, 1) + mask = mask.expand(-1, -1, window_size, -1, window_size) + mask = mask.reshape(1, -1) + + # record the corresponding coordinates of visible patches + coords_h = torch.arange(H, device=x.device) + coords_w = torch.arange(W, device=x.device) + coords = torch.stack(torch.meshgrid([coords_h, coords_w]), dim=-1) # H W 2 + coords = coords.reshape(1, H*W, 2) + + # for convenient, first divide the image into local windows + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, N, C) + mask = mask.view(1, H // window_size, window_size, W // window_size, window_size) + mask = mask.permute(0, 1, 3, 2, 4).reshape(1, N) + coords = coords.view(1, H // window_size, window_size, W // window_size, window_size, 2) + coords = coords.permute(0, 1, 3, 2, 4, 5).reshape(1, N, 2) + + # mask out patches + vis_mask = ~mask # ~mask means visible + x_vis = x[vis_mask.expand(B, -1)].reshape(B, -1, C) + coords = coords[vis_mask].reshape(1, -1, 2) # 1 N_vis 2 + + # transformer forward + for layer in self.layers: + x_vis, coords, vis_mask = layer(x_vis, coords, vis_mask) + x_vis = self.norm(x_vis) + + return x_vis + + def forward(self, x, mask): + return self.forward_features(x, mask) + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops + +@MODELS.register_module() +class GreenMIMSwinTransformer(nn.Module): # Swin结构的MAE + """ Masked Autoencoder with VisionTransformer backbone + """ + def __init__(self, arch='B', stage_cfgs=None, img_size=224, patch_size=4, in_chans=3, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + norm_layer=nn.LayerNorm, norm_pix_loss=False, + block_cls=Block, backbone_cls=SwinTransformer, + **kwargs): + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.encoder = backbone_cls(img_size=img_size, patch_size=patch_size, in_chans=in_chans, + num_classes=0, embed_dim=embed_dim, depths=depths, num_heads=num_heads, + window_size=window_size, norm_layer=norm_layer, **kwargs) + num_patches = np.prod(self.encoder.layers[-1].input_resolution) + self.num_patches = num_patches + patch_size = patch_size * 2**(len(depths) - 1) + self.final_patch_size = patch_size + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.encoder.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def unpatchify(self, x, patch_size=None): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = patch_size or self.final_patch_size + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L = 1, self.num_patches # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask.scatter_add_(1, ids_keep, torch.full([N, len_keep], fill_value=-1, dtype=mask.dtype, device=x.device)) + assert (mask.gather(1, ids_shuffle).gather(1, ids_restore) == mask).all() + + # repeat the mask + ids_restore = ids_restore.repeat(x.shape[0], 1) + mask = mask.repeat(x.shape[0], 1) + + return mask, ids_restore + + def forward(self, x, mask_ratio=0.75): + # generate random mask: B x Token^2,ids_restore:正确的ID顺序 + mask, ids_restore = self.random_masking(x, mask_ratio) + + # L -> L_vis:计算没有被mask掉的特征 + latent = self.encoder(x, mask.bool()) + + return latent, mask, ids_restore + + + + + + diff --git a/mmselfsup/models/heads/__init__.py b/mmselfsup/models/heads/__init__.py index 36cb565db..ec740f5a7 100644 --- a/mmselfsup/models/heads/__init__.py +++ b/mmselfsup/models/heads/__init__.py @@ -13,10 +13,12 @@ from .multi_cls_head import MultiClsHead from .simmim_head import SimMIMHead from .swav_head import SwAVHead +from .greenmim_head import GreenMIMHead __all__ = [ 'BEiTV1Head', 'BEiTV2Head', 'ContrastiveHead', 'ClsHead', 'LatentPredictHead', 'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead', - 'MaskFeatPretrainHead', 'MILANPretrainHead', 'MixMIMPretrainHead' + 'MaskFeatPretrainHead', 'MILANPretrainHead', 'MixMIMPretrainHead', + 'GreenMIMHead' ] diff --git a/mmselfsup/models/heads/greenmim_head.py b/mmselfsup/models/heads/greenmim_head.py new file mode 100644 index 000000000..0e5751145 --- /dev/null +++ b/mmselfsup/models/heads/greenmim_head.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class GreenMIMHead(BaseModule): + """Pretrain Head for SimMIM. + + Args: + patch_size (int): Patch size of each token. + loss (dict): The config for loss. + """ + + def __init__(self, patch_size, norm_pix_loss, loss: dict) -> None: + super().__init__() + self.loss = MODELS.build(loss) + self.final_patch_size = patch_size + self.norm_pix_loss = norm_pix_loss + + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + if self.norm_pix_loss: # 这个是False + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 # 用的MSE loss,这部分非常简单 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + + return loss diff --git a/mmselfsup/models/losses/greenmim_loss.py b/mmselfsup/models/losses/greenmim_loss.py new file mode 100644 index 000000000..ceebc9b0c --- /dev/null +++ b/mmselfsup/models/losses/greenmim_loss.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from mmengine.model import BaseModule +from torch.nn import functional as F + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class GreenMIMReconstructionLoss(BaseModule): + """Loss function for MAE. + + Compute the loss in masked region. + + Args: + encoder_in_channels (int): Number of input channels for encoder. + """ + + def __init__(self, encoder_in_channels: int) -> None: + super().__init__() + self.encoder_in_channels = encoder_in_channels + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Forward function of MAE Loss. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + mask (torch.Tensor): The mask of the target image. + + Returns: + torch.Tensor: The reconstruction loss. + """ + loss_rec = F.l1_loss(target, pred, reduction='none') # ???: 应该可以复用之前的loss吧 + loss = (loss_rec * mask).sum() / (mask.sum() + + 1e-5) / self.encoder_in_channels + + return loss diff --git a/mmselfsup/models/necks/__init__.py b/mmselfsup/models/necks/__init__.py index 7956fa817..3e6a092ac 100644 --- a/mmselfsup/models/necks/__init__.py +++ b/mmselfsup/models/necks/__init__.py @@ -13,10 +13,12 @@ from .relative_loc_neck import RelativeLocNeck from .simmim_neck import SimMIMNeck from .swav_neck import SwAVNeck +from .greenmim_neck import GreenMIMNeck + __all__ = [ 'AvgPool2dNeck', 'BEiTV2Neck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck', 'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck', 'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'MixMIMPretrainDecoder', - 'ClsBatchNormNeck', 'MILANPretrainDecoder' + 'ClsBatchNormNeck', 'MILANPretrainDecoder', 'GreenMIMNeck' ] diff --git a/mmselfsup/models/necks/greenmim_neck.py b/mmselfsup/models/necks/greenmim_neck.py new file mode 100644 index 000000000..cb5a55afe --- /dev/null +++ b/mmselfsup/models/necks/greenmim_neck.py @@ -0,0 +1,125 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmselfsup.registry import MODELS +from timm.models.vision_transformer import Block +import numpy as np + + +@MODELS.register_module() +class GreenMIMNeck(BaseModule): + """Pre-train Neck For SimMIM. + + This neck reconstructs the original image from the shrunk feature map. + + Args: + in_channels (int): Channel dimension of the feature map. + encoder_stride (int): The total stride of the encoder. + """ + + def __init__(self, in_channels: int, encoder_stride: int, img_size, patch_size, + embed_dim=96, depths=[2, 2, 6, 2], decoder_embed_dim=512, + mlp_ratio=4., decoder_depth=8, decoder_num_heads=16, block_cls=Block) -> None: + super().__init__() + + patch_resolution = img_size // patch_size + num_patches = (patch_resolution // (2 ** (len(depths) - 1))) ** 2 + # SwinMAE decoder specifics + embed_dim = embed_dim * 2**(len(depths) - 1) + patch_size = patch_size * 2**(len(depths) - 1) + self.patch_size = patch_size + self.decoder_embed = nn.Identity() if embed_dim == decoder_embed_dim else nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding + #### 替换block cls#### + self.decoder_blocks = nn.ModuleList([ + block_cls(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=torch.nn.LayerNorm) + for i in range(decoder_depth)]) + + self.decoder_norm = torch.nn.LayerNorm(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_channels, bias=True) # encoder to decoder + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), cls_token=False) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.mask_token, std=.02) + + def forward(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x, mask_tokens], dim=1) + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + + # add pos embed + x = x_ + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + return x + + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb From eb64471050093ca9ff94fa921694a423cc09a90d Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Sat, 11 Feb 2023 16:42:48 +0800 Subject: [PATCH 02/17] fix bugs --- .../_base_/models/greenmim_swin-base.py | 12 +- ...in-base_16xb128-amp-coslr-100e_in1k-192.py | 5 +- .../greenmim/models}/greenmim.py | 7 +- .../greenmim/models/greenmim_backbone.py | 596 ++++++++++++------ .../greenmim/models}/greenmim_head.py | 5 +- .../greenmim/models}/greenmim_neck.py | 86 ++- 6 files changed, 467 insertions(+), 244 deletions(-) rename {configs => projects/greenmim/configs}/selfsup/_base_/models/greenmim_swin-base.py (68%) rename {configs => projects/greenmim/configs}/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py (90%) rename {mmselfsup/models/algorithms => projects/greenmim/models}/greenmim.py (97%) rename mmselfsup/models/backbones/greenmim.py => projects/greenmim/models/greenmim_backbone.py (63%) rename {mmselfsup/models/heads => projects/greenmim/models}/greenmim_head.py (89%) rename {mmselfsup/models/necks => projects/greenmim/models}/greenmim_neck.py (56%) diff --git a/configs/selfsup/_base_/models/greenmim_swin-base.py b/projects/greenmim/configs/selfsup/_base_/models/greenmim_swin-base.py similarity index 68% rename from configs/selfsup/_base_/models/greenmim_swin-base.py rename to projects/greenmim/configs/selfsup/_base_/models/greenmim_swin-base.py index 34aa5cec3..5b8501d06 100644 --- a/configs/selfsup/_base_/models/greenmim_swin-base.py +++ b/projects/greenmim/configs/selfsup/_base_/models/greenmim_swin-base.py @@ -12,10 +12,20 @@ type='GreenMIMSwinTransformer', arch='B', img_size=img_size, + embed_dim=128, + num_heads=[4, 8, 16, 32], + depths=[2, 2, 18, 2], patch_size=patch_size, + decoder_depth=1, drop_path_rate=0.0, stage_cfgs=dict(block_cfgs=dict(window_size=7))), - neck=dict(type='GreenMIMNeck', in_channels=3, encoder_stride=32, img_size=img_size, patch_size=patch_size), + neck=dict( + type='GreenMIMNeck', + in_channels=3, + encoder_stride=32, + img_size=img_size, + patch_size=patch_size, + embed_dim=128), head=dict( type='GreenMIMHead', patch_size=patch_size, diff --git a/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py b/projects/greenmim/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py similarity index 90% rename from configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py rename to projects/greenmim/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py index cd1fc4be8..efd8cc9bf 100644 --- a/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py +++ b/projects/greenmim/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py @@ -6,7 +6,10 @@ ] # dataset 16 GPUs x 128 -train_dataloader = dict(batch_size=128, num_workers=16) +train_dataloader = dict( + batch_size=128, + num_workers=16, + sampler=dict(type='DefaultSampler', seed=0, shuffle=True)) # optimizer wrapper optimizer = dict( diff --git a/mmselfsup/models/algorithms/greenmim.py b/projects/greenmim/models/greenmim.py similarity index 97% rename from mmselfsup/models/algorithms/greenmim.py rename to projects/greenmim/models/greenmim.py index ae9ccd9f4..331df1947 100644 --- a/mmselfsup/models/algorithms/greenmim.py +++ b/projects/greenmim/models/greenmim.py @@ -8,12 +8,13 @@ from mmselfsup.structures import SelfSupDataSample from .base import BaseModel + @MODELS.register_module() class GreenMIM(BaseModel): """GreenMIM. - Implementation of `GreenMIM: Green Hierarchical Vision Transformer for Masked Image Modeling - `_. + Implementation of `GreenMIM: Green Hierarchical Vision Transformer for + Masked Image Modeling `_. """ def extract_feat(self, @@ -65,7 +66,7 @@ def reconstruct(self, results.pred = BaseDataElement(**dict(value=pred)) return results - + def patchify(self, imgs, patch_size): """ imgs: (N, 3, H, W) diff --git a/mmselfsup/models/backbones/greenmim.py b/projects/greenmim/models/greenmim_backbone.py similarity index 63% rename from mmselfsup/models/backbones/greenmim.py rename to projects/greenmim/models/greenmim_backbone.py index 61835c80d..a7bdf9556 100644 --- a/mmselfsup/models/backbones/greenmim.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -1,17 +1,27 @@ -import math +# Copyright (c) OpenMMLab. All rights reserved. from functools import partial +from typing import List, Optional, Union + import numpy as np import torch import torch.nn as nn -from torch.nn import functional as F import torch.utils.checkpoint as checkpoint -from mmselfsup.registry import MODELS +from mmcls.models.backbones.base_backbone import BaseBackbone from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.vision_transformer import Block +from torch.nn import functional as F + +from mmselfsup.registry import MODELS class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -37,43 +47,62 @@ def get_coordinates(h, w, device='cpu'): class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. + r""" Window based multi-head self attention (W-MSA) module with + relative position bias. It supports both of shifted and + non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + qkv_bias (bool, optional): If True, add a learnable bias to + query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale + of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention + weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. + Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH - # get pair-wise relative position index for each token inside the window - # NOTE: the index is not used at pretraining and is kept for compatibility + # get pair-wise relative position index for each + # token inside the window + # NOTE: the index is not used at pretraining and + # is kept for compatibility coords = get_coordinates(*window_size) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords = coords_flatten[:, :, # 2, Wh*Ww, Wh*Ww + None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) + self.register_buffer('relative_position_index', + relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -87,32 +116,41 @@ def forward(self, x, mask=None, pos_idx=None): """ Args: x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, + Wh*Ww) or None """ # projection B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) # B_, nH, N, N + attn = (q @ k.transpose(-2, -1)) # B_, nH, N, N # relative position bias - assert pos_idx.dim() == 3, f"Expect the pos_idx/mask to be a 3-d tensor, but got{pos_idx.dim()}" - rel_pos_mask = torch.masked_fill(torch.ones_like(mask), mask=mask.bool(), value=0.0) + assert pos_idx.dim( + ) == 3, 'Expect the pos_idx/mask to be a 3-d tensor,' + f'but got{pos_idx.dim()}' + rel_pos_mask = torch.masked_fill( + torch.ones_like(mask), mask=mask.bool(), value=0.0) pos_idx_m = torch.masked_fill(pos_idx, mask.bool(), value=0).view(-1) - relative_position_bias = self.relative_position_bias_table[pos_idx_m].view( - -1, N, N, self.num_heads) # nW, Wh*Ww, Wh*Ww,nH - relative_position_bias = relative_position_bias * rel_pos_mask.view(-1, N, N, 1) + relative_position_bias = self.relative_position_bias_table[ + pos_idx_m].view(-1, N, N, self.num_heads) # nW, Wh*Ww, Wh*Ww,nH + relative_position_bias = relative_position_bias * rel_pos_mask.view( + -1, N, N, 1) nW = relative_position_bias.shape[0] - relative_position_bias = relative_position_bias.permute(0, 3, 1, 2).contiguous() # nW, nH, Wh*Ww, Wh*Ww - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + relative_position_bias.unsqueeze(0) + relative_position_bias = relative_position_bias.permute( + 0, 3, 1, 2).contiguous() # nW, nH, Wh*Ww, Wh*Ww + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + relative_position_bias.unsqueeze(0) # attention mask attn = attn + mask.view(1, nW, 1, N, N) attn = attn.view(B_, self.num_heads, N, N) - + # normalization attn = self.softmax(attn) attn = self.attn_drop(attn) @@ -124,7 +162,8 @@ def forward(self, x, mask=None, pos_idx=None): return x def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + return f'dim={self.dim}, window_size={self.window_size}, \ + num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N @@ -145,23 +184,37 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. + input_resolution (tuple[int]): Input resolution. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + qkv_bias (bool, optional): If True, add a learnable bias to + query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale + of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm """ - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -170,27 +223,39 @@ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0 self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows + # if window size is larger than input resolution, + # we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + assert 0 <= self.shift_size < self.window_size, \ + 'shift_size must in 0-window_size' self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) def forward(self, x, attn_mask, rel_pos_idx): shortcut = x x = self.norm1(x) - # W-MSA/SW-MSA - x = self.attn(x, mask=attn_mask, pos_idx=rel_pos_idx) # B*nW, N_vis, C + # W-MSA/SW-MSA, B*nW, N_vis, C + x = self.attn(x, mask=attn_mask, pos_idx=rel_pos_idx) # FFN x = shortcut + self.drop_path(x) @@ -199,8 +264,9 @@ def forward(self, x, attn_mask, rel_pos_idx): return x def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + return f'dim={self.dim}, input_resolution={self.input_resolution}, ' + 'num_heads={self.num_heads}, ' f'window_size={self.window_size}, ' + f'shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}' def flops(self): flops = 0 @@ -223,7 +289,8 @@ class PatchMerging(nn.Module): Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): @@ -239,24 +306,26 @@ def forward(self, x, mask_prev): """ H, W = self.input_resolution B, L, C = x.shape - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - ratio = H // 7 if H % 7 == 0 else H // 6 # FIXME - x = x.view(B, -1, ratio//2, 2, ratio//2, 2, C) - x = x.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, L//4, 4 * C) - + assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.' + ratio = H // 7 if H % 7 == 0 else H // 6 # FIXME + x = x.view(B, -1, ratio // 2, 2, ratio // 2, 2, C) + x = x.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, L // 4, 4 * C) + # merging by a linear layer x = self.norm(x) x = self.reduction(x) - mask_new = mask_prev.view(1, -1, ratio//2, 2, ratio//2, 2).sum(dim=(3, 5)) - assert torch.unique(mask_new).shape[0] == 2 # should be [0, 4] + mask_new = mask_prev.view(1, -1, ratio // 2, 2, ratio // 2, + 2).sum(dim=(3, 5)) + assert torch.unique(mask_new).shape[0] == 2 # should be [0, 4] mask_new = (mask_new > 0).reshape(1, -1) - coords_new = get_coordinates(H//2, W//2, x.device).reshape(1, 2, -1) + coords_new = get_coordinates(H // 2, W // 2, + x.device).reshape(1, 2, -1) coords_new = coords_new.transpose(2, 1)[mask_new].reshape(1, -1, 2) return x, coords_new, mask_new def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" + return f'input_resolution={self.input_resolution}, dim={self.dim}' def flops(self): H, W = self.input_resolution @@ -272,18 +341,16 @@ def knapsack(W, wt): ''' val = wt n = len(val) - K = [[0 for w in range(W + 1)] - for i in range(n + 1)] - + K = [[0 for w in range(W + 1)] for i in range(n + 1)] + # Build table K[][] in bottom up manner for i in range(n + 1): for w in range(W + 1): if i == 0 or w == 0: K[i][w] = 0 elif wt[i - 1] <= w: - K[i][w] = max(val[i - 1] - + K[i - 1][w - wt[i - 1]], - K[i - 1][w]) + K[i][w] = max(val[i - 1] + K[i - 1][w - wt[i - 1]], + K[i - 1][w]) else: K[i][w] = K[i - 1][w] @@ -296,7 +363,7 @@ def knapsack(W, wt): for i in range(n, 0, -1): if res <= 0: break - # Either the result comes from the top (K[i-1][w]) + # Either the result comes from the top (K[i-1][w]) # or from (val[i-1] + K[i-1] [w-wt[i-1]]) as in Knapsack table. # If it comes from the latter one, it means the item is included. if res == K[i - 1][w]: @@ -307,19 +374,20 @@ def knapsack(W, wt): # Since this weight is included, its value is deducted res = res - val[i - 1] w = w - wt[i - 1] - - return res_ret, idx[::-1] # make the idx in an increasing order + + return res_ret, idx[::-1] # make the idx in an increasing order def group_windows(group_size, num_ele_win): - '''Greedily apply the DP algorithm to group the elements. + """Greedily apply the DP algorithm to group the elements. + Args: group_size (int): maximal size of the group num_ele_win (list[int]): number of visible elements of each window Outputs: num_ele_group (list[int]): number of elements of each group - grouped_idx (list[list[int]]): the seleted indeices of each group - ''' + grouped_idx (list[list[int]]): the selected indeices of each group + """ wt = num_ele_win.copy() ori_idx = list(range(len(wt))) grouped_idx = [] @@ -336,11 +404,12 @@ def group_windows(group_size, num_ele_win): # remaining idx wt = [wt[i] for i in range(len(ori_idx)) if i not in idx] ori_idx = [ori_idx[i] for i in range(len(ori_idx)) if i not in idx] - + return num_ele_group, grouped_idx class GroupingModule: + def __init__(self, window_size, shift_size, group_size=None): self.window_size = window_size self.shift_size = shift_size @@ -349,23 +418,25 @@ def __init__(self, window_size, shift_size, group_size=None): self.group_size = group_size or self.window_size**2 self.attn_mask = None self.rel_pos_idx = None - + def _get_group_id(self, coords): group_id = coords.clone() group_id += (self.window_size - self.shift_size) % self.window_size group_id = group_id // self.window_size - group_id = group_id[0, :, 0] * group_id.shape[1] + group_id[0, :, 1] # (N_vis, ) + group_id = group_id[0, :, 0] * group_id.shape[1] + group_id[ + 0, :, 1] # (N_vis, ) return group_id - + def _get_attn_mask(self, group_id): pos_mask = (group_id == -1) - pos_mask = torch.logical_and(pos_mask[:, :, None], pos_mask[:, None, :]) + pos_mask = torch.logical_and(pos_mask[:, :, None], pos_mask[:, + None, :]) gid = group_id.float() attn_mask_float = gid.unsqueeze(2) - gid.unsqueeze(1) attn_mask = torch.logical_or(attn_mask_float != 0, pos_mask) attn_mask_float.masked_fill_(attn_mask, -100.) return attn_mask_float - + def _get_rel_pos_idx(self, coords): # num_groups, group_size, group_size, 2 rel_pos_idx = coords[:, :, None, :] - coords[:, None, :, :] @@ -373,10 +444,10 @@ def _get_rel_pos_idx(self, coords): rel_pos_idx[..., 0] *= 2 * self.window_size - 1 rel_pos_idx = rel_pos_idx.sum(dim=-1) return rel_pos_idx - + def _prepare_masking(self, coords): # coords: (B, N_vis, 2) - group_id = self._get_group_id(coords) # (N_vis, ) + group_id = self._get_group_id(coords) # (N_vis, ) attn_mask = self._get_attn_mask(group_id.unsqueeze(0)) rel_pos_idx = self._get_rel_pos_idx(coords[:1]) @@ -385,18 +456,20 @@ def _prepare_masking(self, coords): self.idx_unshuffle = None return attn_mask, rel_pos_idx - + def _prepare_grouping(self, coords): # find out and merge the elements within each local window # coords: (B, N_vis, 2) - group_id = self._get_group_id(coords) # (N_vis, ) + group_id = self._get_group_id(coords) # (N_vis, ) idx_merge = torch.argsort(group_id) group_id = group_id[idx_merge].contiguous() - exact_win_sz = torch.unique_consecutive(group_id, return_counts=True)[1].tolist() + exact_win_sz = torch.unique_consecutive( + group_id, return_counts=True)[1].tolist() # group the windows by DP algorithm self.group_size = min(self.window_size**2, max(exact_win_sz)) - num_ele_group, grouped_idx = group_windows(self.group_size, exact_win_sz) + num_ele_group, grouped_idx = group_windows(self.group_size, + exact_win_sz) # pad the splits if their sizes are smaller than the group size idx_merge_spl = idx_merge.split(exact_win_sz) @@ -410,25 +483,30 @@ def _prepare_grouping(self, coords): # attention mask: (group_size) amask = torch.cat([group_id_spl[i] for i in gidx], dim=0) attn_mask.append(F.pad(amask, (0, pad_r), value=-1)) - + # shuffle indexes: (num_groups * group_size, ) self.idx_shuffle = torch.cat(shuffled_idx, dim=0) # unshuffle indexes that exclude the padded indexes: (N_vis, ) - self.idx_unshuffle = torch.argsort(self.idx_shuffle)[-sum(num_ele_group):] - self.idx_shuffle[self.idx_shuffle==-1] = 0 # index_select does not permit negative index + self.idx_unshuffle = torch.argsort( + self.idx_shuffle)[-sum(num_ele_group):] + self.idx_shuffle[self.idx_shuffle == + -1] = 0 # index_select does not permit negative index # attention mask: (num_groups, group_size, group_size) attn_mask = torch.stack(attn_mask, dim=0) attn_mask = self._get_attn_mask(attn_mask) # relative position indexes: (num_groups, group_size, group_size) - coords_shuffled = coords[0][self.idx_shuffle].reshape(-1, self.group_size, 2) - rel_pos_idx = self._get_rel_pos_idx(coords_shuffled) # num_groups, group_size, group_size - rel_pos_mask = torch.ones_like(rel_pos_idx).masked_fill_(attn_mask.bool(), 0) + coords_shuffled = coords[0][self.idx_shuffle].reshape( + -1, self.group_size, 2) + rel_pos_idx = self._get_rel_pos_idx( + coords_shuffled) # num_groups, group_size, group_size + rel_pos_mask = torch.ones_like(rel_pos_idx).masked_fill_( + attn_mask.bool(), 0) rel_pos_idx = rel_pos_idx * rel_pos_mask return attn_mask, rel_pos_idx - + def prepare(self, coords, mode): self._mode = mode if mode == 'masking': @@ -436,25 +514,25 @@ def prepare(self, coords, mode): elif mode == 'grouping': return self._prepare_grouping(coords) else: - raise KeyError("") + raise KeyError('') def group(self, x): if self._mode == 'grouping': self.ori_shape = x.shape - x = torch.index_select(x, 1, self.idx_shuffle) # (B, nG*GS, C) - x = x.reshape(-1, self.group_size, x.shape[-1]) # (B*nG, GS, C) + x = torch.index_select(x, 1, self.idx_shuffle) # (B, nG*GS, C) + x = x.reshape(-1, self.group_size, x.shape[-1]) # (B*nG, GS, C) return x - + def merge(self, x): if self._mode == 'grouping': B, N, C = self.ori_shape - x = x.reshape(B, -1, C) # (B, nG*GS, C) - x = torch.index_select(x, 1, self.idx_unshuffle) # (B, N, C) + x = x.reshape(B, -1, C) # (B, nG*GS, C) + x = torch.index_select(x, 1, self.idx_unshuffle) # (B, N, C) return x class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. + """A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. @@ -463,19 +541,37 @@ class BasicLayer(nn.Module): num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + drop_path (float | tuple[float], optional): Stochastic depth rate. + Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: + nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at + the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False. """ - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): super().__init__() self.dim = dim @@ -484,7 +580,8 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, self.use_checkpoint = use_checkpoint self.window_size = window_size if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows + # if window size is larger than input resolution, + # we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) else: @@ -492,33 +589,42 @@ def __init__(self, dim, input_resolution, depth, num_heads, window_size, # build blocks self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None - def forward(self, x, coords, patch_mask, return_x_before_down=False): # prepare the attention mask - # when the number of visible patches is small, + # when the number of visible patches is small, # all patches are partitioned into a single group - mode = "masking" if x.shape[1] <= 2 * self.window_size**2 else "grouping" - + mode = 'masking' if x.shape[ + 1] <= 2 * self.window_size**2 else 'grouping' + group_block = GroupingModule(self.window_size, 0) mask, pos_idx = group_block.prepare(coords, mode) - if self.window_size < min(self.input_resolution) and self.shift_size != 0: - group_block_shift = GroupingModule(self.window_size, self.shift_size) + if self.window_size < min( + self.input_resolution) and self.shift_size != 0: + group_block_shift = GroupingModule(self.window_size, + self.shift_size) mask_shift, pos_idx_shift = group_block_shift.prepare(coords, mode) else: # do not shift @@ -527,31 +633,31 @@ def forward(self, x, coords, patch_mask, return_x_before_down=False): # forward with grouping/masking for i, blk in enumerate(self.blocks): - gblk = group_block if i % 2 ==0 else group_block_shift - attn_mask = mask if i % 2 ==0 else mask_shift - rel_pos_idx = pos_idx if i % 2 ==0 else pos_idx_shift + gblk = group_block if i % 2 == 0 else group_block_shift + attn_mask = mask if i % 2 == 0 else mask_shift + rel_pos_idx = pos_idx if i % 2 == 0 else pos_idx_shift x = gblk.group(x) if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, attn_mask, rel_pos_idx) else: x = blk(x, attn_mask, rel_pos_idx) x = gblk.merge(x) - + # patch merging if self.downsample is not None: x_down, coords, patch_mask = self.downsample(x, patch_mask) else: x_down = x - + if return_x_before_down: return x, x_down, coords, patch_mask else: return x_down, coords, patch_mask - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, window_size={self.window_size},"\ - f"shift_size={self.shift_size}, depth={self.depth}" + return f'dim={self.dim}, input_resolution={self.input_resolution},' \ + f'window_size={self.window_size},' \ + f'shift_size={self.shift_size}, depth={self.depth}' def flops(self): flops = 0 @@ -569,15 +675,23 @@ class PatchEmbed(nn.Module): img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. + embed_dim (int): Number of linear projection output channels. + Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution @@ -586,7 +700,8 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_la self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: @@ -596,7 +711,8 @@ def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + f"Input image size ({H}*{W}) doesn't match model" \ + f'({self.img_size[0]}*{self.img_size[1]}).' x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) @@ -604,7 +720,8 @@ def forward(self, x): def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = Ho * Wo * self.embed_dim * self.in_chans * ( + self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops @@ -612,35 +729,56 @@ def flops(self): class SwinTransformer(nn.Module): r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 + A PyTorch impl of : `Swin Transformer: Hierarchical Vision + Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 + num_classes (int): Number of classes for classification head. + Default: 1000 embed_dim (int): Patch embedding dimension. Default: 96 depths (tuple(int)): Depth of each Swin Transformer layer. num_heads (tuple(int)): Number of attention heads in different layers. window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. + Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + ape (bool): If True, add absolute position embedding to the + patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. + Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + num_classes=1000, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ape=False, + patch_norm=True, use_checkpoint=False): super().__init__() @@ -649,13 +787,16 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features = int(embed_dim * 2**(self.num_layers - 1)) self.mlp_ratio = mlp_ratio self.drop_path_rate = drop_path_rate # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution @@ -663,30 +804,37 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, # absolute position embedding if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): - layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), - input_resolution=(patches_resolution[0] // (2 ** i_layer), - patches_resolution[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - use_checkpoint=use_checkpoint) + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=(patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) self.layers.append(layer) self.norm = norm_layer(self.num_features) @@ -721,33 +869,37 @@ def forward_features(self, x, mask): B, N, C = x.shape H, W = self.patches_resolution ratio = N // mask.shape[1] - mask = mask[:1].clone() # we use the same mask for the whole batch + mask = mask[:1].clone() # we use the same mask for the whole batch assert ratio * mask.shape[1] == N window_size = int(ratio**0.5) - if ratio > 1: # mask_size != patch_embed_size + if ratio > 1: # mask_size != patch_embed_size Mh, Mw = [sz // window_size for sz in self.patches_resolution] mask = mask.reshape(1, Mh, 1, Mw, 1) mask = mask.expand(-1, -1, window_size, -1, window_size) mask = mask.reshape(1, -1) - + # record the corresponding coordinates of visible patches coords_h = torch.arange(H, device=x.device) coords_w = torch.arange(W, device=x.device) - coords = torch.stack(torch.meshgrid([coords_h, coords_w]), dim=-1) # H W 2 - coords = coords.reshape(1, H*W, 2) - + coords = torch.stack( + torch.meshgrid([coords_h, coords_w]), dim=-1) # H W 2 + coords = coords.reshape(1, H * W, 2) + # for convenient, first divide the image into local windows - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, N, C) - mask = mask.view(1, H // window_size, window_size, W // window_size, window_size) + mask = mask.view(1, H // window_size, window_size, W // window_size, + window_size) mask = mask.permute(0, 1, 3, 2, 4).reshape(1, N) - coords = coords.view(1, H // window_size, window_size, W // window_size, window_size, 2) + coords = coords.view(1, H // window_size, window_size, + W // window_size, window_size, 2) coords = coords.permute(0, 1, 3, 2, 4, 5).reshape(1, N, 2) # mask out patches - vis_mask = ~mask # ~mask means visible + vis_mask = ~mask # ~mask means visible x_vis = x[vis_mask.expand(B, -1)].reshape(B, -1, C) - coords = coords[vis_mask].reshape(1, -1, 2) # 1 N_vis 2 + coords = coords[vis_mask].reshape(1, -1, 2) # 1 N_vis 2 # transformer forward for layer in self.layers: @@ -764,34 +916,57 @@ def flops(self): flops += self.patch_embed.flops() for i, layer in enumerate(self.layers): flops += layer.flops() - flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.patches_resolution[ + 0] * self.patches_resolution[1] // (2**self.num_layers) flops += self.num_features * self.num_classes return flops + @MODELS.register_module() -class GreenMIMSwinTransformer(nn.Module): # Swin结构的MAE - """ Masked Autoencoder with VisionTransformer backbone - """ - def __init__(self, arch='B', stage_cfgs=None, img_size=224, patch_size=4, in_chans=3, - embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], - window_size=7, mlp_ratio=4., - decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - norm_layer=nn.LayerNorm, norm_pix_loss=False, - block_cls=Block, backbone_cls=SwinTransformer, +class GreenMIMSwinTransformer(BaseBackbone): # Swin结构的MAE + """Masked Autoencoder with VisionTransformer backbone.""" + + def __init__(self, + arch='B', + stage_cfgs=None, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss=False, + block_cls=Block, + backbone_cls=SwinTransformer, + init_cfg: Optional[Union[List[dict], dict]] = None, **kwargs): - super().__init__() + super().__init__(init_cfg=init_cfg) # -------------------------------------------------------------------------- # MAE encoder specifics - self.encoder = backbone_cls(img_size=img_size, patch_size=patch_size, in_chans=in_chans, - num_classes=0, embed_dim=embed_dim, depths=depths, num_heads=num_heads, - window_size=window_size, norm_layer=norm_layer, **kwargs) + self.encoder = backbone_cls( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + num_classes=0, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + norm_layer=norm_layer, + **kwargs) num_patches = np.prod(self.encoder.layers[-1].input_resolution) self.num_patches = num_patches patch_size = patch_size * 2**(len(depths) - 1) self.final_patch_size = patch_size # -------------------------------------------------------------------------- - + self.norm_pix_loss = norm_pix_loss self.initialize_weights() @@ -823,25 +998,26 @@ def unpatchify(self, x, patch_size=None): p = patch_size or self.final_patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def random_masking(self, x, mask_ratio): - """ - Perform per-sample random masking by per-sample shuffling. + """Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L = 1, self.num_patches # batch, length, dim len_keep = int(L * (1 - mask_ratio)) - + torch.manual_seed(0) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] - + # sort noise for each sample - ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_shuffle = torch.argsort( + noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset @@ -849,8 +1025,14 @@ def random_masking(self, x, mask_ratio): # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) - mask.scatter_add_(1, ids_keep, torch.full([N, len_keep], fill_value=-1, dtype=mask.dtype, device=x.device)) - assert (mask.gather(1, ids_shuffle).gather(1, ids_restore) == mask).all() + mask.scatter_add_( + 1, ids_keep, + torch.full([N, len_keep], + fill_value=-1, + dtype=mask.dtype, + device=x.device)) + assert (mask.gather(1, ids_shuffle).gather(1, + ids_restore) == mask).all() # repeat the mask ids_restore = ids_restore.repeat(x.shape[0], 1) @@ -860,15 +1042,11 @@ def random_masking(self, x, mask_ratio): def forward(self, x, mask_ratio=0.75): # generate random mask: B x Token^2,ids_restore:正确的ID顺序 + # x, mask, ids_restore, latent = + # torch.load("./x_mask_ids_restore_latent.pth") mask, ids_restore = self.random_masking(x, mask_ratio) # L -> L_vis:计算没有被mask掉的特征 latent = self.encoder(x, mask.bool()) return latent, mask, ids_restore - - - - - - diff --git a/mmselfsup/models/heads/greenmim_head.py b/projects/greenmim/models/greenmim_head.py similarity index 89% rename from mmselfsup/models/heads/greenmim_head.py rename to projects/greenmim/models/greenmim_head.py index 0e5751145..949a17fc4 100644 --- a/mmselfsup/models/heads/greenmim_head.py +++ b/projects/greenmim/models/greenmim_head.py @@ -20,20 +20,19 @@ def __init__(self, patch_size, norm_pix_loss, loss: dict) -> None: self.final_patch_size = patch_size self.norm_pix_loss = norm_pix_loss - def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] - mask: [N, L], 0 is keep, 1 is remove, + mask: [N, L], 0 is keep, 1 is remove, """ if self.norm_pix_loss: # 这个是False mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 - loss = (pred - target) ** 2 # 用的MSE loss,这部分非常简单 + loss = (pred - target)**2 # 用的MSE loss,这部分非常简单 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches diff --git a/mmselfsup/models/necks/greenmim_neck.py b/projects/greenmim/models/greenmim_neck.py similarity index 56% rename from mmselfsup/models/necks/greenmim_neck.py rename to projects/greenmim/models/greenmim_neck.py index cb5a55afe..bd0832b8d 100644 --- a/mmselfsup/models/necks/greenmim_neck.py +++ b/projects/greenmim/models/greenmim_neck.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np import torch import torch.nn as nn from mmengine.model import BaseModule -from mmselfsup.registry import MODELS from timm.models.vision_transformer import Block -import numpy as np + +from mmselfsup.registry import MODELS @MODELS.register_module() @@ -18,37 +19,59 @@ class GreenMIMNeck(BaseModule): encoder_stride (int): The total stride of the encoder. """ - def __init__(self, in_channels: int, encoder_stride: int, img_size, patch_size, - embed_dim=96, depths=[2, 2, 6, 2], decoder_embed_dim=512, - mlp_ratio=4., decoder_depth=8, decoder_num_heads=16, block_cls=Block) -> None: + def __init__(self, + in_channels: int, + encoder_stride: int, + img_size, + patch_size, + embed_dim=96, + depths=[2, 2, 6, 2], + decoder_embed_dim=512, + mlp_ratio=4., + decoder_depth=8, + decoder_num_heads=16, + block_cls=Block) -> None: super().__init__() patch_resolution = img_size // patch_size - num_patches = (patch_resolution // (2 ** (len(depths) - 1))) ** 2 + num_patches = (patch_resolution // (2**(len(depths) - 1)))**2 # SwinMAE decoder specifics embed_dim = embed_dim * 2**(len(depths) - 1) patch_size = patch_size * 2**(len(depths) - 1) self.patch_size = patch_size - self.decoder_embed = nn.Identity() if embed_dim == decoder_embed_dim else nn.Linear(embed_dim, decoder_embed_dim, bias=True) + self.decoder_embed = nn.Identity( + ) if embed_dim == decoder_embed_dim else nn.Linear( + embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) - self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding - #### 替换block cls#### + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, decoder_embed_dim), + requires_grad=False) # fixed sin-cos embedding self.decoder_blocks = nn.ModuleList([ - block_cls(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=torch.nn.LayerNorm) - for i in range(decoder_depth)]) + block_cls( + decoder_embed_dim, + decoder_num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=torch.nn.LayerNorm) for i in range(decoder_depth) + ]) self.decoder_norm = torch.nn.LayerNorm(decoder_embed_dim) - self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_channels, bias=True) # encoder to decoder - + self.decoder_pred = nn.Linear( + decoder_embed_dim, patch_size**2 * in_channels, + bias=True) # encoder to decoder + def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding - decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), cls_token=False) - self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) - - # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.num_patches**.5), + cls_token=False) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + torch.nn.init.normal_(self.mask_token, std=.02) def forward(self, x, ids_restore): @@ -56,9 +79,15 @@ def forward(self, x, ids_restore): x = self.decoder_embed(x) # append mask tokens to sequence - mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + mask_tokens = self.mask_token.repeat(x.shape[0], + ids_restore.shape[1] - x.shape[1], + 1) x_ = torch.cat([x, mask_tokens], dim=1) - x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x_ = torch.gather( + x_, + dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, + x.shape[2])) # unshuffle # add pos embed x = x_ + self.decoder_pos_embed @@ -74,12 +103,12 @@ def forward(self, x, ids_restore): return x - def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, + embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) @@ -89,7 +118,8 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) return pos_embed @@ -97,10 +127,12 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb @@ -118,8 +150,8 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb From 28ad63dcb1257b17dfed33847bbedea5b8f1cabf Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Sat, 11 Feb 2023 18:39:16 +0800 Subject: [PATCH 03/17] fix bugs --- .gitignore | 1 + mmselfsup/models/algorithms/__init__.py | 3 +- mmselfsup/models/backbones/__init__.py | 3 +- mmselfsup/models/heads/__init__.py | 4 +- mmselfsup/models/necks/__init__.py | 4 +- .../_base_/models => }/greenmim_swin-base.py | 2 + ...in-base_16xb128-amp-coslr-100e_in1k-192.py | 13 +-- projects/greenmim/models/__init__.py | 8 ++ projects/greenmim/models/greenmim.py | 2 +- projects/greenmim/tools/train.py | 99 +++++++++++++++++++ 10 files changed, 120 insertions(+), 19 deletions(-) rename projects/greenmim/configs/{selfsup/_base_/models => }/greenmim_swin-base.py (92%) rename projects/greenmim/configs/{selfsup/greenmim => }/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py (77%) create mode 100644 projects/greenmim/models/__init__.py create mode 100644 projects/greenmim/tools/train.py diff --git a/.gitignore b/.gitignore index df0976d7e..0bc88e9dd 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,4 @@ INFO # Pytorch *.pth +data diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py index 9c88a70bd..590782430 100644 --- a/mmselfsup/models/algorithms/__init__.py +++ b/mmselfsup/models/algorithms/__init__.py @@ -21,11 +21,10 @@ from .simmim import SimMIM from .simsiam import SimSiam from .swav import SwAV -from .greenmim import GreenMIM __all__ = [ 'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL', 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA', - 'MixMIM', 'GreenMIM' + 'MixMIM' ] diff --git a/mmselfsup/models/backbones/__init__.py b/mmselfsup/models/backbones/__init__.py index dab9c4ebb..0a40c999b 100644 --- a/mmselfsup/models/backbones/__init__.py +++ b/mmselfsup/models/backbones/__init__.py @@ -9,10 +9,9 @@ from .resnet import ResNet, ResNetSobel, ResNetV1d from .resnext import ResNeXt from .simmim_swin import SimMIMSwinTransformer -from .greenmim import GreenMIMSwinTransformer __all__ = [ 'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT', 'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT', - 'MixMIMTransformerPretrain', 'GreenMIMSwinTransformer' + 'MixMIMTransformerPretrain' ] diff --git a/mmselfsup/models/heads/__init__.py b/mmselfsup/models/heads/__init__.py index ec740f5a7..36cb565db 100644 --- a/mmselfsup/models/heads/__init__.py +++ b/mmselfsup/models/heads/__init__.py @@ -13,12 +13,10 @@ from .multi_cls_head import MultiClsHead from .simmim_head import SimMIMHead from .swav_head import SwAVHead -from .greenmim_head import GreenMIMHead __all__ = [ 'BEiTV1Head', 'BEiTV2Head', 'ContrastiveHead', 'ClsHead', 'LatentPredictHead', 'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead', - 'MaskFeatPretrainHead', 'MILANPretrainHead', 'MixMIMPretrainHead', - 'GreenMIMHead' + 'MaskFeatPretrainHead', 'MILANPretrainHead', 'MixMIMPretrainHead' ] diff --git a/mmselfsup/models/necks/__init__.py b/mmselfsup/models/necks/__init__.py index 3e6a092ac..7956fa817 100644 --- a/mmselfsup/models/necks/__init__.py +++ b/mmselfsup/models/necks/__init__.py @@ -13,12 +13,10 @@ from .relative_loc_neck import RelativeLocNeck from .simmim_neck import SimMIMNeck from .swav_neck import SwAVNeck -from .greenmim_neck import GreenMIMNeck - __all__ = [ 'AvgPool2dNeck', 'BEiTV2Neck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck', 'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck', 'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'MixMIMPretrainDecoder', - 'ClsBatchNormNeck', 'MILANPretrainDecoder', 'GreenMIMNeck' + 'ClsBatchNormNeck', 'MILANPretrainDecoder' ] diff --git a/projects/greenmim/configs/selfsup/_base_/models/greenmim_swin-base.py b/projects/greenmim/configs/greenmim_swin-base.py similarity index 92% rename from projects/greenmim/configs/selfsup/_base_/models/greenmim_swin-base.py rename to projects/greenmim/configs/greenmim_swin-base.py index 5b8501d06..ca3119487 100644 --- a/projects/greenmim/configs/selfsup/_base_/models/greenmim_swin-base.py +++ b/projects/greenmim/configs/greenmim_swin-base.py @@ -1,3 +1,5 @@ +custom_imports = dict(imports=['models'], allow_failed_imports=False) + # model settings img_size = 224 patch_size = 4 diff --git a/projects/greenmim/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py similarity index 77% rename from projects/greenmim/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py rename to projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py index efd8cc9bf..7b95650f3 100644 --- a/projects/greenmim/configs/selfsup/greenmim/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py +++ b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py @@ -1,15 +1,12 @@ _base_ = [ - '../_base_/models/greenmim_swin-base.py', - '../_base_/datasets/imagenet_mae.py', - '../_base_/schedules/adamw_coslr-200e_in1k.py', - '../_base_/default_runtime.py', + './greenmim_swin-base.py', + 'mmselfsup::selfsup/_base_/datasets/imagenet_mae.py', + 'mmselfsup::selfsup/_base_/schedules/adamw_coslr-200e_in1k.py', + 'mmselfsup::selfsup/_base_/default_runtime.py', ] # dataset 16 GPUs x 128 -train_dataloader = dict( - batch_size=128, - num_workers=16, - sampler=dict(type='DefaultSampler', seed=0, shuffle=True)) +train_dataloader = dict(batch_size=128, num_workers=16) # optimizer wrapper optimizer = dict( diff --git a/projects/greenmim/models/__init__.py b/projects/greenmim/models/__init__.py new file mode 100644 index 000000000..88921011b --- /dev/null +++ b/projects/greenmim/models/__init__.py @@ -0,0 +1,8 @@ +from .greenmim import GreenMIM +from .greenmim_backbone import GreenMIMSwinTransformer +from .greenmim_head import GreenMIMHead +from .greenmim_neck import GreenMIMNeck + +__all__ = [ + 'GreenMIM', 'GreenMIMSwinTransformer', 'GreenMIMHead', 'GreenMIMNeck' +] diff --git a/projects/greenmim/models/greenmim.py b/projects/greenmim/models/greenmim.py index 331df1947..5d639c0f8 100644 --- a/projects/greenmim/models/greenmim.py +++ b/projects/greenmim/models/greenmim.py @@ -4,9 +4,9 @@ import torch from mmengine.structures import BaseDataElement +from mmselfsup.models.algorithms.base import BaseModel from mmselfsup.registry import MODELS from mmselfsup.structures import SelfSupDataSample -from .base import BaseModel @MODELS.register_module() diff --git a/projects/greenmim/tools/train.py b/projects/greenmim/tools/train.py new file mode 100644 index 000000000..ef0d3127c --- /dev/null +++ b/projects/greenmim/tools/train.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmselfsup.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # register all modules in mmselfsup into the registries + # do not init the default scope here because it will be init in the runner + register_all_modules(init_default_scope=False) + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + work_type = args.config.split('/')[1] + cfg.work_dir = osp.join('./work_dirs', work_type, + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') + assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ + '`--amp` is not supported custom optimizer wrapper type ' \ + f'`{optim_wrapper}.' + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') + + # resume training + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main() From 2a38fa3d6152272d526444255ff5c4d4734dfd96 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Wed, 15 Feb 2023 17:38:08 +0800 Subject: [PATCH 04/17] fix _scope_ bugs --- .../greenmim/configs/adamw_coslr-200e_in1k.py | 19 ++++++++++++ projects/greenmim/configs/default_runtime.py | 29 +++++++++++++++++++ ...in-base_16xb128-amp-coslr-100e_in1k-192.py | 6 ++-- projects/greenmim/configs/imagenet_mae.py | 29 +++++++++++++++++++ projects/greenmim/models/greenmim_backbone.py | 9 +++--- 5 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 projects/greenmim/configs/adamw_coslr-200e_in1k.py create mode 100644 projects/greenmim/configs/default_runtime.py create mode 100644 projects/greenmim/configs/imagenet_mae.py diff --git a/projects/greenmim/configs/adamw_coslr-200e_in1k.py b/projects/greenmim/configs/adamw_coslr-200e_in1k.py new file mode 100644 index 000000000..7ab03a869 --- /dev/null +++ b/projects/greenmim/configs/adamw_coslr-200e_in1k.py @@ -0,0 +1,19 @@ +# optimizer_wrapper +optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05) +optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', T_max=160, by_epoch=True, begin=40, end=200) +] + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200) diff --git a/projects/greenmim/configs/default_runtime.py b/projects/greenmim/configs/default_runtime.py new file mode 100644 index 000000000..d672cdc00 --- /dev/null +++ b/projects/greenmim/configs/default_runtime.py @@ -0,0 +1,29 @@ +default_scope = 'mmselfsup' + +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=10), + sampler_seed=dict(type='DistSamplerSeedHook'), +) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +log_processor = dict( + window_size=10, + custom_cfg=[dict(data_src='', method='mean', windows_size='global')]) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='SelfSupVisualizer', vis_backends=vis_backends, name='visualizer') +# custom_hooks = [dict(type='SelfSupVisualizationHook', interval=1)] + +log_level = 'INFO' +load_from = None +resume = False diff --git a/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py index 7b95650f3..caac9f791 100644 --- a/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py +++ b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py @@ -1,8 +1,8 @@ _base_ = [ './greenmim_swin-base.py', - 'mmselfsup::selfsup/_base_/datasets/imagenet_mae.py', - 'mmselfsup::selfsup/_base_/schedules/adamw_coslr-200e_in1k.py', - 'mmselfsup::selfsup/_base_/default_runtime.py', + './imagenet_mae.py', + './adamw_coslr-200e_in1k.py', + './default_runtime.py', ] # dataset 16 GPUs x 128 diff --git a/projects/greenmim/configs/imagenet_mae.py b/projects/greenmim/configs/imagenet_mae.py new file mode 100644 index 000000000..c642ad471 --- /dev/null +++ b/projects/greenmim/configs/imagenet_mae.py @@ -0,0 +1,29 @@ +# dataset settings +dataset_type = 'mmcls.ImageNet' +data_root = 'data/imagenet/' +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='RandomResizedCrop', + size=224, + scale=(0.2, 1.0), + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5), + dict(type='PackSelfSupInputs', meta_keys=['img_path']) +] + +train_dataloader = dict( + batch_size=128, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index a7bdf9556..21051b402 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -1014,7 +1014,7 @@ def random_masking(self, x, mask_ratio): len_keep = int(L * (1 - mask_ratio)) torch.manual_seed(0) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] - + print(noise) # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove @@ -1042,11 +1042,12 @@ def random_masking(self, x, mask_ratio): def forward(self, x, mask_ratio=0.75): # generate random mask: B x Token^2,ids_restore:正确的ID顺序 - # x, mask, ids_restore, latent = - # torch.load("./x_mask_ids_restore_latent.pth") + x, mask, ids_restore, latent_gt = \ + torch.load('./x_mask_ids_restore_latent.pth') mask, ids_restore = self.random_masking(x, mask_ratio) # L -> L_vis:计算没有被mask掉的特征 latent = self.encoder(x, mask.bool()) - + print(latent[0][0][:3]) + print(latent_gt[0][0][:3]) return latent, mask, ids_restore From e2d47adf365fe8de0133f0efb5f245615319323f Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Thu, 16 Feb 2023 17:43:40 +0800 Subject: [PATCH 05/17] ch timm modules to mmcv --- projects/greenmim/models/greenmim_backbone.py | 36 ++++++++++--------- projects/greenmim/models/greenmim_head.py | 4 +-- projects/greenmim/models/greenmim_neck.py | 35 +++++++++--------- 3 files changed, 40 insertions(+), 35 deletions(-) diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index 21051b402..8877b7fc6 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -7,14 +7,16 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from mmcls.models.backbones.base_backbone import BaseBackbone -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from timm.models.vision_transformer import Block +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import trunc_normal_ +from mmengine.utils import to_2tuple from torch.nn import functional as F from mmselfsup.registry import MODELS -class Mlp(nn.Module): +class Mlp(BaseModule): def __init__(self, in_features, @@ -46,7 +48,7 @@ def get_coordinates(h, w, device='cpu'): return coords -class WindowAttention(nn.Module): +class WindowAttention(BaseModule): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. @@ -179,7 +181,7 @@ def flops(self, N): return flops -class SwinTransformerBlock(nn.Module): +class SwinTransformerBlock(BaseModule): r""" Swin Transformer Block. Args: @@ -196,8 +198,8 @@ class SwinTransformerBlock(nn.Module): drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. + act_layer (BaseModule, optional): Activation layer. Default: nn.GELU + norm_layer (BaseModule, optional): Normalization layer. Default: nn.LayerNorm """ @@ -283,13 +285,13 @@ def flops(self): return flops -class PatchMerging(nn.Module): +class PatchMerging(BaseModule): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. + norm_layer (BaseModule, optional): Normalization layer. Default: nn.LayerNorm """ @@ -531,7 +533,7 @@ def merge(self, x): return x -class BasicLayer(nn.Module): +class BasicLayer(BaseModule): """A basic Swin Transformer layer for one stage. Args: @@ -549,9 +551,9 @@ class BasicLayer(nn.Module): attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: + norm_layer (BaseModule, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at + downsample (BaseModule | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. @@ -668,7 +670,7 @@ def flops(self): return flops -class PatchEmbed(nn.Module): +class PatchEmbed(BaseModule): r""" Image to Patch Embedding Args: @@ -677,7 +679,7 @@ class PatchEmbed(nn.Module): in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None + norm_layer (BaseModule, optional): Normalization layer. Default: None """ def __init__(self, @@ -727,7 +729,7 @@ def flops(self): return flops -class SwinTransformer(nn.Module): +class SwinTransformer(BaseModule): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` @@ -752,7 +754,7 @@ class SwinTransformer(nn.Module): drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + norm_layer (BaseModule): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. @@ -942,7 +944,6 @@ def __init__(self, decoder_num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, - block_cls=Block, backbone_cls=SwinTransformer, init_cfg: Optional[Union[List[dict], dict]] = None, **kwargs): @@ -1050,4 +1051,5 @@ def forward(self, x, mask_ratio=0.75): latent = self.encoder(x, mask.bool()) print(latent[0][0][:3]) print(latent_gt[0][0][:3]) + return latent, mask, ids_restore diff --git a/projects/greenmim/models/greenmim_head.py b/projects/greenmim/models/greenmim_head.py index 949a17fc4..5938570f4 100644 --- a/projects/greenmim/models/greenmim_head.py +++ b/projects/greenmim/models/greenmim_head.py @@ -27,12 +27,12 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor, pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ - if self.norm_pix_loss: # 这个是False + if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 - loss = (pred - target)**2 # 用的MSE loss,这部分非常简单 + loss = (pred - target)**2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches diff --git a/projects/greenmim/models/greenmim_neck.py b/projects/greenmim/models/greenmim_neck.py index bd0832b8d..678caf345 100644 --- a/projects/greenmim/models/greenmim_neck.py +++ b/projects/greenmim/models/greenmim_neck.py @@ -2,8 +2,8 @@ import numpy as np import torch import torch.nn as nn +from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer from mmengine.model import BaseModule -from timm.models.vision_transformer import Block from mmselfsup.registry import MODELS @@ -19,18 +19,20 @@ class GreenMIMNeck(BaseModule): encoder_stride (int): The total stride of the encoder. """ - def __init__(self, - in_channels: int, - encoder_stride: int, - img_size, - patch_size, - embed_dim=96, - depths=[2, 2, 6, 2], - decoder_embed_dim=512, - mlp_ratio=4., - decoder_depth=8, - decoder_num_heads=16, - block_cls=Block) -> None: + def __init__( + self, + in_channels: int, + encoder_stride: int, + img_size, + patch_size, + embed_dim=96, + depths=[2, 2, 6, 2], + decoder_embed_dim=512, + mlp_ratio=4., + decoder_depth=8, + decoder_num_heads=16, + norm_cfg: dict = dict(type='LN', eps=1e-6), + ) -> None: super().__init__() patch_resolution = img_size // patch_size @@ -48,13 +50,14 @@ def __init__(self, self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding + self.decoder_blocks = nn.ModuleList([ - block_cls( + TransformerEncoderLayer( decoder_embed_dim, decoder_num_heads, - mlp_ratio, + int(mlp_ratio * decoder_embed_dim), qkv_bias=True, - norm_layer=torch.nn.LayerNorm) for i in range(decoder_depth) + norm_cfg=norm_cfg) for _ in range(decoder_depth) ]) self.decoder_norm = torch.nn.LayerNorm(decoder_embed_dim) From daa916e24942e83c3772eca8590149decd2e5591 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Thu, 16 Feb 2023 20:19:23 +0800 Subject: [PATCH 06/17] add type hints --- projects/greenmim/models/greenmim_backbone.py | 319 ++++++++---------- 1 file changed, 142 insertions(+), 177 deletions(-) diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index 8877b7fc6..d7cc94ea9 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from functools import partial -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -19,11 +19,11 @@ class Mlp(BaseModule): def __init__(self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.): + in_features: torch.Tensor, + hidden_features: int = None, + out_features: int = None, + act_layer: nn.Module = nn.GELU, + drop: float = 0.) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -32,7 +32,7 @@ def __init__(self, self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) @@ -41,7 +41,7 @@ def forward(self, x): return x -def get_coordinates(h, w, device='cpu'): +def get_coordinates(h: int, w: int, device: str = 'cpu') -> torch.Tensor: coords_h = torch.arange(h, device=device) coords_w = torch.arange(w, device=device) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww @@ -68,13 +68,13 @@ class WindowAttention(BaseModule): """ def __init__(self, - dim, - window_size, - num_heads, - qkv_bias=True, - qk_scale=None, - attn_drop=0., - proj_drop=0.): + dim: int, + window_size: tuple, + num_heads: int, + qkv_bias: bool = True, + qk_scale: bool = None, + attn_drop: float = 0., + proj_drop: float = 0.) -> None: super().__init__() self.dim = dim @@ -114,7 +114,10 @@ def __init__(self, trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, mask=None, pos_idx=None): + def forward(self, + x: torch.Tensor, + mask: torch.Tensor = None, + pos_idx: torch.Tensor = None) -> torch.Tensor: """ Args: x: input features with shape of (num_windows*B, N, C) @@ -167,19 +170,6 @@ def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, \ num_heads={self.num_heads}' - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - class SwinTransformerBlock(BaseModule): r""" Swin Transformer Block. @@ -198,25 +188,25 @@ class SwinTransformerBlock(BaseModule): drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (BaseModule, optional): Activation layer. Default: nn.GELU - norm_layer (BaseModule, optional): Normalization layer. + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, - dim, - input_resolution, - num_heads, - window_size=7, - shift_size=0, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm): + dim: int, + input_resolution: tuple[int], + num_heads: int, + window_size: int = 7, + shift_size: int = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: bool = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm) -> None: super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -252,7 +242,8 @@ def __init__(self, act_layer=act_layer, drop=drop) - def forward(self, x, attn_mask, rel_pos_idx): + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor, + rel_pos_idx: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) @@ -270,20 +261,6 @@ def extra_repr(self) -> str: 'num_heads={self.num_heads}, ' f'window_size={self.window_size}, ' f'shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}' - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - class PatchMerging(BaseModule): r""" Patch Merging Layer. @@ -291,18 +268,22 @@ class PatchMerging(BaseModule): Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. - norm_layer (BaseModule, optional): Normalization layer. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + def __init__(self, + input_resolution: tuple[int], + dim: int, + norm_layer: nn.Module = nn.LayerNorm) -> None: super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) - def forward(self, x, mask_prev): + def forward(self, x: torch.Tensor, + mask_prev: torch.Tensor) -> torch.Tensor: """ x: B, H*W, C """ @@ -329,14 +310,8 @@ def forward(self, x, mask_prev): def extra_repr(self) -> str: return f'input_resolution={self.input_resolution}, dim={self.dim}' - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - -def knapsack(W, wt): +def knapsack(W: int, wt: tuple[int]) -> Tuple[list[list[int]], list]: '''Args: W (int): capacity wt (tuple[int]): the numbers of elements within each window @@ -380,7 +355,8 @@ def knapsack(W, wt): return res_ret, idx[::-1] # make the idx in an increasing order -def group_windows(group_size, num_ele_win): +def group_windows(group_size: int, + num_ele_win: list[int]) -> Tuple[list[int], list[list[int]]]: """Greedily apply the DP algorithm to group the elements. Args: @@ -410,9 +386,12 @@ def group_windows(group_size, num_ele_win): return num_ele_group, grouped_idx -class GroupingModule: +class GroupingModule(BaseModule): - def __init__(self, window_size, shift_size, group_size=None): + def __init__(self, + window_size: int, + shift_size: int, + group_size: int = None) -> None: self.window_size = window_size self.shift_size = shift_size assert shift_size >= 0 and shift_size < window_size @@ -421,7 +400,7 @@ def __init__(self, window_size, shift_size, group_size=None): self.attn_mask = None self.rel_pos_idx = None - def _get_group_id(self, coords): + def _get_group_id(self, coords: torch.Tensor) -> torch.Tensor: group_id = coords.clone() group_id += (self.window_size - self.shift_size) % self.window_size group_id = group_id // self.window_size @@ -429,7 +408,7 @@ def _get_group_id(self, coords): 0, :, 1] # (N_vis, ) return group_id - def _get_attn_mask(self, group_id): + def _get_attn_mask(self, group_id: torch.Tensor) -> torch.Tensor: pos_mask = (group_id == -1) pos_mask = torch.logical_and(pos_mask[:, :, None], pos_mask[:, None, :]) @@ -439,7 +418,7 @@ def _get_attn_mask(self, group_id): attn_mask_float.masked_fill_(attn_mask, -100.) return attn_mask_float - def _get_rel_pos_idx(self, coords): + def _get_rel_pos_idx(self, coords: torch.Tensor) -> torch.Tensor: # num_groups, group_size, group_size, 2 rel_pos_idx = coords[:, :, None, :] - coords[:, None, :, :] rel_pos_idx += self.window_size - 1 @@ -447,7 +426,7 @@ def _get_rel_pos_idx(self, coords): rel_pos_idx = rel_pos_idx.sum(dim=-1) return rel_pos_idx - def _prepare_masking(self, coords): + def _prepare_masking(self, coords: torch.Tensor) -> torch.Tensor: # coords: (B, N_vis, 2) group_id = self._get_group_id(coords) # (N_vis, ) attn_mask = self._get_attn_mask(group_id.unsqueeze(0)) @@ -459,7 +438,8 @@ def _prepare_masking(self, coords): return attn_mask, rel_pos_idx - def _prepare_grouping(self, coords): + def _prepare_grouping( + self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # find out and merge the elements within each local window # coords: (B, N_vis, 2) group_id = self._get_group_id(coords) # (N_vis, ) @@ -509,7 +489,8 @@ def _prepare_grouping(self, coords): return attn_mask, rel_pos_idx - def prepare(self, coords, mode): + def prepare(self, coords: torch.Tensor, + mode: torch.Tensor) -> torch.Tensor: self._mode = mode if mode == 'masking': return self._prepare_masking(coords) @@ -518,14 +499,14 @@ def prepare(self, coords, mode): else: raise KeyError('') - def group(self, x): + def group(self, x: torch.Tensor) -> torch.Tensor: if self._mode == 'grouping': self.ori_shape = x.shape x = torch.index_select(x, 1, self.idx_shuffle) # (B, nG*GS, C) x = x.reshape(-1, self.group_size, x.shape[-1]) # (B*nG, GS, C) return x - def merge(self, x): + def merge(self, x: torch.Tensor) -> torch.Tensor: if self._mode == 'grouping': B, N, C = self.ori_shape x = x.reshape(B, -1, C) # (B, nG*GS, C) @@ -551,29 +532,29 @@ class BasicLayer(BaseModule): attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (BaseModule, optional): Normalization layer. Default: + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (BaseModule | None, optional): Downsample layer at + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False): + dim: int, + input_resolution: tuple[int], + depth: int, + num_heads: int, + window_size: int, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: bool = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + downsample: nn.Module = None, + use_checkpoint: bool = False) -> None: super().__init__() self.dim = dim @@ -614,7 +595,11 @@ def __init__(self, else: self.downsample = None - def forward(self, x, coords, patch_mask, return_x_before_down=False): + def forward(self, + x: torch.Tensor, + coords: torch.Tensor, + patch_mask: torch.Tensor, + return_x_before_down: bool = False) -> torch.Tensor: # prepare the attention mask # when the number of visible patches is small, # all patches are partitioned into a single group @@ -661,14 +646,6 @@ def extra_repr(self) -> str: f'window_size={self.window_size},' \ f'shift_size={self.shift_size}, depth={self.depth}' - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - class PatchEmbed(BaseModule): r""" Image to Patch Embedding @@ -679,15 +656,15 @@ class PatchEmbed(BaseModule): in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (BaseModule, optional): Normalization layer. Default: None + norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, - img_size=224, - patch_size=4, - in_chans=3, - embed_dim=96, - norm_layer=None): + img_size: int = 224, + patch_size: int = 4, + in_chans: int = 3, + embed_dim: int = 96, + norm_layer: nn.Module = None) -> None: super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -709,7 +686,7 @@ def __init__(self, else: self.norm = None - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ @@ -720,14 +697,6 @@ def forward(self, x): x = self.norm(x) return x - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * ( - self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - class SwinTransformer(BaseModule): r""" Swin Transformer @@ -754,7 +723,7 @@ class SwinTransformer(BaseModule): drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (BaseModule): Normalization layer. Default: nn.LayerNorm. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False patch_norm (bool): If True, add normalization after patch embedding. @@ -764,24 +733,24 @@ class SwinTransformer(BaseModule): """ def __init__(self, - img_size=224, - patch_size=4, - in_chans=3, - num_classes=1000, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.1, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - ape=False, - patch_norm=True, - use_checkpoint=False): + img_size: int = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: int = 96, + depths: list = [2, 2, 6, 2], + num_heads: list = [3, 6, 12, 24], + window_size: int = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + ape: bool = False, + patch_norm: bool = True, + use_checkpoint: bool = False) -> None: super().__init__() self.num_classes = num_classes @@ -843,7 +812,7 @@ def __init__(self, self.apply(self._init_weights) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -860,7 +829,8 @@ def no_weight_decay(self): def no_weight_decay_keywords(self): return {'relative_position_bias_table'} - def forward_features(self, x, mask): + def forward_features(self, x: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: # patch embedding x = self.patch_embed(x) if self.ape: @@ -910,46 +880,35 @@ def forward_features(self, x, mask): return x_vis - def forward(self, x, mask): + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return self.forward_features(x, mask) - def flops(self): - flops = 0 - flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): - flops += layer.flops() - flops += self.num_features * self.patches_resolution[ - 0] * self.patches_resolution[1] // (2**self.num_layers) - flops += self.num_features * self.num_classes - return flops - @MODELS.register_module() -class GreenMIMSwinTransformer(BaseBackbone): # Swin结构的MAE +class GreenMIMSwinTransformer(BaseBackbone): """Masked Autoencoder with VisionTransformer backbone.""" def __init__(self, - arch='B', - stage_cfgs=None, - img_size=224, - patch_size=4, - in_chans=3, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - mlp_ratio=4., - decoder_embed_dim=512, - decoder_depth=8, - decoder_num_heads=16, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - norm_pix_loss=False, - backbone_cls=SwinTransformer, + arch: str = 'B', + stage_cfgs: dict = None, + img_size: int = 224, + patch_size: int = 4, + in_chans: int = 3, + embed_dim: int = 96, + depths: list = [2, 2, 6, 2], + num_heads: list = [3, 6, 12, 24], + window_size: int = 7, + mlp_ratio: float = 4., + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss: bool = False, + backbone_cls: nn.Module = SwinTransformer, init_cfg: Optional[Union[List[dict], dict]] = None, - **kwargs): + **kwargs) -> None: super().__init__(init_cfg=init_cfg) - # -------------------------------------------------------------------------- # MAE encoder specifics self.encoder = backbone_cls( img_size=img_size, @@ -966,8 +925,6 @@ def __init__(self, self.num_patches = num_patches patch_size = patch_size * 2**(len(depths) - 1) self.final_patch_size = patch_size - # -------------------------------------------------------------------------- - self.norm_pix_loss = norm_pix_loss self.initialize_weights() @@ -981,7 +938,7 @@ def initialize_weights(self): # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) @@ -991,7 +948,9 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def unpatchify(self, x, patch_size=None): + def unpatchify(self, + x: torch.Tensor, + patch_size: int = None) -> torch.Tensor: """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) @@ -1005,7 +964,9 @@ def unpatchify(self, x, patch_size=None): imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs - def random_masking(self, x, mask_ratio): + def random_masking( + self, x: torch.Tensor, + mask_ratio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. @@ -1041,7 +1002,11 @@ def random_masking(self, x, mask_ratio): return mask, ids_restore - def forward(self, x, mask_ratio=0.75): + def forward( + self, + x: torch.Tensor, + mask_ratio: float = 0.75, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # generate random mask: B x Token^2,ids_restore:正确的ID顺序 x, mask, ids_restore, latent_gt = \ torch.load('./x_mask_ids_restore_latent.pth') From 91ce6f5cb03a3ae5295ca4503a57acbdf2b9cbb7 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Thu, 16 Feb 2023 20:29:58 +0800 Subject: [PATCH 07/17] add type hints --- projects/greenmim/models/greenmim.py | 4 +-- projects/greenmim/models/greenmim_backbone.py | 1 - projects/greenmim/models/greenmim_head.py | 3 +- projects/greenmim/models/greenmim_neck.py | 34 +++++++++---------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/projects/greenmim/models/greenmim.py b/projects/greenmim/models/greenmim.py index 5d639c0f8..6ab20a04b 100644 --- a/projects/greenmim/models/greenmim.py +++ b/projects/greenmim/models/greenmim.py @@ -67,10 +67,10 @@ def reconstruct(self, return results - def patchify(self, imgs, patch_size): + def patchify(self, imgs: torch.Tensor, patch_size: int) -> torch.Tensor: """ imgs: (N, 3, H, W) - x: (N, L, patch_size**2 *3) + patch_size: int """ p = patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index d7cc94ea9..4ce401201 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -976,7 +976,6 @@ def random_masking( len_keep = int(L * (1 - mask_ratio)) torch.manual_seed(0) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] - print(noise) # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove diff --git a/projects/greenmim/models/greenmim_head.py b/projects/greenmim/models/greenmim_head.py index 5938570f4..2697780bb 100644 --- a/projects/greenmim/models/greenmim_head.py +++ b/projects/greenmim/models/greenmim_head.py @@ -14,7 +14,8 @@ class GreenMIMHead(BaseModule): loss (dict): The config for loss. """ - def __init__(self, patch_size, norm_pix_loss, loss: dict) -> None: + def __init__(self, patch_size: int, norm_pix_loss: bool, + loss: dict) -> None: super().__init__() self.loss = MODELS.build(loss) self.final_patch_size = patch_size diff --git a/projects/greenmim/models/greenmim_neck.py b/projects/greenmim/models/greenmim_neck.py index 678caf345..39b5a36e4 100644 --- a/projects/greenmim/models/greenmim_neck.py +++ b/projects/greenmim/models/greenmim_neck.py @@ -10,27 +10,23 @@ @MODELS.register_module() class GreenMIMNeck(BaseModule): - """Pre-train Neck For SimMIM. + """Pre-train Neck For GreenMIM. This neck reconstructs the original image from the shrunk feature map. - - Args: - in_channels (int): Channel dimension of the feature map. - encoder_stride (int): The total stride of the encoder. """ def __init__( self, in_channels: int, encoder_stride: int, - img_size, - patch_size, - embed_dim=96, - depths=[2, 2, 6, 2], - decoder_embed_dim=512, - mlp_ratio=4., - decoder_depth=8, - decoder_num_heads=16, + img_size: int, + patch_size: int, + embed_dim: int = 96, + depths: list = [2, 2, 6, 2], + decoder_embed_dim: int = 512, + mlp_ratio: float = 4., + decoder_depth: int = 8, + decoder_num_heads: int = 16, norm_cfg: dict = dict(type='LN', eps=1e-6), ) -> None: super().__init__() @@ -77,7 +73,8 @@ def initialize_weights(self): torch.nn.init.normal_(self.mask_token, std=.02) - def forward(self, x, ids_restore): + def forward(self, x: torch.Tensor, + ids_restore: torch.Tensor) -> torch.Tensor: # embed tokens x = self.decoder_embed(x) @@ -106,7 +103,9 @@ def forward(self, x, ids_restore): return x -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): +def get_2d_sincos_pos_embed(embed_dim: int, + grid_size: int, + cls_token: bool = False) -> np.ndarray: """ grid_size: int of the grid height and width return: @@ -126,7 +125,8 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, + grid: np.ndarray) -> np.ndarray: assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -139,7 +139,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: list) -> np.ndarray: """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) From 8e35a224912afd7c2d4834709b0a4e353d98c478 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Thu, 16 Feb 2023 20:56:55 +0800 Subject: [PATCH 08/17] del sth --- projects/greenmim/models/greenmim_backbone.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index 4ce401201..51c190b64 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -195,7 +195,7 @@ class SwinTransformerBlock(BaseModule): def __init__(self, dim: int, - input_resolution: tuple[int], + input_resolution: Tuple[int], num_heads: int, window_size: int = 7, shift_size: int = 0, @@ -273,7 +273,7 @@ class PatchMerging(BaseModule): """ def __init__(self, - input_resolution: tuple[int], + input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: super().__init__() @@ -311,7 +311,7 @@ def extra_repr(self) -> str: return f'input_resolution={self.input_resolution}, dim={self.dim}' -def knapsack(W: int, wt: tuple[int]) -> Tuple[list[list[int]], list]: +def knapsack(W: int, wt: Tuple[int]) -> Tuple[List[List[int]], list]: '''Args: W (int): capacity wt (tuple[int]): the numbers of elements within each window @@ -356,7 +356,7 @@ def knapsack(W: int, wt: tuple[int]) -> Tuple[list[list[int]], list]: def group_windows(group_size: int, - num_ele_win: list[int]) -> Tuple[list[int], list[list[int]]]: + num_ele_win: List[int]) -> Tuple[List[int], List[List[int]]]: """Greedily apply the DP algorithm to group the elements. Args: @@ -542,7 +542,7 @@ class BasicLayer(BaseModule): def __init__(self, dim: int, - input_resolution: tuple[int], + input_resolution: Tuple[int], depth: int, num_heads: int, window_size: int, @@ -974,7 +974,7 @@ def random_masking( """ N, L = 1, self.num_patches # batch, length, dim len_keep = int(L * (1 - mask_ratio)) - torch.manual_seed(0) + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( @@ -1006,14 +1006,8 @@ def forward( x: torch.Tensor, mask_ratio: float = 0.75, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # generate random mask: B x Token^2,ids_restore:正确的ID顺序 - x, mask, ids_restore, latent_gt = \ - torch.load('./x_mask_ids_restore_latent.pth') + # generate random mask: B x Token^2,ids_restore mask, ids_restore = self.random_masking(x, mask_ratio) - - # L -> L_vis:计算没有被mask掉的特征 latent = self.encoder(x, mask.bool()) - print(latent[0][0][:3]) - print(latent_gt[0][0][:3]) return latent, mask, ids_restore From 1e0f8eea7edc2151762cdf986bcc80ed9c4613bb Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Sun, 19 Feb 2023 09:25:38 +0000 Subject: [PATCH 09/17] [Fix] Fix bugs and refine --- .../selfsup/_base_/datasets/imagenet_mae.py | 5 +- mmselfsup/models/losses/greenmim_loss.py | 3 +- .../greenmim/configs/adamw_coslr-200e_in1k.py | 19 ---- projects/greenmim/configs/default_runtime.py | 29 ------ ...in-base_16xb128-amp-coslr-100e_in1k-192.py | 6 +- projects/greenmim/configs/imagenet_mae.py | 29 ------ projects/greenmim/models/greenmim_backbone.py | 9 +- projects/greenmim/models/greenmim_head.py | 2 +- projects/greenmim/models/greenmim_neck.py | 8 +- projects/greenmim/tools/train.py | 99 ------------------- 10 files changed, 18 insertions(+), 191 deletions(-) delete mode 100644 projects/greenmim/configs/adamw_coslr-200e_in1k.py delete mode 100644 projects/greenmim/configs/default_runtime.py delete mode 100644 projects/greenmim/configs/imagenet_mae.py delete mode 100644 projects/greenmim/tools/train.py diff --git a/configs/selfsup/_base_/datasets/imagenet_mae.py b/configs/selfsup/_base_/datasets/imagenet_mae.py index c642ad471..9ed6eb710 100644 --- a/configs/selfsup/_base_/datasets/imagenet_mae.py +++ b/configs/selfsup/_base_/datasets/imagenet_mae.py @@ -1,7 +1,8 @@ # dataset settings dataset_type = 'mmcls.ImageNet' -data_root = 'data/imagenet/' +data_root = '/data/common/ImageNet/' file_client_args = dict(backend='disk') +ann_file = '/home/nus-zwb/research/data/imagenet/meta/train.txt' train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), @@ -24,6 +25,6 @@ dataset=dict( type=dataset_type, data_root=data_root, - ann_file='meta/train.txt', + ann_file=ann_file, data_prefix=dict(img_path='train/'), pipeline=train_pipeline)) diff --git a/mmselfsup/models/losses/greenmim_loss.py b/mmselfsup/models/losses/greenmim_loss.py index ceebc9b0c..84abfd79f 100644 --- a/mmselfsup/models/losses/greenmim_loss.py +++ b/mmselfsup/models/losses/greenmim_loss.py @@ -33,7 +33,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor, Returns: torch.Tensor: The reconstruction loss. """ - loss_rec = F.l1_loss(target, pred, reduction='none') # ???: 应该可以复用之前的loss吧 + loss_rec = F.l1_loss( + target, pred, reduction='none') # ???: 应该可以复用之前的loss吧 loss = (loss_rec * mask).sum() / (mask.sum() + 1e-5) / self.encoder_in_channels diff --git a/projects/greenmim/configs/adamw_coslr-200e_in1k.py b/projects/greenmim/configs/adamw_coslr-200e_in1k.py deleted file mode 100644 index 7ab03a869..000000000 --- a/projects/greenmim/configs/adamw_coslr-200e_in1k.py +++ /dev/null @@ -1,19 +0,0 @@ -# optimizer_wrapper -optimizer = dict(type='AdamW', lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.05) -optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer) - -# learning rate scheduler -param_scheduler = [ - dict( - type='LinearLR', - start_factor=1e-4, - by_epoch=True, - begin=0, - end=40, - convert_to_iter_based=True), - dict( - type='CosineAnnealingLR', T_max=160, by_epoch=True, begin=40, end=200) -] - -# runtime settings -train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200) diff --git a/projects/greenmim/configs/default_runtime.py b/projects/greenmim/configs/default_runtime.py deleted file mode 100644 index d672cdc00..000000000 --- a/projects/greenmim/configs/default_runtime.py +++ /dev/null @@ -1,29 +0,0 @@ -default_scope = 'mmselfsup' - -default_hooks = dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook', interval=50), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict(type='CheckpointHook', interval=10), - sampler_seed=dict(type='DistSamplerSeedHook'), -) - -env_cfg = dict( - cudnn_benchmark=False, - mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), - dist_cfg=dict(backend='nccl'), -) - -log_processor = dict( - window_size=10, - custom_cfg=[dict(data_src='', method='mean', windows_size='global')]) - -vis_backends = [dict(type='LocalVisBackend')] -visualizer = dict( - type='SelfSupVisualizer', vis_backends=vis_backends, name='visualizer') -# custom_hooks = [dict(type='SelfSupVisualizationHook', interval=1)] - -log_level = 'INFO' -load_from = None -resume = False diff --git a/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py index caac9f791..f9ce8efa5 100644 --- a/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py +++ b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py @@ -1,8 +1,8 @@ _base_ = [ './greenmim_swin-base.py', - './imagenet_mae.py', - './adamw_coslr-200e_in1k.py', - './default_runtime.py', + '../../../configs/selfsup/_base_/datasets/imagenet_mae.py', + '../../../configs/selfsup/_base_/schedules/adamw_coslr-200e_in1k.py', + '../../../configs/selfsup/_base_/default_runtime.py', ] # dataset 16 GPUs x 128 diff --git a/projects/greenmim/configs/imagenet_mae.py b/projects/greenmim/configs/imagenet_mae.py deleted file mode 100644 index c642ad471..000000000 --- a/projects/greenmim/configs/imagenet_mae.py +++ /dev/null @@ -1,29 +0,0 @@ -# dataset settings -dataset_type = 'mmcls.ImageNet' -data_root = 'data/imagenet/' -file_client_args = dict(backend='disk') - -train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict( - type='RandomResizedCrop', - size=224, - scale=(0.2, 1.0), - backend='pillow', - interpolation='bicubic'), - dict(type='RandomFlip', prob=0.5), - dict(type='PackSelfSupInputs', meta_keys=['img_path']) -] - -train_dataloader = dict( - batch_size=128, - num_workers=8, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate'), - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='meta/train.txt', - data_prefix=dict(img_path='train/'), - pipeline=train_pipeline)) diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index 51c190b64..90a95ac3a 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -810,6 +810,8 @@ def __init__(self, self.norm = norm_layer(self.num_features) + def init_weights(self): + super().init_weights() self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: @@ -927,11 +929,8 @@ def __init__(self, self.final_patch_size = patch_size self.norm_pix_loss = norm_pix_loss - self.initialize_weights() - - def initialize_weights(self): - # initialization - # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + def init_weights(self): + super().init_weights() w = self.encoder.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) diff --git a/projects/greenmim/models/greenmim_head.py b/projects/greenmim/models/greenmim_head.py index 2697780bb..916546faf 100644 --- a/projects/greenmim/models/greenmim_head.py +++ b/projects/greenmim/models/greenmim_head.py @@ -7,7 +7,7 @@ @MODELS.register_module() class GreenMIMHead(BaseModule): - """Pretrain Head for SimMIM. + """Pretrain Head for GreenMIMHead. Args: patch_size (int): Patch size of each token. diff --git a/projects/greenmim/models/greenmim_neck.py b/projects/greenmim/models/greenmim_neck.py index 39b5a36e4..9e1537818 100644 --- a/projects/greenmim/models/greenmim_neck.py +++ b/projects/greenmim/models/greenmim_neck.py @@ -60,10 +60,12 @@ def __init__( self.decoder_pred = nn.Linear( decoder_embed_dim, patch_size**2 * in_channels, bias=True) # encoder to decoder + self.num_patches = num_patches + + def init_weights(self): + """Initialize position embedding, patch embedding.""" + super().init_weights() - def initialize_weights(self): - # initialization - # initialize (and freeze) pos_embed by sin-cos embedding decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), diff --git a/projects/greenmim/tools/train.py b/projects/greenmim/tools/train.py deleted file mode 100644 index ef0d3127c..000000000 --- a/projects/greenmim/tools/train.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -import os -import os.path as osp - -from mmengine.config import Config, DictAction -from mmengine.runner import Runner - -from mmselfsup.utils import register_all_modules - - -def parse_args(): - parser = argparse.ArgumentParser(description='Train a model') - parser.add_argument('config', help='train config file path') - parser.add_argument('--work-dir', help='the dir to save logs and models') - parser.add_argument( - '--resume', - nargs='?', - type=str, - const='auto', - help='If specify checkpint path, resume from it, while if not ' - 'specify, try to auto resume from the latest checkpoint ' - 'in the work directory.') - parser.add_argument( - '--amp', - action='store_true', - help='enable automatic-mixed-precision training') - parser.add_argument( - '--cfg-options', - nargs='+', - action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' - 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' - 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) - args = parser.parse_args() - if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) - - return args - - -def main(): - args = parse_args() - - # register all modules in mmselfsup into the registries - # do not init the default scope here because it will be init in the runner - register_all_modules(init_default_scope=False) - - # load config - cfg = Config.fromfile(args.config) - cfg.launcher = args.launcher - if args.cfg_options is not None: - cfg.merge_from_dict(args.cfg_options) - - # work_dir is determined in this priority: CLI > segment in file > filename - if args.work_dir is not None: - # update configs according to CLI args if args.work_dir is not None - cfg.work_dir = args.work_dir - elif cfg.get('work_dir', None) is None: - # use config filename as default work_dir if cfg.work_dir is None - work_type = args.config.split('/')[1] - cfg.work_dir = osp.join('./work_dirs', work_type, - osp.splitext(osp.basename(args.config))[0]) - - # enable automatic-mixed-precision training - if args.amp is True: - optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper') - assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \ - '`--amp` is not supported custom optimizer wrapper type ' \ - f'`{optim_wrapper}.' - cfg.optim_wrapper.type = 'AmpOptimWrapper' - cfg.optim_wrapper.setdefault('loss_scale', 'dynamic') - - # resume training - if args.resume == 'auto': - cfg.resume = True - cfg.load_from = None - elif args.resume is not None: - cfg.resume = True - cfg.load_from = args.resume - - # build the runner from config - runner = Runner.from_cfg(cfg) - - # start training - runner.train() - - -if __name__ == '__main__': - main() From d13424bb25b6a282ac1524c25386bc60a91276e0 Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Sun, 19 Feb 2023 09:27:14 +0000 Subject: [PATCH 10/17] [Fix] Fix bugs and refine --- configs/selfsup/_base_/datasets/imagenet_mae.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/selfsup/_base_/datasets/imagenet_mae.py b/configs/selfsup/_base_/datasets/imagenet_mae.py index 9ed6eb710..8466f8d63 100644 --- a/configs/selfsup/_base_/datasets/imagenet_mae.py +++ b/configs/selfsup/_base_/datasets/imagenet_mae.py @@ -1,8 +1,8 @@ + # dataset settings dataset_type = 'mmcls.ImageNet' -data_root = '/data/common/ImageNet/' +data_root = 'data/imagenet/' file_client_args = dict(backend='disk') -ann_file = '/home/nus-zwb/research/data/imagenet/meta/train.txt' train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), @@ -25,6 +25,6 @@ dataset=dict( type=dataset_type, data_root=data_root, - ann_file=ann_file, + ann_file='meta/train.txt', data_prefix=dict(img_path='train/'), - pipeline=train_pipeline)) + pipeline=train_pipeline)) \ No newline at end of file From db9ca9be17cea147125b4c779e70ab8a40311228 Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Sun, 19 Feb 2023 09:28:13 +0000 Subject: [PATCH 11/17] [Fix] Fix bugs and refine --- configs/selfsup/_base_/datasets/imagenet_mae.py | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/selfsup/_base_/datasets/imagenet_mae.py b/configs/selfsup/_base_/datasets/imagenet_mae.py index 8466f8d63..e197f5fb1 100644 --- a/configs/selfsup/_base_/datasets/imagenet_mae.py +++ b/configs/selfsup/_base_/datasets/imagenet_mae.py @@ -1,4 +1,3 @@ - # dataset settings dataset_type = 'mmcls.ImageNet' data_root = 'data/imagenet/' From cd46caf502f96a28d3a7d4c60d76d519f9e8061a Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Sun, 19 Feb 2023 09:29:09 +0000 Subject: [PATCH 12/17] [Fix] Fix bugs and refine --- configs/selfsup/_base_/datasets/imagenet_mae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/selfsup/_base_/datasets/imagenet_mae.py b/configs/selfsup/_base_/datasets/imagenet_mae.py index e197f5fb1..c642ad471 100644 --- a/configs/selfsup/_base_/datasets/imagenet_mae.py +++ b/configs/selfsup/_base_/datasets/imagenet_mae.py @@ -26,4 +26,4 @@ data_root=data_root, ann_file='meta/train.txt', data_prefix=dict(img_path='train/'), - pipeline=train_pipeline)) \ No newline at end of file + pipeline=train_pipeline)) From d2732e2b19cf46850107aa4af1040c8e6576b163 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Mon, 20 Feb 2023 11:10:06 +0800 Subject: [PATCH 13/17] rm greenmim loss, use mae loss --- mmselfsup/models/losses/greenmim_loss.py | 40 ------------------- .../greenmim/configs/greenmim_swin-base.py | 2 +- projects/greenmim/models/greenmim_head.py | 5 +-- 3 files changed, 2 insertions(+), 45 deletions(-) delete mode 100644 mmselfsup/models/losses/greenmim_loss.py diff --git a/mmselfsup/models/losses/greenmim_loss.py b/mmselfsup/models/losses/greenmim_loss.py deleted file mode 100644 index ceebc9b0c..000000000 --- a/mmselfsup/models/losses/greenmim_loss.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch -from mmengine.model import BaseModule -from torch.nn import functional as F - -from mmselfsup.registry import MODELS - - -@MODELS.register_module() -class GreenMIMReconstructionLoss(BaseModule): - """Loss function for MAE. - - Compute the loss in masked region. - - Args: - encoder_in_channels (int): Number of input channels for encoder. - """ - - def __init__(self, encoder_in_channels: int) -> None: - super().__init__() - self.encoder_in_channels = encoder_in_channels - - def forward(self, pred: torch.Tensor, target: torch.Tensor, - mask: torch.Tensor) -> torch.Tensor: - """Forward function of MAE Loss. - - Args: - pred (torch.Tensor): The reconstructed image. - target (torch.Tensor): The target image. - mask (torch.Tensor): The mask of the target image. - - Returns: - torch.Tensor: The reconstruction loss. - """ - loss_rec = F.l1_loss(target, pred, reduction='none') # ???: 应该可以复用之前的loss吧 - loss = (loss_rec * mask).sum() / (mask.sum() + - 1e-5) / self.encoder_in_channels - - return loss diff --git a/projects/greenmim/configs/greenmim_swin-base.py b/projects/greenmim/configs/greenmim_swin-base.py index ca3119487..ba6cfffd8 100644 --- a/projects/greenmim/configs/greenmim_swin-base.py +++ b/projects/greenmim/configs/greenmim_swin-base.py @@ -32,4 +32,4 @@ type='GreenMIMHead', patch_size=patch_size, norm_pix_loss=False, - loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3))) + loss=dict(type='MAEReconstructionLoss'))) diff --git a/projects/greenmim/models/greenmim_head.py b/projects/greenmim/models/greenmim_head.py index 2697780bb..8e91549b5 100644 --- a/projects/greenmim/models/greenmim_head.py +++ b/projects/greenmim/models/greenmim_head.py @@ -33,9 +33,6 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor, var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 - loss = (pred - target)**2 - loss = loss.mean(dim=-1) # [N, L], mean loss per patch - - loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + loss = self.loss(pred, target, mask) return loss From d9a1aba15627b5a5a9af958af044e1ed8314e738 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Mon, 20 Feb 2023 12:14:20 +0800 Subject: [PATCH 14/17] add readme --- projects/greenmim/README.md | 169 ++++++++++++++++++ ..._swin-base_16xb128-amp-coslr-100e_in1k.py} | 0 2 files changed, 169 insertions(+) create mode 100644 projects/greenmim/README.md rename projects/greenmim/configs/{greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py => greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py} (100%) diff --git a/projects/greenmim/README.md b/projects/greenmim/README.md new file mode 100644 index 000000000..e06c20376 --- /dev/null +++ b/projects/greenmim/README.md @@ -0,0 +1,169 @@ +# GreenMIM Pre-training Model + +- [GreenMIM Pre-training Model](#maskfeat-pre-training-with-video) + - [Description](#description) + - [Usage](#usage) + - [Setup Environment](#setup-environment) + - [Data Preparation](#data-preparation) + - [Pre-training Commands](#pre-training-commands) + - [On Local Single GPU](#on-local-single-gpu) + - [On Multiple GPUs](#on-multiple-gpus) + - [On Multiple GPUs with Slurm](#on-multiple-gpus-with-slurm) + - [Citation](#citation) + - [Checklist](#checklist) + +## Description + + + +Author: @xfguo-ucas + +This is the implementation of **GreenMIM** with ImageNet. + +## Usage + + + +### Setup Environment + +Requirements: + +- MMSelfSup >= 1.0.0rc7 + +Please refer to [Get Started](https://mmselfsup.readthedocs.io/en/1.x/get_started.html) documentation of MMSelfSup to finish installation. + +### Data Preparation + +You can refer to the [documentation](https://mmclassification.readthedocs.io/en/latest/getting_started.html) in mmcls. + +### Pre-training Commands + +At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/greenmim/` root directory, please run command below to add it. + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +Then run the following commands to train the model: + +#### On Local Single GPU + +```bash +# train with mim +mim train mmselfsup ${CONFIG} --work-dir ${WORK_DIR} + +# a specific command example +mim train mmselfsup configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \ + --work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ + +# train with scripts +python tools/train.py configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \ + --work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ +``` + +#### On Multiple GPUs + +```bash +# train with mim +# a specific command examples, 8 GPUs here +mim train mmselfsup configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \ + --work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ \ + --launcher pytorch --gpus 8 + +# train with scripts +bash tools/dist_train.sh configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py 8 +``` + +Note: + +- CONFIG: the config files under the directory `configs/` +- WORK_DIR: the working directory to save configs, logs, and checkpoints + +#### On Multiple GPUs with Slurm + +```bash +# train with mim +mim train mmselfsup configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \ + --work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ \ + --launcher slurm --gpus 16 --gpus-per-node 8 \ + --partition ${PARTITION} + +# train with scripts +GPUS_PER_NODE=8 GPUS=16 bash tools/slurm_train.sh ${PARTITION} greenmim \ + configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py \ + --work-dir work_dirs/selfsup/greenmim_swin-base_16xb128-amp-coslr-100e_in1k/ +``` + +Note: + +- CONFIG: the config files under the directory `configs/` +- WORK_DIR: the working directory to save configs, logs, and checkpoints +- PARTITION: the slurm partition you are using + +## Citation + +```bibtex +@article{huang2022green, + title={Green Hierarchical Vision Transformer for Masked Image Modeling}, + author={Huang, Lang and You, Shan and Zheng, Mingkai and Wang, Fei and Qian, Chen and Yamasaki, Toshihiko}, + journal={Thirty-Sixth Conference on Neural Information Processing Systems}, + year={2022} +} +``` + +## Checklist + +Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Inference correctness + + + + - [x] A full README + + + +- [x] Milestone 2: Indicates a successful model implementation. + + - [x] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] `metafile.yml` and `README.md` + + + +- [ ] Refactor and Move your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py b/projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py similarity index 100% rename from projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k-192.py rename to projects/greenmim/configs/greenmim_swin-base_16xb128-amp-coslr-100e_in1k.py From ee14b44095da29f2e7544efa3e566b4f2f941517 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Sun, 26 Feb 2023 00:22:29 +0800 Subject: [PATCH 15/17] merge class --- .gitignore | 1 - mmselfsup/models/losses/greenmim_loss.py | 41 ---- projects/greenmim/models/greenmim_backbone.py | 209 +++++++----------- 3 files changed, 80 insertions(+), 171 deletions(-) delete mode 100644 mmselfsup/models/losses/greenmim_loss.py diff --git a/.gitignore b/.gitignore index 0bc88e9dd..df0976d7e 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,3 @@ INFO # Pytorch *.pth -data diff --git a/mmselfsup/models/losses/greenmim_loss.py b/mmselfsup/models/losses/greenmim_loss.py deleted file mode 100644 index 84abfd79f..000000000 --- a/mmselfsup/models/losses/greenmim_loss.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -import torch -from mmengine.model import BaseModule -from torch.nn import functional as F - -from mmselfsup.registry import MODELS - - -@MODELS.register_module() -class GreenMIMReconstructionLoss(BaseModule): - """Loss function for MAE. - - Compute the loss in masked region. - - Args: - encoder_in_channels (int): Number of input channels for encoder. - """ - - def __init__(self, encoder_in_channels: int) -> None: - super().__init__() - self.encoder_in_channels = encoder_in_channels - - def forward(self, pred: torch.Tensor, target: torch.Tensor, - mask: torch.Tensor) -> torch.Tensor: - """Forward function of MAE Loss. - - Args: - pred (torch.Tensor): The reconstructed image. - target (torch.Tensor): The target image. - mask (torch.Tensor): The mask of the target image. - - Returns: - torch.Tensor: The reconstruction loss. - """ - loss_rec = F.l1_loss( - target, pred, reduction='none') # ???: 应该可以复用之前的loss吧 - loss = (loss_rec * mask).sum() / (mask.sum() + - 1e-5) / self.encoder_in_channels - - return loss diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index 90a95ac3a..b740f797f 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -698,12 +698,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class SwinTransformer(BaseModule): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision - Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - +@MODELS.register_module() +class GreenMIMSwinTransformer(BaseBackbone): + r"""GreenMIM with SwinTransformer backbone. Args: img_size (int | tuple(int)): Input image size. Default 224 patch_size (int | tuple(int)): Patch size. Default: 4 @@ -733,27 +730,34 @@ class SwinTransformer(BaseModule): """ def __init__(self, + arch: str = 'B', + stage_cfgs: dict = None, img_size: int = 224, patch_size: int = 4, in_chans: int = 3, - num_classes: int = 1000, embed_dim: int = 96, depths: list = [2, 2, 6, 2], num_heads: list = [3, 6, 12, 24], window_size: int = 7, mlp_ratio: float = 4., + decoder_embed_dim: int = 512, + decoder_depth: int = 8, + decoder_num_heads: int = 16, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + norm_pix_loss: bool = False, qkv_bias: bool = True, qk_scale: float = None, - drop_rate: float = 0., - attn_drop_rate: float = 0., - drop_path_rate: float = 0.1, - norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), ape: bool = False, patch_norm: bool = True, - use_checkpoint: bool = False) -> None: - super().__init__() + drop_path_rate: float = 0.1, + drop_rate: float = 0., + attn_drop_rate: float = 0., + use_checkpoint: bool = False, + init_cfg: Optional[Union[List[dict], dict]] = None, + **kwargs) -> None: + super().__init__(init_cfg=init_cfg) - self.num_classes = num_classes + # SwinTransformer specifics self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape @@ -810,18 +814,11 @@ def __init__(self, self.norm = norm_layer(self.num_features) - def init_weights(self): - super().init_weights() - self.apply(self._init_weights) - - def _init_weights(self, m: nn.Module) -> None: - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + num_patches = np.prod(self.layers[-1].input_resolution) + self.num_patches = num_patches + patch_size = patch_size * 2**(len(depths) - 1) + self.final_patch_size = patch_size + self.norm_pix_loss = norm_pix_loss @torch.jit.ignore def no_weight_decay(self): @@ -831,113 +828,17 @@ def no_weight_decay(self): def no_weight_decay_keywords(self): return {'relative_position_bias_table'} - def forward_features(self, x: torch.Tensor, - mask: torch.Tensor) -> torch.Tensor: - # patch embedding - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - # mask out some patches according to the random mask - B, N, C = x.shape - H, W = self.patches_resolution - ratio = N // mask.shape[1] - mask = mask[:1].clone() # we use the same mask for the whole batch - assert ratio * mask.shape[1] == N - window_size = int(ratio**0.5) - if ratio > 1: # mask_size != patch_embed_size - Mh, Mw = [sz // window_size for sz in self.patches_resolution] - mask = mask.reshape(1, Mh, 1, Mw, 1) - mask = mask.expand(-1, -1, window_size, -1, window_size) - mask = mask.reshape(1, -1) - - # record the corresponding coordinates of visible patches - coords_h = torch.arange(H, device=x.device) - coords_w = torch.arange(W, device=x.device) - coords = torch.stack( - torch.meshgrid([coords_h, coords_w]), dim=-1) # H W 2 - coords = coords.reshape(1, H * W, 2) - - # for convenient, first divide the image into local windows - x = x.view(B, H // window_size, window_size, W // window_size, - window_size, C) - x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, N, C) - mask = mask.view(1, H // window_size, window_size, W // window_size, - window_size) - mask = mask.permute(0, 1, 3, 2, 4).reshape(1, N) - coords = coords.view(1, H // window_size, window_size, - W // window_size, window_size, 2) - coords = coords.permute(0, 1, 3, 2, 4, 5).reshape(1, N, 2) - - # mask out patches - vis_mask = ~mask # ~mask means visible - x_vis = x[vis_mask.expand(B, -1)].reshape(B, -1, C) - coords = coords[vis_mask].reshape(1, -1, 2) # 1 N_vis 2 - - # transformer forward - for layer in self.layers: - x_vis, coords, vis_mask = layer(x_vis, coords, vis_mask) - x_vis = self.norm(x_vis) - - return x_vis - - def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - return self.forward_features(x, mask) - - -@MODELS.register_module() -class GreenMIMSwinTransformer(BaseBackbone): - """Masked Autoencoder with VisionTransformer backbone.""" - - def __init__(self, - arch: str = 'B', - stage_cfgs: dict = None, - img_size: int = 224, - patch_size: int = 4, - in_chans: int = 3, - embed_dim: int = 96, - depths: list = [2, 2, 6, 2], - num_heads: list = [3, 6, 12, 24], - window_size: int = 7, - mlp_ratio: float = 4., - decoder_embed_dim: int = 512, - decoder_depth: int = 8, - decoder_num_heads: int = 16, - norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), - norm_pix_loss: bool = False, - backbone_cls: nn.Module = SwinTransformer, - init_cfg: Optional[Union[List[dict], dict]] = None, - **kwargs) -> None: - super().__init__(init_cfg=init_cfg) - - # MAE encoder specifics - self.encoder = backbone_cls( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - num_classes=0, - embed_dim=embed_dim, - depths=depths, - num_heads=num_heads, - window_size=window_size, - norm_layer=norm_layer, - **kwargs) - num_patches = np.prod(self.encoder.layers[-1].input_resolution) - self.num_patches = num_patches - patch_size = patch_size * 2**(len(depths) - 1) - self.final_patch_size = patch_size - self.norm_pix_loss = norm_pix_loss - def init_weights(self): - super().init_weights() - w = self.encoder.patch_embed.proj.weight.data + # initialization + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) + super().init_weights() - def _init_weights(self, m: nn.Module) -> None: + def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) @@ -973,7 +874,7 @@ def random_masking( """ N, L = 1, self.num_patches # batch, length, dim len_keep = int(L * (1 - mask_ratio)) - + torch.manual_seed(0) # 977行插入这个代码 noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( @@ -1000,6 +901,56 @@ def random_masking( return mask, ids_restore + def forward_features(self, x: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + # patch embedding + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + # mask out some patches according to the random mask + B, N, C = x.shape + H, W = self.patches_resolution + ratio = N // mask.shape[1] + mask = mask[:1].clone() # we use the same mask for the whole batch + assert ratio * mask.shape[1] == N + window_size = int(ratio**0.5) + if ratio > 1: # mask_size != patch_embed_size + Mh, Mw = [sz // window_size for sz in self.patches_resolution] + mask = mask.reshape(1, Mh, 1, Mw, 1) + mask = mask.expand(-1, -1, window_size, -1, window_size) + mask = mask.reshape(1, -1) + + # record the corresponding coordinates of visible patches + coords_h = torch.arange(H, device=x.device) + coords_w = torch.arange(W, device=x.device) + coords = torch.stack( + torch.meshgrid([coords_h, coords_w]), dim=-1) # H W 2 + coords = coords.reshape(1, H * W, 2) + + # for convenient, first divide the image into local windows + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, N, C) + mask = mask.view(1, H // window_size, window_size, W // window_size, + window_size) + mask = mask.permute(0, 1, 3, 2, 4).reshape(1, N) + coords = coords.view(1, H // window_size, window_size, + W // window_size, window_size, 2) + coords = coords.permute(0, 1, 3, 2, 4, 5).reshape(1, N, 2) + + # mask out patches + vis_mask = ~mask # ~mask means visible + x_vis = x[vis_mask.expand(B, -1)].reshape(B, -1, C) + coords = coords[vis_mask].reshape(1, -1, 2) # 1 N_vis 2 + + # transformer forward + for layer in self.layers: + x_vis, coords, vis_mask = layer(x_vis, coords, vis_mask) + x_vis = self.norm(x_vis) + return x_vis + def forward( self, x: torch.Tensor, @@ -1007,6 +958,6 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # generate random mask: B x Token^2,ids_restore mask, ids_restore = self.random_masking(x, mask_ratio) - latent = self.encoder(x, mask.bool()) + latent = self.forward_features(x, mask.bool()) return latent, mask, ids_restore From 8a2c0381bf7fe324d079a3f4b86fcc378e41b901 Mon Sep 17 00:00:00 2001 From: xfguo <2601882982@qq.com> Date: Sun, 26 Feb 2023 00:25:58 +0800 Subject: [PATCH 16/17] rm random seed --- projects/greenmim/models/greenmim_backbone.py | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index b740f797f..78007d14c 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -874,7 +874,6 @@ def random_masking( """ N, L = 1, self.num_patches # batch, length, dim len_keep = int(L * (1 - mask_ratio)) - torch.manual_seed(0) # 977行插入这个代码 noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( From aae08f212e358acd25f1312b3d8bb6ef4c3b465b Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Sun, 26 Feb 2023 13:08:12 +0000 Subject: [PATCH 17/17] [Refine] refine backbone --- projects/greenmim/README.md | 55 ------------------- projects/greenmim/models/greenmim_backbone.py | 23 +++----- projects/greenmim/models/greenmim_neck.py | 18 +++--- 3 files changed, 19 insertions(+), 77 deletions(-) diff --git a/projects/greenmim/README.md b/projects/greenmim/README.md index e06c20376..fa03dbeef 100644 --- a/projects/greenmim/README.md +++ b/projects/greenmim/README.md @@ -112,58 +112,3 @@ Note: year={2022} } ``` - -## Checklist - -Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. - - - -- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. - - - [x] Finish the code - - - - - [x] Basic docstrings & proper citation - - - - - [x] Inference correctness - - - - - [x] A full README - - - -- [x] Milestone 2: Indicates a successful model implementation. - - - [x] Training-time correctness - - - -- [ ] Milestone 3: Good to be a part of our core package! - - - [ ] Type hints and docstrings - - - - - [ ] Unit tests - - - - - [ ] Code polishing - - - - - [ ] `metafile.yml` and `README.md` - - - -- [ ] Refactor and Move your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/greenmim/models/greenmim_backbone.py b/projects/greenmim/models/greenmim_backbone.py index 78007d14c..1db701948 100644 --- a/projects/greenmim/models/greenmim_backbone.py +++ b/projects/greenmim/models/greenmim_backbone.py @@ -730,8 +730,6 @@ class GreenMIMSwinTransformer(BaseBackbone): """ def __init__(self, - arch: str = 'B', - stage_cfgs: dict = None, img_size: int = 224, patch_size: int = 4, in_chans: int = 3, @@ -740,9 +738,6 @@ def __init__(self, num_heads: list = [3, 6, 12, 24], window_size: int = 7, mlp_ratio: float = 4., - decoder_embed_dim: int = 512, - decoder_depth: int = 8, - decoder_num_heads: int = 16, norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), norm_pix_loss: bool = False, qkv_bias: bool = True, @@ -753,8 +748,7 @@ def __init__(self, drop_rate: float = 0., attn_drop_rate: float = 0., use_checkpoint: bool = False, - init_cfg: Optional[Union[List[dict], dict]] = None, - **kwargs) -> None: + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__(init_cfg=init_cfg) # SwinTransformer specifics @@ -829,15 +823,16 @@ def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def init_weights(self): - # initialization - # initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.patch_embed.proj.weight.data - torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - # initialize nn.Linear and nn.LayerNorm - self.apply(self._init_weights) super().init_weights() + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + self.apply(self._init_weights) + def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: diff --git a/projects/greenmim/models/greenmim_neck.py b/projects/greenmim/models/greenmim_neck.py index 9e1537818..a1b94bca9 100644 --- a/projects/greenmim/models/greenmim_neck.py +++ b/projects/greenmim/models/greenmim_neck.py @@ -18,7 +18,6 @@ class GreenMIMNeck(BaseModule): def __init__( self, in_channels: int, - encoder_stride: int, img_size: int, patch_size: int, embed_dim: int = 96, @@ -66,14 +65,17 @@ def init_weights(self): """Initialize position embedding, patch embedding.""" super().init_weights() - decoder_pos_embed = get_2d_sincos_pos_embed( - self.decoder_pos_embed.shape[-1], - int(self.num_patches**.5), - cls_token=False) - self.decoder_pos_embed.data.copy_( - torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): - torch.nn.init.normal_(self.mask_token, std=.02) + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.num_patches**.5), + cls_token=False) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + torch.nn.init.normal_(self.mask_token, std=.02) def forward(self, x: torch.Tensor, ids_restore: torch.Tensor) -> torch.Tensor: