Skip to content

Commit 97f924a

Browse files
committed
refactor: extend NPU kernel parameter set and adjust conditional compilation structure.
1 parent 67a2090 commit 97f924a

File tree

18 files changed

+75
-100
lines changed

18 files changed

+75
-100
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,4 @@ DEFINE_string(reasoning_parser,
390390
// --- qwen3 reranker config ---
391391
DEFINE_bool(enable_qwen3_reranker, false, "Whether to enable qwen3 reranker.");
392392

393-
DEFINE_bool(enable_native_npu, true, "Whether to enable native NPU support.");
393+
DEFINE_bool(enable_npu_torch, true, "Whether to enable native NPU support.");

xllm/core/common/global_flags.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,4 @@ DECLARE_string(reasoning_parser);
203203

204204
DECLARE_bool(enable_shm);
205205

206-
DECLARE_bool(enable_native_npu);
206+
DECLARE_bool(enable_npu_torch);

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void WorkerServer::create_server(
103103
#elif defined(USE_NPU)
104104
// TODO: Refactor to use model_type or other appropriate enumeration for
105105
// condition checking
106-
if (FLAGS_enable_native_npu) {
106+
if (FLAGS_enable_npu_torch) {
107107
comm.create_process_groups(master_node_addr, device);
108108
}
109109
#endif

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ limitations under the License.
1818
#include "mapping_npu.h"
1919

2020
#if defined(USE_NPU)
21-
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
22-
2321
#include "npu_process_group.h"
2422
#include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h"
2523
#include "xllm_kernels/core/include/atb_speed/utils/singleton.h"

xllm/core/framework/parallel_state/npu_process_group.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ limitations under the License.
1818
#include "hccl/hccl.h"
1919
#include "process_group.h"
2020

21-
namespace c10d_npu {
22-
class ProcessGroupHCCL;
23-
}
24-
2521
namespace xllm {
2622

2723
class ProcessGroupHCCL : public ProcessGroup {

xllm/core/framework/parallel_state/process_group.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class ProcessGroup {
6666

6767
protected:
6868
#if defined(USE_NPU)
69+
// Using ProcessGroupHCCL for NPU devices
70+
// Note: torch_npu uses an older torch version where c10d::Backend lacks
71+
// shutdown() method
6972
std::unique_ptr<c10d_npu::ProcessGroupHCCL> pg_{nullptr};
7073
#else
7174
std::unique_ptr<c10d::Backend> pg_{nullptr};

xllm/core/kernels/npu/active.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ limitations under the License.
2020

2121
namespace xllm::kernel::npu {
2222

23-
torch::Tensor active(const torch::Tensor& input) {
23+
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) {
24+
if (act_mode != "silu" && act_mode != "swiglu") {
25+
throw std::runtime_error(
26+
"Only swiglu activation is supported in NPU active");
27+
}
2428
return at_npu::native::custom_ops::npu_swiglu(input);
2529
}
2630
} // namespace xllm::kernel::npu

xllm/core/kernels/npu/attention.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,34 @@ void batch_prefill(const torch::Tensor& query,
3131
const torch::Tensor& mask,
3232
const torch::Tensor& seq_len,
3333
float scale,
34-
int num_heads,
35-
int num_kv_heads,
3634
torch::Tensor& output) {
35+
auto num_heads = query.size(-2);
36+
auto num_kv_heads = key.size(-2);
3737
atb::_npu_flash_attention(
3838
query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output);
3939
}
4040

4141
void batch_decode(const torch::Tensor& query,
4242
const torch::Tensor& k_cache,
4343
const torch::Tensor& v_cache,
44-
int num_kv_heads,
45-
int num_heads,
4644
float scale,
4745
const torch::Tensor& block_table,
4846
const torch::Tensor& seq_lens,
4947
torch::Tensor& output) {
50-
atb::_npu_paged_attention(query,
48+
auto head_size = query.size(-1);
49+
auto num_heads = query.size(-2);
50+
auto num_kv_heads = k_cache.size(-2);
51+
auto q = query.view({-1, num_heads, head_size});
52+
auto o = output.view({-1, num_heads, head_size});
53+
atb::_npu_paged_attention(q,
5154
k_cache,
5255
v_cache,
5356
num_kv_heads,
5457
num_heads,
5558
scale,
5659
block_table,
5760
seq_lens,
58-
output);
61+
o);
5962
}
6063

6164
} // namespace xllm::kernel::npu

xllm/core/kernels/npu/fused_layernorm.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ namespace xllm::kernel::npu {
2121

2222
torch::Tensor fused_layernorm(const torch::Tensor& input,
2323
const torch::Tensor& weight,
24-
double eps) {
24+
double eps,
25+
const std::string& mode) {
26+
if (mode != "rmsnorm") {
27+
throw std::runtime_error(
28+
"Only rmsnorm mode is supported in NPU fused_layernorm");
29+
}
2530
std::tuple<at::Tensor, at::Tensor> result =
2631
at_npu::native::custom_ops::npu_rms_norm(input, weight, eps);
2732
auto normalized_input = std::get<0>(result);

xllm/core/kernels/npu/npu_ops_api.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,11 @@ void batch_prefill(const torch::Tensor& query,
3434
const torch::Tensor& mask,
3535
const torch::Tensor& seq_len,
3636
float scale,
37-
int num_heads,
38-
int num_kv_heads,
3937
torch::Tensor& output);
4038

4139
void batch_decode(const torch::Tensor& query,
4240
const torch::Tensor& k_cache,
4341
const torch::Tensor& v_cache,
44-
int num_kv_heads,
45-
int num_heads,
4642
float scale,
4743
const torch::Tensor& block_table,
4844
const torch::Tensor& seq_lens,
@@ -52,11 +48,12 @@ torch::Tensor matmul(const torch::Tensor& a,
5248
const torch::Tensor& b,
5349
const std::optional<torch::Tensor>& bias);
5450

55-
torch::Tensor active(const torch::Tensor& input);
51+
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode);
5652

5753
torch::Tensor fused_layernorm(const torch::Tensor& input,
5854
const torch::Tensor& weight,
59-
double eps);
55+
double eps,
56+
const std::string& mode);
6057

6158
void apply_rotary(torch::Tensor& q,
6259
torch::Tensor& k,

0 commit comments

Comments
 (0)