@@ -89,32 +89,27 @@ std::vector<paddle::Tensor> GatherNextToken(
8989 return {out};
9090 }
9191
92- if (enc_batch <= 0 ) {
93- out = x.copy_to (x.place (), false );
92+ if (output_padding_offset) {
93+ int r = baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
94+ ctx,
95+ reinterpret_cast <const XPUType*>(x.data <data_t >()),
96+ reinterpret_cast <XPUType*>(out.data <data_t >()),
97+ encoder_seqs_lods_vp,
98+ decoder_seqs_lods_vp,
99+ encoder_batch_map_vp,
100+ decoder_batch_map_vp,
101+ dim);
102+ PD_CHECK (r == 0 , " xpu::plugin::gather_next_token failed." );
94103 } else {
95- if (output_padding_offset) {
96- int r =
97- baidu::xpu::api::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
98- ctx,
99- reinterpret_cast <const XPUType*>(x.data <data_t >()),
100- reinterpret_cast <XPUType*>(out.data <data_t >()),
101- encoder_seqs_lods_vp,
102- decoder_seqs_lods_vp,
103- encoder_batch_map_vp,
104- decoder_batch_map_vp,
105- dim);
106- PD_CHECK (r == 0 , " xpu::plugin::gather_next_token failed." );
107- } else {
108- int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
109- ctx,
110- reinterpret_cast <const XPUType*>(x.data <data_t >()),
111- reinterpret_cast <XPUType*>(out.data <data_t >()),
112- encoder_seqs_lods_vp,
113- encoder_batch_map_vp,
114- decoder_batch_map_vp,
115- dim);
116- PD_CHECK (r == 0 , " xpu::plugin::gather_next_token failed." );
117- }
104+ int r = baidu::xpu::api::plugin::eb_gather_next_token<XPUType, XPUType>(
105+ ctx,
106+ reinterpret_cast <const XPUType*>(x.data <data_t >()),
107+ reinterpret_cast <XPUType*>(out.data <data_t >()),
108+ encoder_seqs_lods_vp,
109+ encoder_batch_map_vp,
110+ decoder_batch_map_vp,
111+ dim);
112+ PD_CHECK (r == 0 , " xpu::plugin::gather_next_token failed." );
118113 }
119114 return {out};
120115}
0 commit comments