Work around a jit post-op issue with layout not matching dst#4864
Work around a jit post-op issue with layout not matching dst#4864
Conversation
| else if (col) { | ||
| emul(1, tempQ0, j0, ld, strategy, state); | ||
| eadd(1, offset, offset, tempQ0.reinterpret(0, offset.getType()), strategy, state); | ||
| } |
There was a problem hiding this comment.
These changes don't look right. The intent is that we pass state.inputs.binaryLDs only when we have 2D post-ops. For either row or column (1D) post-ops, we assume packed data, and here we scale by the data type size to get the stride.
| // For col-only binary, check whether the col direction has unit stride. | ||
| // Col direction corresponds to dim[ndims-2] (no swap) or dim[ndims-1] (swap_ab). | ||
| bool col_only = is_multi_col && !is_multi_row; | ||
| int rmd_ndims = src_rmd.ndims(); | ||
| bool col_unit = !col_only | ||
| || (swap_ab ? src_rmd.is_inner_dim(rmd_ndims - 1, rmd_ndims) | ||
| : src_rmd.is_inner_dim(rmd_ndims - 2, rmd_ndims)); | ||
| // Non-unit stride requires Scattered access with correct multi-block offsets, | ||
| // which is not yet supported. Reject at PD time to avoid silent wrong results. | ||
| if (!col_unit) return status::unimplemented; |
There was a problem hiding this comment.
This fix may work, but only this block of code is needed. Everything else looks wrong. Please also move this to before the if (swap_ab) block above, and assume swap_ab = false, then inside of the if (swap_ab) block do any necessary "fixing up". Eventually this logic will be moved somewhere else, so we need to keep it separate from the base logic.
There was a problem hiding this comment.
Could we force 2d in these cases? E.g.
if (!is_multi_row && is_multi_col && !src_rmd.is_inner_dim(rmd_ndims - 2, rmd_ndims))
is_multi_row = true;edit: nevermind, I thought the problem shape had N=1.
34180a5 to
df1c7e4
Compare
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
df1c7e4 to
e4543be
Compare
|
@rjoursler Please comment on the logic related to |
|
make test |
Fixes MFDNN-14794. Backport to v3.11 (to allow ref path in Graph SDPA training work).