File tree Expand file tree Collapse file tree 3 files changed +4
-6
lines changed
Expand file tree Collapse file tree 3 files changed +4
-6
lines changed Original file line number Diff line number Diff line change @@ -55,9 +55,6 @@ status_t sdp_primitive_config_t::initial_check(
5555 VCHECK_SDP_PRIMITIVE (opk != graph::op_kind::Dequantize
5656 && opk != graph::op_kind::Quantize,
5757 status::unimplemented, " Not support quantized SDPA" );
58- // SDPA with Dropout is currently unsupported in the ukernel.
59- VCHECK_SDP_PRIMITIVE (opk != graph::op_kind::Dropout, status::unimplemented,
60- " Not support SDPA with Dropout" );
6158 if (opk == graph::op_kind::GenIndex) { has_genindex = true ; }
6259 }
6360
Original file line number Diff line number Diff line change @@ -669,8 +669,9 @@ status_t insert_reshape_for_sdpa(std::shared_ptr<subgraph_t> &sg) {
669669 // Insert reshape for optional stats output (output 2)
670670 if (cur_op->get_attr <bool >(op_attr::is_training)) {
671671 auto stats_dims = ltw (cur_op->get_output_logical_tensor (2 )).vdims ();
672- dims expected_stats_dims = stats_dims;
673- op_ptr reshape_stats = std::make_shared<op_t >(op_kind::_reshape);
672+ const dims &expected_stats_dims = stats_dims;
673+ op_ptr reshape_stats
674+ = std::make_shared<op_t >(op_kind::dnnl_reshape);
674675 reshape_stats->set_attr <bool >(op_attr::special_zero, false );
675676 reshape_stats->set_attr <std::vector<int64_t >>(
676677 op_attr::shape, expected_stats_dims);
Original file line number Diff line number Diff line change @@ -4773,7 +4773,7 @@ status_t fuse_sdpa_bwd(std::shared_ptr<subgraph_t> &sg) {
47734773 if (cur_op->get_kind () != op_kind::dnnl_matmul) continue ;
47744774
47754775 // Step 1 – walk matmul_qk → [scale_pre] → [mask] → sub → exp
4776- op_ptr matmul_qk = cur_op;
4776+ const op_ptr & matmul_qk = cur_op;
47774777 op_ptr scale_pre = nullptr , mask_op = nullptr ;
47784778 op_ptr sub_op = nullptr , exp_op = nullptr ;
47794779
You can’t perform that action at this time.
0 commit comments