@@ -3637,7 +3637,7 @@ struct test_flash_attn_ext : public test_case {
3637
3637
3638
3638
ggml_tensor * m = nullptr ;
3639
3639
if (mask) {
3640
- m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[1 ], 1 );
3640
+ m = ggml_new_tensor_4d (ctx, GGML_TYPE_F16, kv, GGML_PAD (nb, GGML_KQ_MASK_PAD), nr23[0 ], nr23[ 1 ] );
3641
3641
ggml_set_name (m, " m" );
3642
3642
}
3643
3643
@@ -4751,7 +4751,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4751
4751
test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {1 , 1 }, scale, max_bias));
4752
4752
4753
4753
if (ne0 <= 32 && ne1 <= 32 ) {
4754
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, {3 , 1 }, scale, max_bias));
4754
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 3 }, mask, m_prec, {3 , 1 }, scale, max_bias));
4755
4755
test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, {2 , 3 }, scale, max_bias));
4756
4756
}
4757
4757
}
0 commit comments