Skip to content

Commit d14fe8a

Browse files
authored
fuj/rms-norm (DeepLink-org#861)
* impl rms_norm_backward function, fix rms_norm forward function, fix rms_norm case config * fix rms norm backward * rms_norm config for ascend
1 parent bc77f24 commit d14fe8a

File tree

5 files changed

+53
-10
lines changed

5 files changed

+53
-10
lines changed

diopi_test/python/configs/diopi_configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8189,10 +8189,12 @@
81898189
args=[
81908190
{
81918191
"ins": ['input'],
8192+
"requires_grad": [True],
81928193
"shape": ((5, 5), (35, 125, 32), (16, 64, 64), (1, 32, 32, 8)),
81938194
},
81948195
{
81958196
"ins": ['weight'],
8197+
"requires_grad": [True],
81968198
"shape": ((5, ), (32, ), (64, ), (8, )),
81978199
},
81988200
{
@@ -8201,6 +8203,9 @@
82018203
},
82028204
],
82038205
),
8206+
# saved_args=dict(grad_outputs=0, inv_rms=1),
8207+
saved_args=dict(inv_rms=1),
8208+
requires_backward=[0],
82048209
),
82058210

82068211
# 'multihead_attention_forward': dict(

diopi_test/python/conformance/diopi_functions.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5151,7 +5151,9 @@ def rms_norm(input, normalized_shape, weight, bias, eps):
51515151
func = check_function(call)
51525152
size = list(input.size().data)
51535153
out = Tensor(size, input.get_dtype())
5154-
inv_rms = Tensor(size, input.get_dtype())
5154+
inv_rms_size = size.copy()
5155+
inv_rms_size[-1] = 1
5156+
inv_rms = Tensor(inv_rms_size, input.get_dtype())
51555157
normalized_shape = Sizes(list(normalized_shape))
51565158
ret = func(
51575159
input.context(),
@@ -5164,7 +5166,21 @@ def rms_norm(input, normalized_shape, weight, bias, eps):
51645166
eps,
51655167
)
51665168
check_returncode(ret)
5167-
return out
5169+
return (out, inv_rms)
5170+
5171+
5172+
def rms_norm_backward(grad_outputs, input, weight, bias, inv_rms, normalized_shape, eps):
5173+
call = "diopiRMSNormBackward"
5174+
func = check_function(call)
5175+
grad_input = Tensor(list(input.size().data), input.get_dtype())
5176+
grad_weight = Tensor(list(weight.size().data), weight.get_dtype())
5177+
grad_bias = Tensor(list(bias.size().data), bias.get_dtype())
5178+
normalized_shape = Sizes(list(normalized_shape))
5179+
5180+
ret = func(input.context(), grad_input, grad_weight, grad_bias, grad_outputs[0], input, weight, bias, inv_rms,
5181+
normalized_shape, eps)
5182+
check_returncode(ret)
5183+
return {'input': grad_input, 'weight': grad_weight}
51685184

51695185

51705186
def multihead_attention_forward(

diopi_test/python/conformance/gen_output.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,12 @@ def rotary_emb(input, cos, sin, conj):
226226
return out
227227

228228
def rms_norm(input, normalized_shape, weight, bias, eps):
229-
variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
230-
input = input * torch.rsqrt(variance + eps)
231-
out = weight * input
232-
return out
229+
var = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
230+
inv_rms = torch.rsqrt(var + eps)
231+
inp = input * inv_rms
232+
out = weight * inp
233+
234+
return (out, inv_rms)
233235

234236
def multihead_attention_forward(q, k, v, dropout_p, is_causal, return_debug_mask, scale):
235237
# 为了保证精度,因此在test的时候不使用dropout
@@ -298,8 +300,11 @@ class GenOutputData(object):
298300
db_case_items = {}
299301

300302
@staticmethod
301-
def run(diopi_item_config_path='diopi_case_items.cfg', input_path='data/inputs/',
302-
output_path='data/outputs/', fname='all_ops', model_name='diopi'):
303+
def run(diopi_item_config_path='diopi_case_items.cfg',
304+
input_path='data/inputs/',
305+
output_path='data/outputs/',
306+
fname='all_ops',
307+
model_name='diopi'):
303308
if not os.path.exists(input_path):
304309
logger.error("Input data is not generated!")
305310
sys.exit(0)
@@ -332,9 +337,11 @@ def run(diopi_item_config_path='diopi_case_items.cfg', input_path='data/inputs/'
332337
output, saved_grads = gen_tensor_obj.gen_data(input_)
333338
item['result'] = 'passed'
334339
except Exception as err_msg:
335-
raise GenDataFailedException(f'Generate output data for diopi_functions.{func_name} [{case_name}] failed, cause by \n{err_msg}')
340+
raise GenDataFailedException(
341+
f'Generate output data for diopi_functions.{func_name} [{case_name}] failed, cause by \n{err_msg}')
336342
GenOutputData.db_case_items[case_name] = item
337343
if output is not None:
344+
# import pdb; pdb.set_trace()
338345
with open(os.path.join(output_path, case_name), "wb") as f:
339346
pickle.dump(GenOutputData.to_numpy(output), f, protocol=4)
340347
logger_str = "output"

impl/ascend/device_configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,14 @@
780780
),
781781
),
782782

783+
'rms_norm': dict(
784+
name=['rms_norm'],
785+
atol=1e-3,
786+
rtol=1e-3,
787+
atol_half=1e-2,
788+
rtol_half=1e-2,
789+
),
790+
783791
'smooth_l1_loss': dict(
784792
name=['smooth_l1_loss'],
785793
tensor_para=dict(

impl/ascend/functions_ext/rms_norm.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ diopiError_t diopiRMSNormBackward(diopiContextHandle_t ctx, diopiTensorHandle_t
2222
diopiConstTensorHandle_t bias, diopiConstTensorHandle_t invRms, diopiSize_t normalizedShape, double eps) {
2323
AscendTensor inputTensor(input);
2424
ASCEND_CHECK_ABORT(1 == normalizedShape.len && normalizedShape.data[0] == inputTensor.shape()[inputTensor.dim() - 1], "normalized shape error!");
25-
AclOpRunner<4, 2>("RmsNorm", ctx).addInput(gradOutput).addInput(input).addInput(invRms).addInput(weight).addOutput(gradInput).addOutput(gradWeight).run();
25+
AclOpRunner<4, 2>("RmsNormGrad", ctx)
26+
.addInput(gradOutput)
27+
.addInput(input)
28+
.addInput(invRms)
29+
.addInput(weight)
30+
.addOutput(gradInput)
31+
.addOutput(gradWeight)
32+
.run();
2633
return diopiSuccess;
2734
}
2835

0 commit comments

Comments
 (0)