Hi!
I have noticed that the U-Net in ncsnpp.py handles attention blocks differently for the contracting and expanding path.
When downsampling, the number of attention blocks matches the no. of ResNet blocks for each dimension mult.
for i_level in range(num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
h = ResnetBlock(out_ch=nf * ch_mult[i_level])(hs[-1], temb, train)
if h.shape[1] in attn_resolutions:
h = AttnBlock()(h)
hs.append(h)
In the expanding path, however, there is only a single attention block.
# Upsampling block
for i_level in reversed(range(num_resolutions)):
for i_block in range(num_res_blocks + 1):
h = ResnetBlock(out_ch=nf * ch_mult[i_level])(
jnp.concatenate([h, hs.pop()], axis=-1), temb, train
)
if h.shape[1] in attn_resolutions:
h = AttnBlock()(h)
Why is this beneficial?
Hi!
I have noticed that the U-Net in
ncsnpp.pyhandles attention blocks differently for the contracting and expanding path.When downsampling, the number of attention blocks matches the no. of ResNet blocks for each dimension mult.
In the expanding path, however, there is only a single attention block.
Why is this beneficial?