@@ -1836,8 +1836,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
18361836 // forward input logical tensor. If the input layout is already fixed, reuse
18371837 // it; otherwise fall back to the canonical acbd format used by sdpa.
18381838 auto get_md_for_diff = [](const logical_tensor_t <) {
1839- if (!ltw (lt).is_any ())
1840- return make_dnnl_memory_desc (lt);
1839+ if (!ltw (lt).is_any ()) return make_dnnl_memory_desc (lt);
18411840 return dnnl::memory::desc {ltw (lt).vdims (),
18421841 static_cast <dnnl::memory::data_type>(ltw (lt).data_type ()),
18431842 dnnl::memory::format_tag::acbd};
@@ -1872,12 +1871,10 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
18721871 const bool with_scale = op->get_attr <bool >(op_attr::with_scale);
18731872 const auto mask_type = static_cast <attn_mask_type_t >(
18741873 op->get_attr <int64_t >(op_attr::mask_type));
1875- const bool is_invert_scale
1876- = op->has_attr (op_attr::is_invert_scale)
1874+ const bool is_invert_scale = op->has_attr (op_attr::is_invert_scale)
18771875 ? op->get_attr <bool >(op_attr::is_invert_scale)
18781876 : false ;
1879- const bool with_explicit_mask
1880- = mask_type == attn_mask_type::buffer;
1877+ const bool with_explicit_mask = mask_type == attn_mask_type::buffer;
18811878
18821879 auto md_q = make_dnnl_memory_desc (op->get_input_logical_tensor (0 ));
18831880 auto md_k = make_dnnl_memory_desc (op->get_input_logical_tensor (1 ));
@@ -1898,8 +1895,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
18981895 md_attn_mask = make_dnnl_memory_desc (
18991896 op->get_input_logical_tensor (idx++));
19001897 if (op->num_outputs () > 4 )
1901- md_dS = make_dnnl_memory_desc (
1902- op->get_output_logical_tensor (4 ));
1898+ md_dS = make_dnnl_memory_desc (op->get_output_logical_tensor (4 ));
19031899 }
19041900
19051901 const auto &sdpa_fusion_info = op->has_attr (op_attr::fusion_info)
@@ -1917,18 +1913,17 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
19171913 vs_attr.set_accumulation_mode (str2accumulation_mode (
19181914 op->get_attr <std::string>(op_attr::vs_acc_mode)));
19191915 attr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
1920- attr.set_fpmath_mode (
1921- static_cast <dnnl::fpmath_mode>(fpmath.mode_ ));
1916+ attr.set_fpmath_mode (static_cast <dnnl::fpmath_mode>(fpmath.mode_ ));
19221917
19231918 dim_t kv_head_number = op->get_input_logical_tensor (1 ).dims [1 ];
1924- const alg_kind_t softmax_alg
1925- = alg_kind::softmax_accurate_inf_as_zero;
1919+ const alg_kind_t softmax_alg = alg_kind::softmax_accurate_inf_as_zero;
19261920
19271921 std::shared_ptr<primitive_desc_t > hint_fwd_pd;
19281922 status = create_sdpa_pd (hint_fwd_pd, p_engine.get (), md_q.get (),
19291923 md_k.get (), md_v.get (), md_dst.get (), md_attn_mask.get (),
19301924 md_scale.get (), is_invert_scale, kv_head_number, mask_type,
1931- softmax_alg, impl::prop_kind::forward_training, attr.get (), qk_attr.get (), vs_attr.get ());
1925+ softmax_alg, impl::prop_kind::forward_training, attr.get (),
1926+ qk_attr.get (), vs_attr.get ());
19321927 VCHECK_LAYOUT_PROPAGATOR (status == status::success, status,
19331928 " failed to create hint fwd pd for sdpa_bwd scratchpad" );
19341929
@@ -1938,8 +1933,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
19381933 md_diff_k.get (), md_diff_v.get (), md_diff_dst.get (),
19391934 md_dS.get (), md_attn_mask.get (), md_scale.get (),
19401935 is_invert_scale, kv_head_number, mask_type, softmax_alg,
1941- attr.get (), hint_fwd_pd.get (), qk_attr.get (),
1942- vs_attr.get ());
1936+ attr.get (), hint_fwd_pd.get (), qk_attr.get (), vs_attr.get ());
19431937 VCHECK_LAYOUT_PROPAGATOR (status == status::success, status,
19441938 " failed to create pd for sdpa_bwd scratchpad" );
19451939
0 commit comments