-
Notifications
You must be signed in to change notification settings - Fork 400
Migrate DotProductAttention to NNX #2198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Migrate DotProductAttention to NNX #2198
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @hsuan-lun-chiang. Could you check why the GPU integration test is failing?
dummy_query_prefill = jnp.zeros((1, self.max_target_length, self.num_query_heads, config.head_dim), dtype=self.dtype) | ||
dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype) | ||
dummy_value_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cgarciae are zeros the right value here?
bd06740
to
16ad76e
Compare
Sure! It was causes by None being casting to uint8, fixed it. Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you @hsuan-lun-chiang. Did you run the description test on a GPU VM? Let's get @cgarciae's thoughts as well
Happy to help! Yes, I ran the test on a GPU VM with A100 80GB. |
Description
Migrate DotProductAttention to NNX.
Tests
Train Gemma-2b with attention=cudnn_flash_te, which use DotProductAttention:
Logs - After Migration
Logs - Before Migration
Checklist
Before submitting this PR, please make sure (put X in square brackets):