Skip to content

Commit 3149aed

Browse files
authored
fix_gather_next_token (#5311)
1 parent 0925d44 commit 3149aed

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

custom_ops/xpu_ops/src/ops/gather_next_token.cc

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,27 @@ std::vector<paddle::Tensor> GatherNextToken(
8989
return {out};
9090
}
9191

92-
if (enc_batch <= 0) {
93-
out = x.copy_to(x.place(), false);
92+
if (output_padding_offset) {
93+
int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
94+
ctx,
95+
reinterpret_cast<const XPUType*>(x.data<data_t>()),
96+
reinterpret_cast<XPUType*>(out.data<data_t>()),
97+
encoder_seqs_lods_vp,
98+
decoder_seqs_lods_vp,
99+
encoder_batch_map_vp,
100+
decoder_batch_map_vp,
101+
dim);
102+
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
94103
} else {
95-
if (output_padding_offset) {
96-
int r =
97-
baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
98-
ctx,
99-
reinterpret_cast<const XPUType*>(x.data<data_t>()),
100-
reinterpret_cast<XPUType*>(out.data<data_t>()),
101-
encoder_seqs_lods_vp,
102-
decoder_seqs_lods_vp,
103-
encoder_batch_map_vp,
104-
decoder_batch_map_vp,
105-
dim);
106-
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
107-
} else {
108-
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
109-
ctx,
110-
reinterpret_cast<const XPUType*>(x.data<data_t>()),
111-
reinterpret_cast<XPUType*>(out.data<data_t>()),
112-
encoder_seqs_lods_vp,
113-
encoder_batch_map_vp,
114-
decoder_batch_map_vp,
115-
dim);
116-
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
117-
}
104+
int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
105+
ctx,
106+
reinterpret_cast<const XPUType*>(x.data<data_t>()),
107+
reinterpret_cast<XPUType*>(out.data<data_t>()),
108+
encoder_seqs_lods_vp,
109+
encoder_batch_map_vp,
110+
decoder_batch_map_vp,
111+
dim);
112+
PD_CHECK(r == 0, "xpu::plugin::gather_next_token failed.");
118113
}
119114
return {out};
120115
}

0 commit comments

Comments
 (0)