Skip to content

Commit 6fc2126

Browse files
feat: enable old py to work
1 parent c87087a commit 6fc2126

File tree

1 file changed

+2
-143
lines changed

1 file changed

+2
-143
lines changed

direct/nn/adain/__init__.py

Lines changed: 2 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,3 @@
1-
from enum import Enum
1+
from direct.nn.adain.adain import AdaIN2d, AdaIN3d, NormType
22

3-
import torch
4-
from torch import nn
5-
6-
__all__ = ["AdaIN2d", "AdaIN3d"]
7-
8-
9-
class NormType(str, Enum):
10-
INSTANCE = "instance"
11-
ADAIN = "adain"
12-
13-
14-
import torch
15-
from torch import nn
16-
17-
18-
class AdaIN2d(nn.Module):
19-
"""
20-
Adaptive Instance Normalization for 2D tensors:
21-
x: (B, C, H, W)
22-
y: (B, F) auxiliary vector
23-
Produces per-sample, per-channel affine params from y.
24-
"""
25-
26-
def __init__(
27-
self,
28-
num_channels: int,
29-
aux_in_features: int,
30-
hidden_features: int | tuple[int, ...] | None = None,
31-
activation: nn.Module | None = None,
32-
eps: float = 1e-5,
33-
use_one_plus_gamma: bool = True,
34-
):
35-
super().__init__()
36-
self.num_channels = num_channels
37-
self.eps = eps
38-
self.use_one_plus_gamma = use_one_plus_gamma
39-
40-
if activation is None:
41-
activation = nn.SiLU()
42-
43-
# Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
44-
if hidden_features is None:
45-
hidden = []
46-
elif isinstance(hidden_features, int):
47-
hidden = [hidden_features]
48-
else:
49-
hidden = list(hidden_features)
50-
51-
layers: list[nn.Module] = []
52-
in_f = aux_in_features
53-
for h in hidden:
54-
layers += [nn.Linear(in_f, h), activation]
55-
in_f = h
56-
layers += [nn.Linear(in_f, 2 * num_channels)]
57-
self.mlp = nn.Sequential(*layers)
58-
59-
# Initialize last layer near-zero so AdaIN starts close to plain IN
60-
if isinstance(self.mlp[-1], nn.Linear):
61-
nn.init.zeros_(self.mlp[-1].weight)
62-
nn.init.zeros_(self.mlp[-1].bias)
63-
64-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
65-
# Instance-style normalization over spatial dims (H,W), per (B,C)
66-
mean = x.mean(dim=(2, 3), keepdim=True)
67-
var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
68-
x_norm = (x - mean) / torch.sqrt(var + self.eps)
69-
70-
# Produce gamma/beta from y
71-
params = self.mlp(y) # (B, 2C)
72-
gamma, beta = params.chunk(2, 1) # each (B, C)
73-
74-
gamma = gamma.view(-1, self.num_channels, 1, 1)
75-
beta = beta.view(-1, self.num_channels, 1, 1)
76-
77-
if self.use_one_plus_gamma:
78-
return x_norm * (1.0 + gamma) + beta
79-
return x_norm * gamma + beta
80-
81-
82-
class AdaIN3d(nn.Module):
83-
"""
84-
Adaptive Instance Normalization for 3D tensors:
85-
x: (B, C, Z, H, W)
86-
y: (B, F) auxiliary vector
87-
Produces per-sample, per-channel affine params from y.
88-
"""
89-
90-
def __init__(
91-
self,
92-
num_channels: int,
93-
aux_in_features: int,
94-
hidden_features: int | tuple[int, ...] | None = None,
95-
activation: nn.Module | None = None,
96-
eps: float = 1e-5,
97-
use_one_plus_gamma: bool = True,
98-
):
99-
super().__init__()
100-
self.num_channels = num_channels
101-
self.eps = eps
102-
self.use_one_plus_gamma = use_one_plus_gamma
103-
104-
if activation is None:
105-
activation = nn.SiLU()
106-
107-
# Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
108-
if hidden_features is None:
109-
hidden = []
110-
elif isinstance(hidden_features, int):
111-
hidden = [hidden_features]
112-
else:
113-
hidden = list(hidden_features)
114-
115-
layers: list[nn.Module] = []
116-
in_f = aux_in_features
117-
for h in hidden:
118-
layers += [nn.Linear(in_f, h), activation]
119-
in_f = h
120-
layers += [nn.Linear(in_f, 2 * num_channels)]
121-
self.mlp = nn.Sequential(*layers)
122-
123-
# Optional: initialize last layer to near-zero so AdaIN starts close to plain IN
124-
if isinstance(self.mlp[-1], nn.Linear):
125-
nn.init.zeros_(self.mlp[-1].weight)
126-
nn.init.zeros_(self.mlp[-1].bias)
127-
128-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
129-
130-
# Instance-style normalization over spatial dims (Z,H,W), per (B,C)
131-
mean = x.mean(dim=(2, 3, 4), keepdim=True)
132-
var = x.var(dim=(2, 3, 4), keepdim=True, unbiased=False)
133-
x_norm = (x - mean) / torch.sqrt(var + self.eps)
134-
135-
# Produce gamma/beta from y
136-
params = self.mlp(y) # (B, 2C)
137-
gamma, beta = params.chunk(2, dim=-1) # each (B, C)
138-
139-
gamma = gamma.view(-1, self.num_channels, 1, 1, 1)
140-
beta = beta.view(-1, self.num_channels, 1, 1, 1)
141-
142-
if self.use_one_plus_gamma:
143-
return x_norm * (1.0 + gamma) + beta
144-
return x_norm * gamma + beta
3+
__all__ = ["AdaIN2d", "AdaIN3d", "NormType"]

0 commit comments

Comments
 (0)