Skip to content

Commit f407c85

Browse files
add vit/dino implementation (no xformers). implement factory class for generating dinov2. update anomaly_dino to use factory method
1 parent 3e3080c commit f407c85

File tree

13 files changed

+1795
-5
lines changed

13 files changed

+1795
-5
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
"""Anomalib's Vision Transformer implementation.
6+
7+
References:
8+
https://github.com/facebookresearch/dinov2/blob/main/dinov2/
9+
"""
10+
11+
# vision transformer
12+
# loader
13+
from .dinov2_loader import DinoV2Loader
14+
from .vision_transformer import (
15+
DinoVisionTransformer,
16+
vit_base,
17+
vit_giant2,
18+
vit_large,
19+
vit_small,
20+
)
21+
22+
__all__ = [
23+
# vision transformer
24+
"DinoVisionTransformer",
25+
"vit_base",
26+
"vit_giant2",
27+
"vit_large",
28+
"vit_small",
29+
# loader
30+
"DinoV2Loader",
31+
]
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Loader for DINOv2 Vision Transformer models.
5+
6+
This module provides a simple interface for loading pre-trained DINOv2 Vision Transformer models for the
7+
Dinomaly anomaly detection framework.
8+
9+
Example:
10+
model = DinoV2Loader.from_name("dinov2_vit_base_14")
11+
model = DinoV2Loader.from_name("dinomaly_vit_base_14")
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import logging
17+
from pathlib import Path
18+
from typing import ClassVar
19+
from urllib.request import urlretrieve
20+
21+
import torch
22+
23+
from anomalib.data.utils import DownloadInfo
24+
from anomalib.data.utils.download import DownloadProgressBar
25+
from anomalib.models.components.dinov2 import vision_transformer as dinov2_models
26+
from anomalib.models.image.dinomaly.components import vision_transformer as dinomaly_models
27+
28+
logger = logging.getLogger(__name__)
29+
30+
MODEL_FACTORIES: dict[str, object] = {
31+
"dinov2": dinov2_models,
32+
"dinov2_reg": dinov2_models,
33+
"dinomaly": dinomaly_models,
34+
}
35+
36+
37+
class DinoV2Loader:
38+
"""Simple loader for DINOv2 Vision Transformer models.
39+
40+
Supports loading dinov2, dinov2_reg, and dinomaly model variants across small, base,
41+
and large architectures.
42+
"""
43+
44+
DINOV2_BASE_URL: ClassVar[str] = "https://dl.fbaipublicfiles.com/dinov2"
45+
46+
MODEL_CONFIGS: ClassVar[dict[str, dict[str, int]]] = {
47+
"small": {"embed_dim": 384, "num_heads": 6},
48+
"base": {"embed_dim": 768, "num_heads": 12},
49+
"large": {"embed_dim": 1024, "num_heads": 16},
50+
}
51+
52+
def __init__(self, cache_dir: str | Path = "./pre_trained/") -> None:
53+
"""Initialize a model loader instance.
54+
55+
Args:
56+
cache_dir: Directory in which downloaded weights will be stored.
57+
"""
58+
self.cache_dir: Path = Path(cache_dir)
59+
self.cache_dir.mkdir(parents=True, exist_ok=True)
60+
61+
def load(self, model_name: str) -> torch.nn.Module:
62+
"""Load a DINOv2 model by name.
63+
64+
Args:
65+
model_name: Model identifier such as "dinov2_vit_base_14".
66+
67+
Returns:
68+
A fully constructed and weight-loaded PyTorch module.
69+
70+
Raises:
71+
ValueError: If the requested model name is malformed or unsupported.
72+
"""
73+
model_type, architecture, patch_size = self._parse_name(model_name)
74+
model = self._create_model(model_type, architecture, patch_size)
75+
self._load_weights(model, model_type, architecture, patch_size)
76+
77+
logger.info(f"Loaded model: {model_name}")
78+
return model
79+
80+
@classmethod
81+
def from_name(
82+
cls,
83+
model_name: str,
84+
cache_dir: str | Path = "./pre_trained/",
85+
) -> torch.nn.Module:
86+
"""Instantiate a loader and return the requested model."""
87+
loader = cls(cache_dir)
88+
return loader.load(model_name)
89+
90+
def _parse_name(self, name: str) -> tuple[str, str, int]:
91+
"""Parse a model name string into components.
92+
93+
Args:
94+
name: Full model name string.
95+
96+
Returns:
97+
Tuple of (model_type, architecture_name, patch_size).
98+
99+
Raises:
100+
ValueError: If the prefix or architecture is unknown.
101+
"""
102+
parts = name.split("_")
103+
prefix = parts[0]
104+
architecture = parts[-2]
105+
patch_size = int(parts[-1])
106+
107+
if prefix == "dinov2reg":
108+
model_type = "dinov2_reg"
109+
elif prefix == "dinov2":
110+
model_type = "dinov2"
111+
elif prefix == "dinomaly":
112+
model_type = "dinomaly"
113+
else:
114+
msg = f"Unknown model type prefix '{prefix}'."
115+
raise ValueError(msg)
116+
117+
if architecture not in self.MODEL_CONFIGS:
118+
msg = f"Invalid architecture '{architecture}'. Expected one of: {list(self.MODEL_CONFIGS)}"
119+
raise ValueError(
120+
msg,
121+
)
122+
123+
return model_type, architecture, patch_size
124+
125+
@staticmethod
126+
def _create_model(
127+
model_type: str,
128+
architecture: str,
129+
patch_size: int,
130+
) -> torch.nn.Module:
131+
"""Construct a model instance using the configured factory modules.
132+
133+
Args:
134+
model_type: Model family, e.g., "dinov2", "dinov2_reg", "dinomaly".
135+
architecture: Architecture label ("small", "base", "large").
136+
patch_size: Patch resolution.
137+
138+
Returns:
139+
An instantiated PyTorch module.
140+
141+
Raises:
142+
ValueError: If the relevant constructor cannot be found.
143+
"""
144+
model_kwargs: dict[str, object] = {
145+
"patch_size": patch_size,
146+
"img_size": 518,
147+
"block_chunks": 0,
148+
"init_values": 1e-8,
149+
"interpolate_antialias": False,
150+
"interpolate_offset": 0.1,
151+
}
152+
153+
if model_type == "dinov2_reg":
154+
model_kwargs["num_register_tokens"] = 4
155+
156+
module = MODEL_FACTORIES.get(model_type)
157+
if module is None:
158+
msg = f"Unknown model type '{model_type}'."
159+
raise ValueError(msg)
160+
161+
ctor = getattr(module, f"vit_{architecture}", None)
162+
if ctor is None:
163+
msg = f"No constructor 'vit_{architecture}' in module {module}."
164+
raise ValueError(msg)
165+
166+
model: torch.nn.Module = ctor(**model_kwargs)
167+
return model
168+
169+
def _load_weights(
170+
self,
171+
model: torch.nn.Module,
172+
model_type: str,
173+
architecture: str,
174+
patch_size: int,
175+
) -> None:
176+
"""Load pre-trained weights from disk, downloading them if necessary."""
177+
weight_path = self._get_weight_path(model_type, architecture, patch_size)
178+
179+
if not weight_path.exists():
180+
self._download_weights(model_type, architecture, patch_size)
181+
182+
# Using weights_only=True for safety mitigation (see Anomalib PR #2729)
183+
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) # nosec B614
184+
model.load_state_dict(state_dict, strict=False)
185+
186+
def _get_weight_path(
187+
self,
188+
model_type: str,
189+
architecture: str,
190+
patch_size: int,
191+
) -> Path:
192+
"""Return the expected local path for downloaded weights."""
193+
arch_code = architecture[0]
194+
195+
if model_type == "dinov2_reg":
196+
filename = f"dinov2_vit{arch_code}{patch_size}_reg4_pretrain.pth"
197+
else:
198+
filename = f"dinov2_vit{arch_code}{patch_size}_pretrain.pth"
199+
200+
return self.cache_dir / filename
201+
202+
def _download_weights(
203+
self,
204+
model_type: str,
205+
architecture: str,
206+
patch_size: int,
207+
) -> None:
208+
"""Download DINOv2 weight files using Anomalib's standardized utilities."""
209+
weight_path = self._get_weight_path(model_type, architecture, patch_size)
210+
arch_code = architecture[0]
211+
212+
model_dir = f"dinov2_vit{arch_code}{patch_size}"
213+
url = f"{self.DINOV2_BASE_URL}/{model_dir}/{weight_path.name}"
214+
215+
download_info = DownloadInfo(
216+
name=f"DINOv2 {model_type} {architecture} weights",
217+
url=url,
218+
hashsum="", # DINOv2 publishes no official hash
219+
filename=weight_path.name,
220+
)
221+
222+
logger.info(
223+
f"Downloading DINOv2 weights: {weight_path.name} to {self.cache_dir}",
224+
)
225+
226+
self.cache_dir.mkdir(parents=True, exist_ok=True)
227+
228+
with DownloadProgressBar(
229+
unit="B",
230+
unit_scale=True,
231+
miniters=1,
232+
desc=download_info.name,
233+
) as progress_bar:
234+
# nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected # noqa: ERA001, E501
235+
urlretrieve( # noqa: S310 # nosec B310
236+
url=url,
237+
filename=weight_path,
238+
reporthook=progress_bar.update_to,
239+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Layers needed to build DINOv2.
5+
6+
References:
7+
https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/__init__.py
8+
"""
9+
10+
from .attention import Attention, MemEffAttention
11+
from .block import Block, CausalAttentionBlock
12+
from .dino_head import DINOHead
13+
from .drop_path import DropPath
14+
from .layer_scale import LayerScale
15+
from .mlp import Mlp
16+
from .patch_embed import PatchEmbed
17+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNAligned, SwiGLUFFNFused
18+
19+
__all__ = [
20+
"Attention",
21+
"CausalAttentionBlock",
22+
"Block",
23+
"DINOHead",
24+
"DropPath",
25+
"LayerScale",
26+
"MemEffAttention",
27+
"Mlp",
28+
"PatchEmbed",
29+
"SwiGLUFFN",
30+
"SwiGLUFFNAligned",
31+
"SwiGLUFFNFused",
32+
]

0 commit comments

Comments
 (0)