-
Notifications
You must be signed in to change notification settings - Fork 6.2k
UNet2DConditionModel
: add support for QK Normalization by propagating qk_norm
value from config through to child attention modules
#12051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
qk_norm
from config.json through to components of UNet2DConditionModelUNet2DConditionModel
: add support for QK Normalization by propagating qk_norm
value from config through to child attention modules
For the tests I find myself writing code like this:
Is there a canonical way of obtaining the list of block types so it doesn't have to be hardcoded in the test? |
…sers into feat/qk_norm_propagate
I wasn't able to run the full test suite - significant components seem to be broken on macOS/mps |
@damian0815 you mention what you changed in the code. |
the Attention modules get q_norm and k_norm (known as "QK Normalization" in the literature). In the SD3 paper they state that they were having problems with attention calculations getting larger and larger with training, which I also saw when using flow matching. QK Normalization brings that back under control. |
@damian0815 Is this an SD3 only kind of thing? |
What does this PR do?
Fixes #12050
QK Normalization was already implemented in
Attention
__init__
, but adding eg"qk_norm": "rms_norm"
to the config.json for aUNet2DConfitionModel
had no effect.This PR makes config
qk_norm
have an effect by propagating its value through the variousUNet2DConfitionModel
block initialization logic.Without this PR:
With this PR:
(I have successfully finetuned a model after making this change)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Pings based on
git blame
:@yiyixuxu @gnobitab @sayakpaul