From 316a8e8f47d98cd4ae06864fcf6a25d329120404 Mon Sep 17 00:00:00 2001 From: Duc-Viet Hoang Date: Sat, 19 Jul 2025 12:48:12 +0900 Subject: [PATCH 1/2] DaVIT Channel Attention to use fused attention --- timm/models/davit.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 22b4a1a05f..690f9eb226 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -136,18 +136,23 @@ def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) + qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) if self.dynamic_scale: - q = q * N ** -0.5 + scale = N ** -0.5 else: - q = q * self.head_dim ** -0.5 - attn = q.transpose(-1, -2) @ k - attn = attn.softmax(dim=-1) - x = (attn @ v.transpose(-1, -2)).transpose(-1, -2) + scale = self.head_dim ** -0.5 - x = x.transpose(1, 2).reshape(B, N, C) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, scale=scale) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + x = attn @ v + + x = x.permute(0, 3, 2, 1).reshape(B, N, C) x = self.proj(x) return x From f9e9cc9e1ee3f3e92657ad604a4b3adde5a5d985 Mon Sep 17 00:00:00 2001 From: Duc-Viet Hoang Date: Sun, 20 Jul 2025 08:23:51 +0700 Subject: [PATCH 2/2] Update davit.py --- timm/models/davit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 690f9eb226..347750652e 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -129,6 +129,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True): self.groups = num_heads self.head_dim = dim // num_heads self.dynamic_scale = dynamic_scale + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) @@ -147,12 +148,12 @@ def forward(self, x): if self.fused_attn: x = F.scaled_dot_product_attention(q, k, v, scale=scale) else: - q = q * self.scale + q = q * scale attn = (q @ k.transpose(-2, -1)) - attn = self.softmax(attn) + attn = attn.softmax(dim=-1) x = attn @ v - x = x.permute(0, 3, 2, 1).reshape(B, N, C) + x = x.permute(0, 3, 1, 2).reshape(B, N, C) x = self.proj(x) return x