-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[WS] Support double buffering of scales in TMEM for tmem_store #8795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I wonder if we need/want this case in practice. Having scales going through tmem store will still break the pipeline as we need to load from smem to register. Do we have cases where we cannot swizzle? |
|
based on discussion with @ptillet, we think we don't want this optimization. The result would still be very suboptimal and user should always be able to swizzle HBM |
|
The unswizzled path is expected to be more performant for small inputs for which
This might be the case for SWP, but I don't think it presents any difficulty for WS. |
+1, if local_load + tmem_store is done in a dedicated partition, then i don't think would be an issue with WS, that can run in parallel with other partitions. |
This is true for pipeliner as well.
are you saying that because we need to padd scales to 32bits along K it is more efficient to not do tmem_copy? TMem store also works at the 32bits granularity right? |
No, in low-latency case where inputs are small and we use a small block size like 8 (with A/B swap), tmem_store working on unpadded, small tensor might copy just one column into TMEM, while tmem_copy ends up over-copying four columns.
Not sure what you mean by this? Are you saying that WS actually has an issue with local_load? Scales are copied into smem by the load partition, so the local_load in the "tmem copy partition" can work with that. |
There are two ways to copy scales into TMEM - tmem_copy and tmem_store. For the latter, making MMA asynchronous while scales are copied into TMEM requires that scales be double buffered in TMEM. So far neither SWP nor WS implement such double buffering, so using an MMA op with scales copied via tmem_store forces the MMA to be synchronous.
Motivated by applications for which tmem_copy might be difficult to apply, such as MoE with activation quantization, this PR enables double buffering of TMEM scales in WS when the scales are copied by tmem_store. We introduce a new predefined partition in
partition-schedulewhich is responsible for storing scales into double-buffered TMEM. MMA can now be made asynchronous when its scale operand is double-buffered TMEM, in addition to the tmem_copy case.