Skip to content

Commit 20240af

Browse files
committed
graph: backend: dnnl: fix code format
1 parent 6d6ade6 commit 20240af

File tree

8 files changed

+74
-103
lines changed

8 files changed

+74
-103
lines changed

src/graph/backend/dnnl/executables/sdpa.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ sdpa_bwd_executable_t::sdpa_bwd_executable_t(std::shared_ptr<op_t> &op,
404404
status_t s = create_sdpa_pd(hint_fwd_pd, p_engine.get(), md_q.get(),
405405
md_k.get(), md_v.get(), md_dst.get(), md_attn_mask.get(),
406406
md_scale.get(), is_invert_scale_, kv_head_number, mask_type_,
407-
softmax_alg, impl::prop_kind::forward_training, attr.get(), qk_attr.get(), vs_attr.get());
407+
softmax_alg, impl::prop_kind::forward_training, attr.get(),
408+
qk_attr.get(), vs_attr.get());
408409
if (s != dnnl::impl::status::success) {
409410
is_initialized_ = false;
410411
return;
@@ -423,7 +424,6 @@ sdpa_bwd_executable_t::sdpa_bwd_executable_t(std::shared_ptr<op_t> &op,
423424
s = sdpa_bwd_pd_->create_primitive(sdpa_bwd_prim_, p_engine.get());
424425
is_initialized_ = s == status::success;
425426
}
426-
427427
}
428428

429429
void sdpa_bwd_executable_t::execute(const stream &stream,
@@ -466,8 +466,7 @@ void sdpa_bwd_executable_t::execute(const stream &stream,
466466
// Set up scratchpad grantor required by the primitive's execute
467467
const memory_storage_t *mem_storage = nullptr;
468468
memory_t *scratchpad_memory = ctx.output(DNNL_ARG_SCRATCHPAD);
469-
if (scratchpad_memory)
470-
mem_storage = scratchpad_memory->memory_storage();
469+
if (scratchpad_memory) mem_storage = scratchpad_memory->memory_storage();
471470
const void *host_ptr
472471
= ctx.host_ptr(mem_storage, /* require_host_ptr = */ true);
473472
auto *scratchpad_grantor

src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ status_t sdp_bwd_primitive_kernel_t::initial_check(
5151
const std::vector<logical_tensor_t> &inputs,
5252
const std::vector<logical_tensor_t> &outputs) {
5353
const bool is_f32 = inputs[0].data_type == data_type::f32;
54-
VCHECK_SDP_BWD_PRIMITIVE(!is_f32,
55-
status::unimplemented,
54+
VCHECK_SDP_BWD_PRIMITIVE(!is_f32, status::unimplemented,
5655
"SDPA bwd primitive doesn't support f32 because of performance");
5756

5857
bool has_dropout = false;
@@ -63,8 +62,7 @@ status_t sdp_bwd_primitive_kernel_t::initial_check(
6362
break;
6463
}
6564
}
66-
VCHECK_SDP_BWD_PRIMITIVE(!has_dropout,
67-
status::unimplemented,
65+
VCHECK_SDP_BWD_PRIMITIVE(!has_dropout, status::unimplemented,
6866
"SDPA bwd primitive doesn't support Dropout for now");
6967

7068
bool has_host_scalar = false;
@@ -74,8 +72,7 @@ status_t sdp_bwd_primitive_kernel_t::initial_check(
7472
break;
7573
}
7674
}
77-
VCHECK_SDP_BWD_PRIMITIVE(!has_host_scalar,
78-
status::unimplemented,
75+
VCHECK_SDP_BWD_PRIMITIVE(!has_host_scalar, status::unimplemented,
7976
"SDPA bwd primitive doesn't support host scalar inputs for now");
8077

8178
return status::success;

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ status_t sdp_primitive_config_t::initial_check(
5656
&& opk != graph::op_kind::Quantize,
5757
status::unimplemented, "Not support quantized SDPA");
5858
// SDPA with Dropout is currently unsupported in the ukernel.
59-
VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dropout, status::unimplemented,
60-
"Not support SDPA with Dropout");
59+
VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dropout,
60+
status::unimplemented, "Not support SDPA with Dropout");
6161
if (opk == graph::op_kind::GenIndex) { has_genindex = true; }
6262
}
6363

src/graph/backend/dnnl/layout_propagator.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,8 +1836,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
18361836
// forward input logical tensor. If the input layout is already fixed, reuse
18371837
// it; otherwise fall back to the canonical acbd format used by sdpa.
18381838
auto get_md_for_diff = [](const logical_tensor_t &lt) {
1839-
if (!ltw(lt).is_any())
1840-
return make_dnnl_memory_desc(lt);
1839+
if (!ltw(lt).is_any()) return make_dnnl_memory_desc(lt);
18411840
return dnnl::memory::desc {ltw(lt).vdims(),
18421841
static_cast<dnnl::memory::data_type>(ltw(lt).data_type()),
18431842
dnnl::memory::format_tag::acbd};
@@ -1872,12 +1871,10 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
18721871
const bool with_scale = op->get_attr<bool>(op_attr::with_scale);
18731872
const auto mask_type = static_cast<attn_mask_type_t>(
18741873
op->get_attr<int64_t>(op_attr::mask_type));
1875-
const bool is_invert_scale
1876-
= op->has_attr(op_attr::is_invert_scale)
1874+
const bool is_invert_scale = op->has_attr(op_attr::is_invert_scale)
18771875
? op->get_attr<bool>(op_attr::is_invert_scale)
18781876
: false;
1879-
const bool with_explicit_mask
1880-
= mask_type == attn_mask_type::buffer;
1877+
const bool with_explicit_mask = mask_type == attn_mask_type::buffer;
18811878

18821879
auto md_q = make_dnnl_memory_desc(op->get_input_logical_tensor(0));
18831880
auto md_k = make_dnnl_memory_desc(op->get_input_logical_tensor(1));
@@ -1898,8 +1895,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
18981895
md_attn_mask = make_dnnl_memory_desc(
18991896
op->get_input_logical_tensor(idx++));
19001897
if (op->num_outputs() > 4)
1901-
md_dS = make_dnnl_memory_desc(
1902-
op->get_output_logical_tensor(4));
1898+
md_dS = make_dnnl_memory_desc(op->get_output_logical_tensor(4));
19031899
}
19041900

19051901
const auto &sdpa_fusion_info = op->has_attr(op_attr::fusion_info)
@@ -1917,18 +1913,17 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
19171913
vs_attr.set_accumulation_mode(str2accumulation_mode(
19181914
op->get_attr<std::string>(op_attr::vs_acc_mode)));
19191915
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
1920-
attr.set_fpmath_mode(
1921-
static_cast<dnnl::fpmath_mode>(fpmath.mode_));
1916+
attr.set_fpmath_mode(static_cast<dnnl::fpmath_mode>(fpmath.mode_));
19221917

19231918
dim_t kv_head_number = op->get_input_logical_tensor(1).dims[1];
1924-
const alg_kind_t softmax_alg
1925-
= alg_kind::softmax_accurate_inf_as_zero;
1919+
const alg_kind_t softmax_alg = alg_kind::softmax_accurate_inf_as_zero;
19261920

19271921
std::shared_ptr<primitive_desc_t> hint_fwd_pd;
19281922
status = create_sdpa_pd(hint_fwd_pd, p_engine.get(), md_q.get(),
19291923
md_k.get(), md_v.get(), md_dst.get(), md_attn_mask.get(),
19301924
md_scale.get(), is_invert_scale, kv_head_number, mask_type,
1931-
softmax_alg, impl::prop_kind::forward_training, attr.get(), qk_attr.get(), vs_attr.get());
1925+
softmax_alg, impl::prop_kind::forward_training, attr.get(),
1926+
qk_attr.get(), vs_attr.get());
19321927
VCHECK_LAYOUT_PROPAGATOR(status == status::success, status,
19331928
"failed to create hint fwd pd for sdpa_bwd scratchpad");
19341929

@@ -1938,8 +1933,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr<op_t> &op,
19381933
md_diff_k.get(), md_diff_v.get(), md_diff_dst.get(),
19391934
md_dS.get(), md_attn_mask.get(), md_scale.get(),
19401935
is_invert_scale, kv_head_number, mask_type, softmax_alg,
1941-
attr.get(), hint_fwd_pd.get(), qk_attr.get(),
1942-
vs_attr.get());
1936+
attr.get(), hint_fwd_pd.get(), qk_attr.get(), vs_attr.get());
19431937
VCHECK_LAYOUT_PROPAGATOR(status == status::success, status,
19441938
"failed to create pd for sdpa_bwd scratchpad");
19451939

src/graph/backend/dnnl/passes/compile_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ status_t compile_ops(std::shared_ptr<subgraph_t> &sg) {
6969
"failed to create executable for op %s",
7070
op->get_name().c_str());
7171
}
72-
72+
7373
sg->execs_.emplace_back(exec);
7474

7575
sg->is_constant_.push_back(op->has_attr(op_attr::is_constant)

src/graph/backend/dnnl/passes/insert_ops.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,8 @@ status_t insert_reshape_for_sdpa(std::shared_ptr<subgraph_t> &sg) {
666666
// Insert reshape for optional stats output (output 2)
667667
if (cur_op->get_attr<bool>(op_attr::is_training)) {
668668
auto stats_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims();
669-
dims expected_stats_dims = stats_dims;
670-
op_ptr reshape_stats
671-
= std::make_shared<op_t>(op_kind::_reshape);
669+
const dims &expected_stats_dims = stats_dims;
670+
op_ptr reshape_stats = std::make_shared<op_t>(op_kind::_reshape);
672671
reshape_stats->set_attr<bool>(op_attr::special_zero, false);
673672
reshape_stats->set_attr<std::vector<int64_t>>(
674673
op_attr::shape, expected_stats_dims);
@@ -728,8 +727,7 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr<subgraph_t> &sg) {
728727
size_t index = 6;
729728
// Insert reshape for scale (optional)
730729
if (cur_op->get_attr<bool>(op_attr::with_scale)) {
731-
int32_t scale_ndims
732-
= cur_op->get_input_logical_tensor(index).ndims;
730+
int32_t scale_ndims = cur_op->get_input_logical_tensor(index).ndims;
733731
if (scale_ndims == 5) {
734732
auto scale_dims
735733
= ltw(cur_op->get_input_logical_tensor(index)).vdims();
@@ -741,9 +739,8 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr<subgraph_t> &sg) {
741739
// Insert reshape for mask (optional)
742740
if (cur_op->get_attr<int64_t>(op_attr::mask_type)
743741
== static_cast<int64_t>(attn_mask_type::buffer)) {
744-
int32_t mask_ndims
745-
= cur_op->get_input_logical_tensor(index).ndims;
746-
742+
int32_t mask_ndims = cur_op->get_input_logical_tensor(index).ndims;
743+
747744
if (mask_ndims == 5) {
748745
auto mask_dims
749746
= ltw(cur_op->get_input_logical_tensor(index)).vdims();
@@ -753,10 +750,10 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr<subgraph_t> &sg) {
753750
}
754751

755752
// Insert reshape for diff_query output (output 0) -> 4D to 5D
756-
auto diff_query_dims = ltw(cur_op->get_output_logical_tensor(0)).vdims();
753+
auto diff_query_dims
754+
= ltw(cur_op->get_output_logical_tensor(0)).vdims();
757755
const dims &expected_diff_query_dims = diff_query_dims;
758-
op_ptr reshape_diff_query
759-
= std::make_shared<op_t>(op_kind::_reshape);
756+
op_ptr reshape_diff_query = std::make_shared<op_t>(op_kind::_reshape);
760757
reshape_diff_query->set_attr<bool>(op_attr::special_zero, false);
761758
reshape_diff_query->set_attr<std::vector<int64_t>>(
762759
op_attr::shape, expected_diff_query_dims);
@@ -772,26 +769,26 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr<subgraph_t> &sg) {
772769
rewriter.insert_op_after(reshape_diff_key, cur_op, 1);
773770

774771
// Insert reshape for diff_value output (output 2) -> 4D to 5D
775-
auto diff_value_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims();
772+
auto diff_value_dims
773+
= ltw(cur_op->get_output_logical_tensor(2)).vdims();
776774
const dims &expected_diff_value_dims = diff_value_dims;
777-
op_ptr reshape_diff_value
778-
= std::make_shared<op_t>(op_kind::_reshape);
775+
op_ptr reshape_diff_value = std::make_shared<op_t>(op_kind::_reshape);
779776
reshape_diff_value->set_attr<bool>(op_attr::special_zero, false);
780777
reshape_diff_value->set_attr<std::vector<int64_t>>(
781778
op_attr::shape, expected_diff_value_dims);
782779
rewriter.insert_op_after(reshape_diff_value, cur_op, 2);
783780

784781
// Insert reshape for diff_mask output (output 4) -> 4D to 5D
785782
if (cur_op->num_outputs() > 4) {
786-
auto diff_mask_dims
787-
= ltw(cur_op->get_output_logical_tensor(4)).vdims();
788-
const dims &expected_diff_mask_dims = diff_mask_dims;
789-
op_ptr reshape_diff_mask
790-
= std::make_shared<op_t>(op_kind::_reshape);
791-
reshape_diff_mask->set_attr<bool>(op_attr::special_zero, false);
792-
reshape_diff_mask->set_attr<std::vector<int64_t>>(
793-
op_attr::shape, expected_diff_mask_dims);
794-
rewriter.insert_op_after(reshape_diff_mask, cur_op, 4);
783+
auto diff_mask_dims
784+
= ltw(cur_op->get_output_logical_tensor(4)).vdims();
785+
const dims &expected_diff_mask_dims = diff_mask_dims;
786+
op_ptr reshape_diff_mask
787+
= std::make_shared<op_t>(op_kind::_reshape);
788+
reshape_diff_mask->set_attr<bool>(op_attr::special_zero, false);
789+
reshape_diff_mask->set_attr<std::vector<int64_t>>(
790+
op_attr::shape, expected_diff_mask_dims);
791+
rewriter.insert_op_after(reshape_diff_mask, cur_op, 4);
795792
}
796793
}
797794

0 commit comments

Comments
 (0)