Skip to content

Commit 70d711f

Browse files
committed
graph: backend: dnnl: fix code format
1 parent 3308143 commit 70d711f

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ status_t sdp_primitive_config_t::initial_check(
5555
VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dequantize
5656
&& opk != graph::op_kind::Quantize,
5757
status::unimplemented, "Not support quantized SDPA");
58-
// SDPA with Dropout is currently unsupported in the ukernel.
59-
VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dropout, status::unimplemented,
60-
"Not support SDPA with Dropout");
6158
if (opk == graph::op_kind::GenIndex) { has_genindex = true; }
6259
}
6360

src/graph/backend/dnnl/passes/insert_ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,9 @@ status_t insert_reshape_for_sdpa(std::shared_ptr<subgraph_t> &sg) {
669669
// Insert reshape for optional stats output (output 2)
670670
if (cur_op->get_attr<bool>(op_attr::is_training)) {
671671
auto stats_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims();
672-
dims expected_stats_dims = stats_dims;
673-
op_ptr reshape_stats = std::make_shared<op_t>(op_kind::_reshape);
672+
const dims &expected_stats_dims = stats_dims;
673+
op_ptr reshape_stats
674+
= std::make_shared<op_t>(op_kind::dnnl_reshape);
674675
reshape_stats->set_attr<bool>(op_attr::special_zero, false);
675676
reshape_stats->set_attr<std::vector<int64_t>>(
676677
op_attr::shape, expected_stats_dims);

src/graph/backend/dnnl/passes/transform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4773,7 +4773,7 @@ status_t fuse_sdpa_bwd(std::shared_ptr<subgraph_t> &sg) {
47734773
if (cur_op->get_kind() != op_kind::dnnl_matmul) continue;
47744774

47754775
// Step 1 – walk matmul_qk → [scale_pre] → [mask] → sub → exp
4776-
op_ptr matmul_qk = cur_op;
4776+
const op_ptr &matmul_qk = cur_op;
47774777
op_ptr scale_pre = nullptr, mask_op = nullptr;
47784778
op_ptr sub_op = nullptr, exp_op = nullptr;
47794779

0 commit comments

Comments
 (0)