1- from enum import Enum
1+ from direct . nn . adain . adain import AdaIN2d , AdaIN3d , NormType
22
3- import torch
4- from torch import nn
5-
6- __all__ = ["AdaIN2d" , "AdaIN3d" ]
7-
8-
9- class NormType (str , Enum ):
10- INSTANCE = "instance"
11- ADAIN = "adain"
12-
13-
14- import torch
15- from torch import nn
16-
17-
18- class AdaIN2d (nn .Module ):
19- """
20- Adaptive Instance Normalization for 2D tensors:
21- x: (B, C, H, W)
22- y: (B, F) auxiliary vector
23- Produces per-sample, per-channel affine params from y.
24- """
25-
26- def __init__ (
27- self ,
28- num_channels : int ,
29- aux_in_features : int ,
30- hidden_features : int | tuple [int , ...] | None = None ,
31- activation : nn .Module | None = None ,
32- eps : float = 1e-5 ,
33- use_one_plus_gamma : bool = True ,
34- ):
35- super ().__init__ ()
36- self .num_channels = num_channels
37- self .eps = eps
38- self .use_one_plus_gamma = use_one_plus_gamma
39-
40- if activation is None :
41- activation = nn .SiLU ()
42-
43- # Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
44- if hidden_features is None :
45- hidden = []
46- elif isinstance (hidden_features , int ):
47- hidden = [hidden_features ]
48- else :
49- hidden = list (hidden_features )
50-
51- layers : list [nn .Module ] = []
52- in_f = aux_in_features
53- for h in hidden :
54- layers += [nn .Linear (in_f , h ), activation ]
55- in_f = h
56- layers += [nn .Linear (in_f , 2 * num_channels )]
57- self .mlp = nn .Sequential (* layers )
58-
59- # Initialize last layer near-zero so AdaIN starts close to plain IN
60- if isinstance (self .mlp [- 1 ], nn .Linear ):
61- nn .init .zeros_ (self .mlp [- 1 ].weight )
62- nn .init .zeros_ (self .mlp [- 1 ].bias )
63-
64- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
65- # Instance-style normalization over spatial dims (H,W), per (B,C)
66- mean = x .mean (dim = (2 , 3 ), keepdim = True )
67- var = x .var (dim = (2 , 3 ), keepdim = True , unbiased = False )
68- x_norm = (x - mean ) / torch .sqrt (var + self .eps )
69-
70- # Produce gamma/beta from y
71- params = self .mlp (y ) # (B, 2C)
72- gamma , beta = params .chunk (2 , 1 ) # each (B, C)
73-
74- gamma = gamma .view (- 1 , self .num_channels , 1 , 1 )
75- beta = beta .view (- 1 , self .num_channels , 1 , 1 )
76-
77- if self .use_one_plus_gamma :
78- return x_norm * (1.0 + gamma ) + beta
79- return x_norm * gamma + beta
80-
81-
82- class AdaIN3d (nn .Module ):
83- """
84- Adaptive Instance Normalization for 3D tensors:
85- x: (B, C, Z, H, W)
86- y: (B, F) auxiliary vector
87- Produces per-sample, per-channel affine params from y.
88- """
89-
90- def __init__ (
91- self ,
92- num_channels : int ,
93- aux_in_features : int ,
94- hidden_features : int | tuple [int , ...] | None = None ,
95- activation : nn .Module | None = None ,
96- eps : float = 1e-5 ,
97- use_one_plus_gamma : bool = True ,
98- ):
99- super ().__init__ ()
100- self .num_channels = num_channels
101- self .eps = eps
102- self .use_one_plus_gamma = use_one_plus_gamma
103-
104- if activation is None :
105- activation = nn .SiLU ()
106-
107- # Build an MLP: aux_in_features -> ... -> 2*num_channels (gamma, beta)
108- if hidden_features is None :
109- hidden = []
110- elif isinstance (hidden_features , int ):
111- hidden = [hidden_features ]
112- else :
113- hidden = list (hidden_features )
114-
115- layers : list [nn .Module ] = []
116- in_f = aux_in_features
117- for h in hidden :
118- layers += [nn .Linear (in_f , h ), activation ]
119- in_f = h
120- layers += [nn .Linear (in_f , 2 * num_channels )]
121- self .mlp = nn .Sequential (* layers )
122-
123- # Optional: initialize last layer to near-zero so AdaIN starts close to plain IN
124- if isinstance (self .mlp [- 1 ], nn .Linear ):
125- nn .init .zeros_ (self .mlp [- 1 ].weight )
126- nn .init .zeros_ (self .mlp [- 1 ].bias )
127-
128- def forward (self , x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
129-
130- # Instance-style normalization over spatial dims (Z,H,W), per (B,C)
131- mean = x .mean (dim = (2 , 3 , 4 ), keepdim = True )
132- var = x .var (dim = (2 , 3 , 4 ), keepdim = True , unbiased = False )
133- x_norm = (x - mean ) / torch .sqrt (var + self .eps )
134-
135- # Produce gamma/beta from y
136- params = self .mlp (y ) # (B, 2C)
137- gamma , beta = params .chunk (2 , dim = - 1 ) # each (B, C)
138-
139- gamma = gamma .view (- 1 , self .num_channels , 1 , 1 , 1 )
140- beta = beta .view (- 1 , self .num_channels , 1 , 1 , 1 )
141-
142- if self .use_one_plus_gamma :
143- return x_norm * (1.0 + gamma ) + beta
144- return x_norm * gamma + beta
3+ __all__ = ["AdaIN2d" , "AdaIN3d" , "NormType" ]
0 commit comments