-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Add XAttention reference operation #31864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
9a119ee
to
e29e853
Compare
894d8bc
to
7e88152
Compare
460a365
to
6be527b
Compare
6be527b
to
1113f02
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference tests should be stored in:
src/plugins/template/tests/functional/op_reference/
as others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am aware of these tests. All of these use utilize unnecessary abstractions of OV opset ops and involve the sample plugin to do unit tests. Seeing that there will never be a counterpart to this operation in the OV opset, there is no reason to proliferate the bad design decision to have all reference ops tested through the template plugin just for the sake of consistency.
OPENVINO_ASSERT(input_shape.size() == 3); // [num_heads, num_tokens, head_size] | ||
OPENVINO_ASSERT(out_shape.size() == 3); | ||
OPENVINO_ASSERT(input_shape[0] == out_shape[0]); | ||
OPENVINO_ASSERT(input_shape[1] % m_stride == 0); | ||
OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]); | ||
OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these assert. the shapes and input validation is done during shape inference in operator implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not relevant since there is no shape inference involved by design
* @param out_shape Shape of the output tensor data. Expected shape is strictly equal to | ||
* `reshaped_qk_product_shape`. | ||
*/ | ||
void softmax(const T* reshaped_qk_product_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is required as it is just call for ov::reference::softmax(reshaped_qk_product_data, out, reshaped_qk_product_shape, {2});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a reference operation. It is designed to be understandable and testable first and foremost. I provide a set of functions that each directly map to a phase or sub-phase of the intended HW-accelerated kernel. Softmax is one of these phases and therefore has a separate function. It also has an extra shape check, so having a function wrapper has some utility.
} | ||
} | ||
|
||
/** Selects the elements of the input tensor along the last two dimensions, indepently along the first dimension, so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/** Selects the elements of the input tensor along the last two dimensions, indepently along the first dimension, so | |
/** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension, so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
4042108
to
2295ca6
Compare
22b057f
to
aa28325
Compare
Details:
Tickets: