Skip to content

Commit 8f35785

Browse files
feat: adain norm, mod conv for unet3d/vsharp3d
1 parent 2169d7f commit 8f35785

File tree

9 files changed

+1802
-79
lines changed

9 files changed

+1802
-79
lines changed

direct/nn/adain/__init__.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from enum import Enum
2+
3+
import torch
4+
from torch import nn
5+
6+
__all__ = ["AdaIN2d", "AdaIN3d"]
7+
8+
9+
class NormType(str, Enum):
10+
INSTANCE = "instance"
11+
ADAIN = "adain"
12+
13+
14+
import torch
15+
from torch import nn
16+
17+
18+
class AdaIN2d(nn.Module):
19+
"""
20+
Adaptive Instance Normalization for 2D tensors:
21+
x: (B, C, H, W)
22+
y: (B, F) auxiliary vector
23+
Produces per-sample, per-channel affine params from y.
24+
"""
25+
26+
def __init__(
27+
self,
28+
num_channels: int,
29+
aux_in_features: int,
30+
hidden_features: int | tuple[int, ...] | None = None,
31+
activation: nn.Module | None = None,
32+
eps: float = 1e-5,
33+
use_one_plus_gamma: bool = True,
34+
):
35+
super().__init__()
36+
self.num_channels = num_channels
37+
self.eps = eps
38+
self.use_one_plus_gamma = use_one_plus_gamma
39+
40+
if activation is None:
41+
activation = nn.SiLU()
42+
43+
# Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
44+
if hidden_features is None:
45+
hidden = []
46+
elif isinstance(hidden_features, int):
47+
hidden = [hidden_features]
48+
else:
49+
hidden = list(hidden_features)
50+
51+
layers: list[nn.Module] = []
52+
in_f = aux_in_features
53+
for h in hidden:
54+
layers += [nn.Linear(in_f, h), activation]
55+
in_f = h
56+
layers += [nn.Linear(in_f, 2 * num_channels)]
57+
self.mlp = nn.Sequential(*layers)
58+
59+
# Initialize last layer near-zero so AdaIN starts close to plain IN
60+
if isinstance(self.mlp[-1], nn.Linear):
61+
nn.init.zeros_(self.mlp[-1].weight)
62+
nn.init.zeros_(self.mlp[-1].bias)
63+
64+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
65+
# Instance-style normalization over spatial dims (H,W), per (B,C)
66+
mean = x.mean(dim=(2, 3), keepdim=True)
67+
var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
68+
x_norm = (x - mean) / torch.sqrt(var + self.eps)
69+
70+
# Produce gamma/beta from y
71+
params = self.mlp(y) # (B, 2C)
72+
gamma, beta = params.chunk(2, 1) # each (B, C)
73+
74+
gamma = gamma.view(-1, self.num_channels, 1, 1)
75+
beta = beta.view(-1, self.num_channels, 1, 1)
76+
77+
if self.use_one_plus_gamma:
78+
return x_norm * (1.0 + gamma) + beta
79+
return x_norm * gamma + beta
80+
81+
82+
class AdaIN3d(nn.Module):
83+
"""
84+
Adaptive Instance Normalization for 3D tensors:
85+
x: (B, C, Z, H, W)
86+
y: (B, F) auxiliary vector
87+
Produces per-sample, per-channel affine params from y.
88+
"""
89+
90+
def __init__(
91+
self,
92+
num_channels: int,
93+
aux_in_features: int,
94+
hidden_features: int | tuple[int, ...] | None = None,
95+
activation: nn.Module | None = None,
96+
eps: float = 1e-5,
97+
use_one_plus_gamma: bool = True,
98+
):
99+
super().__init__()
100+
self.num_channels = num_channels
101+
self.eps = eps
102+
self.use_one_plus_gamma = use_one_plus_gamma
103+
104+
if activation is None:
105+
activation = nn.SiLU()
106+
107+
# Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
108+
if hidden_features is None:
109+
hidden = []
110+
elif isinstance(hidden_features, int):
111+
hidden = [hidden_features]
112+
else:
113+
hidden = list(hidden_features)
114+
115+
layers: list[nn.Module] = []
116+
in_f = aux_in_features
117+
for h in hidden:
118+
layers += [nn.Linear(in_f, h), activation]
119+
in_f = h
120+
layers += [nn.Linear(in_f, 2 * num_channels)]
121+
self.mlp = nn.Sequential(*layers)
122+
123+
# Optional: initialize last layer to near-zero so AdaIN starts close to plain IN
124+
if isinstance(self.mlp[-1], nn.Linear):
125+
nn.init.zeros_(self.mlp[-1].weight)
126+
nn.init.zeros_(self.mlp[-1].bias)
127+
128+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
129+
130+
# Instance-style normalization over spatial dims (Z,H,W), per (B,C)
131+
mean = x.mean(dim=(2, 3, 4), keepdim=True)
132+
var = x.var(dim=(2, 3, 4), keepdim=True, unbiased=False)
133+
x_norm = (x - mean) / torch.sqrt(var + self.eps)
134+
135+
# Produce gamma/beta from y
136+
params = self.mlp(y) # (B, 2C)
137+
gamma, beta = params.chunk(2, dim=-1) # each (B, C)
138+
139+
gamma = gamma.view(-1, self.num_channels, 1, 1, 1)
140+
beta = beta.view(-1, self.num_channels, 1, 1, 1)
141+
142+
if self.use_one_plus_gamma:
143+
return x_norm * (1.0 + gamma) + beta
144+
return x_norm * gamma + beta

direct/nn/adain/adain.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from enum import Enum
2+
3+
import torch
4+
from torch import nn
5+
6+
__all__ = ["AdaIN2d", "AdaIN3d"]
7+
8+
9+
class NormType(str, Enum):
10+
INSTANCE = "instance"
11+
ADAIN = "adain"
12+
13+
14+
import torch
15+
from torch import nn
16+
17+
18+
class AdaIN2d(nn.Module):
19+
"""
20+
Adaptive Instance Normalization for 2D tensors:
21+
x: (B, C, H, W)
22+
y: (B, F) auxiliary vector
23+
Produces per-sample, per-channel affine params from y.
24+
"""
25+
26+
def __init__(
27+
self,
28+
num_channels: int,
29+
aux_in_features: int,
30+
hidden_features: int | tuple[int, ...] | None = None,
31+
activation: nn.Module | None = None,
32+
eps: float = 1e-5,
33+
use_one_plus_gamma: bool = True,
34+
):
35+
super().__init__()
36+
self.num_channels = num_channels
37+
self.eps = eps
38+
self.use_one_plus_gamma = use_one_plus_gamma
39+
40+
if activation is None:
41+
activation = nn.SiLU()
42+
43+
# Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
44+
if hidden_features is None:
45+
hidden = []
46+
elif isinstance(hidden_features, int):
47+
hidden = [hidden_features]
48+
else:
49+
hidden = list(hidden_features)
50+
51+
layers: list[nn.Module] = []
52+
in_f = aux_in_features
53+
for h in hidden:
54+
layers += [nn.Linear(in_f, h), activation]
55+
in_f = h
56+
layers += [nn.Linear(in_f, 2 * num_channels)]
57+
self.mlp = nn.Sequential(*layers)
58+
59+
# Initialize last layer near-zero so AdaIN starts close to plain IN
60+
if isinstance(self.mlp[-1], nn.Linear):
61+
nn.init.zeros_(self.mlp[-1].weight)
62+
nn.init.zeros_(self.mlp[-1].bias)
63+
64+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
65+
# Instance-style normalization over spatial dims (H,W), per (B,C)
66+
mean = x.mean(dim=(2, 3), keepdim=True)
67+
var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
68+
x_norm = (x - mean) / torch.sqrt(var + self.eps)
69+
70+
# Produce gamma/beta from y
71+
params = self.mlp(y) # (B, 2C)
72+
gamma, beta = params.chunk(2, 1) # each (B, C)
73+
74+
gamma = gamma.view(-1, self.num_channels, 1, 1)
75+
beta = beta.view(-1, self.num_channels, 1, 1)
76+
77+
if self.use_one_plus_gamma:
78+
return x_norm * (1.0 + gamma) + beta
79+
return x_norm * gamma + beta
80+
81+
82+
class AdaIN3d(nn.Module):
83+
"""
84+
Adaptive Instance Normalization for 3D tensors:
85+
x: (B, C, Z, H, W)
86+
y: (B, F) auxiliary vector
87+
Produces per-sample, per-channel affine params from y.
88+
"""
89+
90+
def __init__(
91+
self,
92+
num_channels: int,
93+
aux_in_features: int,
94+
hidden_features: int | tuple[int, ...] | None = None,
95+
activation: nn.Module | None = None,
96+
eps: float = 1e-5,
97+
use_one_plus_gamma: bool = True,
98+
):
99+
super().__init__()
100+
self.num_channels = num_channels
101+
self.eps = eps
102+
self.use_one_plus_gamma = use_one_plus_gamma
103+
104+
if activation is None:
105+
activation = nn.SiLU()
106+
107+
# Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
108+
if hidden_features is None:
109+
hidden = []
110+
elif isinstance(hidden_features, int):
111+
hidden = [hidden_features]
112+
else:
113+
hidden = list(hidden_features)
114+
115+
layers: list[nn.Module] = []
116+
in_f = aux_in_features
117+
for h in hidden:
118+
layers += [nn.Linear(in_f, h), activation]
119+
in_f = h
120+
layers += [nn.Linear(in_f, 2 * num_channels)]
121+
self.mlp = nn.Sequential(*layers)
122+
123+
# Optional: initialize last layer to near-zero so AdaIN starts close to plain IN
124+
if isinstance(self.mlp[-1], nn.Linear):
125+
nn.init.zeros_(self.mlp[-1].weight)
126+
nn.init.zeros_(self.mlp[-1].bias)
127+
128+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
129+
130+
# Instance-style normalization over spatial dims (Z,H,W), per (B,C)
131+
mean = x.mean(dim=(2, 3, 4), keepdim=True)
132+
var = x.var(dim=(2, 3, 4), keepdim=True, unbiased=False)
133+
x_norm = (x - mean) / torch.sqrt(var + self.eps)
134+
135+
# Produce gamma/beta from y
136+
params = self.mlp(y) # (B, 2C)
137+
gamma, beta = params.chunk(2, dim=-1) # each (B, C)
138+
139+
gamma = gamma.view(-1, self.num_channels, 1, 1, 1)
140+
beta = beta.view(-1, self.num_channels, 1, 1, 1)
141+
142+
if self.use_one_plus_gamma:
143+
return x_norm * (1.0 + gamma) + beta
144+
return x_norm * gamma + beta

direct/nn/conv/conv.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
"""direct.nn.conv module."""
44

5-
65
from typing import List
76

87
import torch

0 commit comments

Comments
 (0)