Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions dinov2/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@

import torch
import dinov2.distributed as distributed
from functools import partial
from fvcore.common.checkpoint import Checkpointer
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp._runtime_utils import _reshard


def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
def parse_fsdp_config(model_cfg):
sharding_strategy_dict = {
"NO_SHARD": ShardingStrategy.NO_SHARD,
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
Expand All @@ -40,18 +38,7 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):

sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]

local_rank = distributed.get_local_rank()

fsdp_wrapper = partial(
FSDP,
sharding_strategy=sharding_strategy_config,
mixed_precision=mixed_precision_config,
device_id=local_rank,
sync_module_states=True,
use_orig_params=True,
auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
)
return fsdp_wrapper
return dict(sharding_strategy=sharding_strategy_config, mixed_precision=mixed_precision_config)


def is_fsdp(x):
Expand Down
9 changes: 1 addition & 8 deletions dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, incl
return module


class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x


class DinoVisionTransformer(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -144,7 +137,7 @@ def f(*args, **kwargs):
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
self.blocks = nn.ModuleList([nn.Sequential(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
Expand Down
30 changes: 15 additions & 15 deletions dinov2/train/ssl_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
from dinov2.layers import DINOHead
from dinov2.utils.utils import has_batchnorms
from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups
from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model

from dinov2.models.vision_transformer import BlockChunk
from dinov2.fsdp import parse_fsdp_config, ShardedGradScaler, reshard_fsdp_model

from torch.distributed.fsdp.wrap import wrap

try:
from xformers.ops import fmha
Expand Down Expand Up @@ -120,6 +119,10 @@ def __init__(self, cfg):
p.requires_grad = False
logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.")

for k, v in self.student.items():
logger.info(f"Syncing student and teacher submodule: {k}")
self.teacher[k].load_state_dict(v.state_dict())

def forward(self, inputs):
raise NotImplementedError

Expand Down Expand Up @@ -353,16 +356,15 @@ def fsdp_synchronize_streams(self):
) = self.student.backbone._streams = self.teacher.backbone._streams
self.need_to_synchronize_fsdp_streams = False

def update_teacher(self, m):
@torch.no_grad()
def update_teacher(self, m: float) -> None:
student_param_list = []
teacher_param_list = []
with torch.no_grad():
for k in self.student.keys():
for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])):
student_param_list += ms.params
teacher_param_list += mt.params
torch._foreach_mul_(teacher_param_list, m)
torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m)
for k in self.student.keys():
student_param_list.extend(self.student[k].parameters())
teacher_param_list.extend(self.teacher[k].parameters())
torch._foreach_mul_(teacher_param_list, m)
torch._foreach_add_(teacher_param_list, student_param_list, alpha=1.0 - m)

def train(self):
super().train()
Expand Down Expand Up @@ -391,10 +393,8 @@ def prepare_for_distributed_training(self):
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
if has_batchnorms(self.student):
raise NotImplementedError
# below will synchronize all student subnetworks across gpus:
for k, v in self.student.items():
self.teacher[k].load_state_dict(self.student[k].state_dict())
student_model_cfg = self.cfg.compute_precision.student[k]
self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
self.student[k] = wrap(v, **parse_fsdp_config(student_model_cfg))
teacher_model_cfg = self.cfg.compute_precision.teacher[k]
self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
self.student[k] = wrap(self.teacher[k], **parse_fsdp_config(teacher_model_cfg))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.teacher[k] = ....

11 changes: 10 additions & 1 deletion dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from dinov2.logging import MetricLogger
from dinov2.utils.config import setup
from dinov2.utils.utils import CosineScheduler
from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy
from torch.distributed.fsdp import FullyShardedDataParallel
from dinov2.layers.block import Block

from dinov2.train.ssl_meta_arch import SSLMetaArch

Expand Down Expand Up @@ -298,7 +301,13 @@ def main(args):
cfg = setup(args)

model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}),
sync_module_states=True,
use_orig_params=True,
):
model.prepare_for_distributed_training()

logger.info("Model:\n{}".format(model))
if args.eval_only:
Expand Down