Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dinov2.fsdp import FSDPCheckpointer
from dinov2.logging import MetricLogger
from dinov2.utils.config import setup
from dinov2.utils.utils import CosineScheduler
from dinov2.utils.utils import MemEfficientCosineScheduler

from dinov2.train.ssl_meta_arch import SSLMetaArch

Expand Down Expand Up @@ -89,15 +89,14 @@ def build_schedulers(cfg):
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
)

lr_schedule = CosineScheduler(**lr)
wd_schedule = CosineScheduler(**wd)
momentum_schedule = CosineScheduler(**momentum)
teacher_temp_schedule = CosineScheduler(**teacher_temp)
last_layer_lr_schedule = CosineScheduler(**lr)

last_layer_lr_schedule.schedule[
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
] = 0 # mimicking the original schedules
lr_schedule = MemEfficientCosineScheduler(**lr)
wd_schedule = MemEfficientCosineScheduler(**wd)
momentum_schedule = MemEfficientCosineScheduler(**momentum)
teacher_temp_schedule = MemEfficientCosineScheduler(**teacher_temp)
# this is a hack to mimic the original schedules
_lr = lr.copy()
_lr.update(freeze_iters=cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH)
last_layer_lr_schedule = MemEfficientCosineScheduler(**_lr)

logger.info("Schedulers ready.")

Expand Down
31 changes: 31 additions & 0 deletions dinov2/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,38 @@ def __getitem__(self, it):
return self.final_value
else:
return self.schedule[it]


class MemEfficientCosineScheduler:
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
super().__init__()
self.final_value = final_value
self.total_iters = total_iters
self.start_warmup_value = start_warmup_value
self.base_value = base_value
self.freeze_iters = freeze_iters
self.warmup_iters = warmup_iters

def __getitem__(self, it):
if it >= self.total_iters:
return self.final_value

if it < self.freeze_iters:
return 0.0

if it < self.freeze_iters + self.warmup_iters:
# Linear warmup - fixed to match original implementation
alpha = (it - self.freeze_iters) / max(1, self.warmup_iters)
value = self.start_warmup_value * (1 - alpha) + self.base_value * alpha
return value

# Cosine schedule - this part needed adjustment to match CosineScheduler
effective_it = it - self.freeze_iters - self.warmup_iters
total_cosine_iters = self.total_iters - self.warmup_iters - self.freeze_iters
return self.final_value + 0.5 * (self.base_value - self.final_value) * (
1 + np.cos(np.pi * effective_it / total_cosine_iters)
)


def has_batchnorms(model):
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
Expand Down