Skip to content

Commit 3b5afff

Browse files
authored
Merge pull request #842 from gongchensu/Issue/791
Issue/791 增加add_rms_norm融合算子
2 parents 2d9d5c3 + 7712471 commit 3b5afff

File tree

21 files changed

+1403
-0
lines changed

21 files changed

+1403
-0
lines changed

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "ops/add.hpp"
4+
#include "ops/add_rms_norm.hpp"
45
#include "ops/attention.hpp"
56
#include "ops/causal_softmax.hpp"
67
#include "ops/matmul.hpp"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
#include <utility>
6+
7+
namespace infinicore::op {
8+
class AddRMSNorm {
9+
public:
10+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
11+
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
12+
static common::OpDispatcher<schema> &dispatcher();
13+
};
14+
15+
// Fused Add and RMS Normalization
16+
// Returns: (normalized_result, add_result)
17+
// The add_result can be used as residual for subsequent layers
18+
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
19+
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
20+
} // namespace infinicore::op

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "infiniop/handle.h"
55
#include "infiniop/ops/add.h"
6+
#include "infiniop/ops/add_rms_norm.h"
67
#include "infiniop/ops/attention.h"
78
#include "infiniop/ops/causal_softmax.h"
89
#include "infiniop/ops/clip.h"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef __INFINIOP_ADD_RMS_NORM_API_H__
2+
#define __INFINIOP_ADD_RMS_NORM_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopAddRMSNormDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t y_desc,
12+
infiniopTensorDescriptor_t a_desc,
13+
infiniopTensorDescriptor_t b_desc,
14+
infiniopTensorDescriptor_t weight_desc,
15+
float epsilon,
16+
infiniopTensorDescriptor_t residual_out_desc);
17+
18+
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
19+
20+
__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
21+
void *workspace,
22+
size_t workspace_size,
23+
void *y,
24+
const void *a,
25+
const void *b,
26+
const void *weight,
27+
void *residual_out,
28+
void *stream);
29+
30+
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
31+
32+
#endif

python/infinicore/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
uint8,
4141
)
4242
from infinicore.ops.add import add
43+
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
4344
from infinicore.ops.attention import attention
4445
from infinicore.ops.matmul import matmul
4546
from infinicore.ops.mul import mul
@@ -105,6 +106,8 @@
105106
"uint8",
106107
# Operations.
107108
"add",
109+
"add_rms_norm",
110+
"add_rms_norm_",
108111
"attention",
109112
"matmul",
110113
"mul",
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
6+
"""
7+
Fused Add and RMS Normalization.
8+
9+
Args:
10+
a: First input tensor
11+
b: Second input tensor
12+
weight: Scale weights
13+
epsilon: Small constant for numerical stability, default is 1e-5
14+
out: Optional output tuple (y, residual_out) for in-place operation
15+
16+
Returns:
17+
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
18+
The add_result can be used as residual for subsequent layers.
19+
"""
20+
if out is None:
21+
result = _infinicore.add_rms_norm(
22+
a._underlying, b._underlying, weight._underlying, epsilon
23+
)
24+
return (Tensor(result[0]), Tensor(result[1]))
25+
26+
y, residual_out = out
27+
_infinicore.add_rms_norm_(
28+
y._underlying,
29+
residual_out._underlying,
30+
a._underlying,
31+
b._underlying,
32+
weight._underlying,
33+
epsilon,
34+
)
35+
return (y, residual_out)
36+
37+
38+
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
39+
"""In-place Fused Add and RMS Normalization."""
40+
_infinicore.add_rms_norm_(
41+
y._underlying,
42+
residual_out._underlying,
43+
a._underlying,
44+
b._underlying,
45+
weight._underlying,
46+
epsilon,
47+
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "infinicore/ops/add_rms_norm.hpp"
2+
3+
#include "../../utils.hpp"
4+
5+
namespace infinicore::op {
6+
7+
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
8+
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
9+
return dispatcher_;
10+
};
11+
12+
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
13+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
14+
infinicore::context::setDevice(y->device());
15+
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
16+
}
17+
18+
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
19+
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
20+
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
21+
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
22+
return std::make_pair(y, residual_out);
23+
}
24+
25+
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
26+
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
27+
}
28+
29+
} // namespace infinicore::op
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/add_rms_norm.hpp"
4+
#include "infinicore/ops/common/cache.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::add_rms_norm_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopAddRMSNormDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
19+
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
20+
21+
auto device = context::getDevice();
22+
auto &cache = caches.getCache(device);
23+
24+
auto desc_opt = cache.get(seed);
25+
infiniopAddRMSNormDescriptor_t desc = nullptr;
26+
27+
if (!desc_opt) {
28+
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
29+
context::getInfiniopHandle(device), &desc,
30+
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
31+
cache.put(seed, desc);
32+
} else {
33+
desc = *desc_opt;
34+
}
35+
36+
size_t workspace_size = 0;
37+
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
38+
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
39+
40+
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
41+
desc, workspace->data(), workspace_size,
42+
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
43+
}
44+
45+
static bool registered = []() {
46+
AddRMSNorm::dispatcher().registerAll(&calculate, false);
47+
return true;
48+
}();
49+
50+
} // namespace infinicore::op::add_rms_norm_impl::infiniop

src/infinicore/pybind11/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <pybind11/pybind11.h>
44

55
#include "ops/add.hpp"
6+
#include "ops/add_rms_norm.hpp"
67
#include "ops/attention.hpp"
78
#include "ops/causal_softmax.hpp"
89
#include "ops/embedding.hpp"
@@ -24,6 +25,7 @@ namespace infinicore::ops {
2425

2526
inline void bind(py::module &m) {
2627
bind_add(m);
28+
bind_add_rms_norm(m);
2729
bind_attention(m);
2830
bind_causal_softmax(m);
2931
bind_random_sample(m);
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#pragma once
2+
3+
#include <pybind11/pybind11.h>
4+
5+
#include "infinicore/ops/add_rms_norm.hpp"
6+
7+
namespace py = pybind11;
8+
9+
namespace infinicore::ops {
10+
11+
inline void bind_add_rms_norm(py::module &m) {
12+
m.def("add_rms_norm",
13+
&op::add_rms_norm,
14+
py::arg("a"),
15+
py::arg("b"),
16+
py::arg("weight"),
17+
py::arg("epsilon") = 1e-5f,
18+
R"doc(Fused Add and RMS Normalization.
19+
20+
Args:
21+
a: First input tensor
22+
b: Second input tensor
23+
weight: Scale weights
24+
epsilon: Small constant for numerical stability, default is 1e-5
25+
26+
Returns:
27+
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
28+
The add_result can be used as residual for subsequent layers.
29+
)doc");
30+
31+
m.def("add_rms_norm_",
32+
&op::add_rms_norm_,
33+
py::arg("y"),
34+
py::arg("residual_out"),
35+
py::arg("a"),
36+
py::arg("b"),
37+
py::arg("weight"),
38+
py::arg("epsilon") = 1e-5f,
39+
R"doc(In-place Fused Add and RMS Normalization.
40+
41+
Args:
42+
y: Output tensor for normalized result
43+
residual_out: Output tensor for add result (a + b) before normalization
44+
a: First input tensor
45+
b: Second input tensor
46+
weight: Scale weights
47+
epsilon: Small constant for numerical stability, default is 1e-5
48+
)doc");
49+
}
50+
51+
} // namespace infinicore::ops

0 commit comments

Comments
 (0)