From d2333fce47e1e1134cd8a3397e57ac0cef17422f Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Wed, 24 Apr 2024 12:11:31 +0000 Subject: [PATCH 1/2] Enable 70B get_qkv stage dynamic shape. --- dicp/dicp/dynamo_bridge/utils.py | 7 + .../dicp/vendor/AscendGraph/codegen/ascend.py | 15 ++ dicp/dicp/vendor/AscendGraph/conversion.py | 179 ++++++++++++++---- 3 files changed, 167 insertions(+), 34 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/utils.py b/dicp/dicp/dynamo_bridge/utils.py index 050102ad4..46b6857d8 100644 --- a/dicp/dicp/dynamo_bridge/utils.py +++ b/dicp/dicp/dynamo_bridge/utils.py @@ -14,6 +14,13 @@ def symint_in_shape(shape): return False +def not_all_num_shape(shape): + for elem in shape: + if not isinstance(elem, int): + return True + return False + + def save_cpu_gm(gm: torch.fx.GraphModule, folder: str): Path(folder).mkdir(exist_ok=True) cpu_gm = copy_gm_to_cpu(gm) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9b9fc24f4..b5b85bbd7 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -264,6 +264,21 @@ def process_sym_name(self, st): return self.sym_to_inputs[sp[0]] + '*' + sp[1] else: return self.process_sym_name(sp[0]) + '*' + sp[1] + elif '//' in st: + sp = st.strip('()').split('//') + if len(sp) > 2: + sp = [sp[0], '//'.join(sp[1:])] + assert (len(sp) == 2) + sp = [elem.strip() for elem in sp] + if sp[0].isdigit(): + (sp[1], sp[0]) = (sp[0], sp[1]) + if sp[0] in self.sym_in_args: + arg, idx = self.sym_in_args[sp[0]] + return "{}.shape[{}]".format(arg, idx) + '//' + sp[1] + if sp[0] in self.sym_to_inputs.keys(): + return self.sym_to_inputs[sp[0]] + '//' + sp[1] + else: + return self.process_sym_name(sp[0]) + '//' + sp[1] else: if st in self.sym_in_args: arg, idx = self.sym_in_args[st] diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 11e72be2e..adfce1c8c 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -14,7 +14,7 @@ from torch.fx.immutable_collections import immutable_list from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op -from dicp.dynamo_bridge.utils import symint_in_shape +from dicp.dynamo_bridge.utils import symint_in_shape, not_all_num_shape from dicp.vendor.AscendGraph.codegen.utils import ( get_ascend_dtype, get_cpp_dtype @@ -79,14 +79,38 @@ def generate_digits_op(shapes): ascend_op.Const, (shapes, torch.int32, [len(shapes)])) x_names.append(const_op) - def generate_sym_int(elem): - elem = elem.node.str() - elems = elem.strip().split(' ') + def find_root_num(set_num, num): + while set_num[num] != num: + num = set_num[num] + return num + + def merge_disjoint_set(set_num, idx_a, idx_b): + root_a = find_root_num(set_num, idx_a) + root_b = find_root_num(set_num, idx_b) + # an example for (s5 / 8) - (s5 / 16) + # num: 0 1 2 3 + # step1 - > set_num: 0 1 2 3 + # step2 - > set_num: 0 0 2 2 + # step3 - > set_num: 0 0 0 0 + + # return merged set from root_b to root_a + return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num] + + def replace_elem_proxy(elem_str): + # exit if already a proxy + if isinstance(elem_str, torch.fx.proxy.Proxy): + return elem_str + assert not elem_str in ['+', '-', '*', '//', '(', ')'] + + # handle with integer + if elem_str.isdigit(): + const_op = self.get_proxy( + ascend_op.Const, ([int(elem_str)], torch.int32, [1])) + return const_op - arg = None - # dynamic shape feature - if elems[0] in self.sym_in_args: - arg, idx = self.sym_in_args[elems[0]] + # handle if elem in shape of InputArgs + if elem_str in self.sym_in_args: + arg, idx = self.sym_in_args[elem_str] shape = self.get_proxy(ascend_op.Shape, (arg,)) axis = self.get_proxy( ascend_op.Const, ([0], torch.int32, [1])) @@ -94,50 +118,131 @@ def generate_sym_int(elem): ascend_op.Const, ([idx], torch.int32, [1])) gather = self.get_proxy( ascend_op.GatherV2, (shape, indice, axis)) + return gather + # handle if SymInt InputArg needed + return self.sym_to_inputs[elem_str] + + def generate_not_num(elem): + # situation for NodeProxy + if isinstance(elem, torch.fx.proxy.Proxy): + x_names.append(elem) + return + + elem_str = elem.node.str() + elem_str = elem_str.replace('+', ' + ') + elem_str = elem_str.replace('-', ' - ') + elem_str = elem_str.replace('*', ' * ') + elem_str = elem_str.replace('//', ' // ') + elem_str = elem_str.replace('(', ' ( ') + elem_str = elem_str.replace(')', ' ) ') + elems = elem_str.split(' ') + elems = [e for e in elems if e != ''] + + # dynamic shape feature if len(elems) > 1: - assert len(elems) == 3 - assert elems[2].isdigit() - assert elems[1] == '+' or elems[1] == '-' - const_op = self.get_proxy( - ascend_op.Const, ([int(elems[2])], torch.int32, [1])) - if arg is not None: - args = (gather, const_op) - else: - args = (self.sym_to_inputs[elems[0]], const_op) - if elems[1] == '+': - x_names.append(self.get_proxy(ascend_op.Add, args)) - else: - x_names.append(self.get_proxy(ascend_op.Sub, args)) + set_num = [] + priority = [] + nest = 0 + + # calculate priority for each operator + # set initial set number + for idx, e in enumerate(elems): + if e == '+' or e =='-': + priority.append(nest * 3 + 0) + elif e == '*' or e == '//': + priority.append(nest * 3 + 1) + else: + if e == '(': + nest += 1 + elif e == ')': + nest -= 1 + priority.append(-1) + + # init set number + if not e in ['+', '-', '*', '//', '(', ')']: + set_num.append(idx) + else: + set_num.append(-1) + + # start merge disjoint-set + if len(set_num) > 1: + while len(set(set_num)) > 2: + # seek the highest priority operator + max = -1 + m_idx = -1 + for idx, prio in enumerate(priority): + if prio > max: + max = prio + m_idx = idx + + # merge the highest priority two elements calculation + # find left & right element + left_idx = m_idx - 1 + while left_idx > 0 and str(elems[left_idx]) in ['(', ')']: + left_idx -= 1 + right_idx = m_idx + 1 + while right_idx < len(elems) - 1 and str(elems[right_idx]) in ['(', ')']: + right_idx += 1 + left_idx = find_root_num(set_num, set_num[left_idx]) + right_idx = find_root_num(set_num, set_num[right_idx]) + left_elem = replace_elem_proxy(elems[left_idx]) + right_elem = replace_elem_proxy(elems[right_idx]) + + # generate calculation operator + if elems[m_idx] == '+': + elems[left_idx] = self.get_proxy(ascend_op.Add, (left_elem, right_elem)) + elif elems[m_idx] == '-': + elems[left_idx] = self.get_proxy(ascend_op.Sub, (left_elem, right_elem)) + elif elems[m_idx] == '*': + elems[left_idx] = self.get_proxy(ascend_op.Mul, (left_elem, right_elem)) + else: + elems[left_idx] = self.get_proxy(ascend_op.Div, (left_elem, right_elem)) + + # merge set number and priority + set_num = merge_disjoint_set(set_num, left_idx, right_idx) + priority[m_idx] = -1 + + # add final element proxy + final_idx = 0 + while final_idx < len(elems) - 1 and str(elems[final_idx]) in ['(', ')']: + final_idx += 1 + final_elem = replace_elem_proxy(elems[final_idx]) + x_names.append(final_elem) else: - if arg is not None: - x_names.append(gather) - else: - x_names.append(self.sym_to_inputs[elems[0]]) + # only one not num element + node = replace_elem_proxy(elems[0]) + x_names.append(node) dims = [] for elem in shape: - if not isinstance(elem, torch.SymInt): + # process number + if isinstance(elem, int): dims.append(elem) continue - st = elem.node.str() + st = str(elem) if st.isdigit(): dims.append(int(st)) continue + # add number block if len(dims) > 0: generate_digits_op(dims) dims = [] - generate_sym_int(elem) + generate_not_num(elem) + + # add last number block if len(dims) > 0: generate_digits_op(dims) + # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) def get_shape_proxy(self, shape): if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): return shape - elif isinstance(shape, list) and symint_in_shape(shape): + elif isinstance(shape, list) and not_all_num_shape(shape): + # include both SymInt & NodeProxy return self.process_dynamic_shape(shape) else: return self.get_proxy( @@ -307,12 +412,16 @@ def inge(self, x, y): y = self.get_const_proxy(y, torch.int32) return self.get_proxy(ascend_op.GreaterEqual, (x, y)) - @register_conversion(aten.div) + @register_conversion([aten.div, _operator.floordiv]) def div(self, x, y): if isinstance(y, torch.fx.proxy.Proxy): return self.get_proxy(ascend_op.DivNoNan, (x, y)) assert y != 0 - out_dtype = fx_traceback.get_current_meta()['val'].dtype + out = fx_traceback.get_current_meta()['val'] + if not isinstance(out, torch.SymInt): + out_dtype = out.dtype + else: + out_dtype = torch.int32 y_op = self.get_const_proxy(y, out_dtype) return self.get_proxy(ascend_op.Div, (x, y_op), {}) @@ -332,10 +441,12 @@ def slice(self, x, dim=0, start=None, end=None, step=1): x_shape = list(x.node.meta['val'].shape) y_shape = list(fx_traceback.get_current_meta()['val'].shape) dim = int(dim) - start = int(start) if start is not None else 0 - start = start if start >= 0 else x_shape[dim] + start + if not isinstance(start, torch.fx.proxy.Proxy): + start = int(start) if start is not None else 0 + start = start if start >= 0 else x_shape[dim] + start + assert start is None or start >= 0 and start < x_shape[dim] + assert dim == -1 or dim >= 0 and dim < len(x_shape) - assert start is None or start >= 0 and start < x_shape[dim] offset = [0] * len(x_shape) offset[dim] = start offset = self.get_shape_proxy(offset) From 3dec4354d60573bba049693dc0eb55f09042d4d9 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Thu, 25 Apr 2024 07:25:02 +0000 Subject: [PATCH 2/2] Change load_and_run in/out shape assignment. --- .../dicp/vendor/AscendGraph/codegen/ascend.py | 95 +++++-------------- 1 file changed, 26 insertions(+), 69 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index b5b85bbd7..46ea900e0 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -217,73 +217,19 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): ) return self.import_code.getvalue() + def operator_in_str(self, st): + for op in ['+', '-', '*', '/']: + if op in st: + return True + return False + def process_sym_name(self, st): # dynamic shape feature - if st.isdigit(): - return st - elif '+' in st: - sp = st.split('+') - if len(sp) > 2: - sp = [sp[0], '+'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '+' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '+' + sp[1] - else: - return self.process_sym_name(sp[0]) + '+' + sp[1] - elif '-' in st: - sp = st.split('-') - if len(sp) > 2: - sp = [sp[0], '-'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '-' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '-' + sp[1] - else: - return self.process_sym_name(sp[0]) + '-' + sp[1] - elif '*' in st: - sp = st.split('*') - if len(sp) > 2: - sp = [sp[0], '*'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '*' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '*' + sp[1] - else: - return self.process_sym_name(sp[0]) + '*' + sp[1] - elif '//' in st: - sp = st.strip('()').split('//') - if len(sp) > 2: - sp = [sp[0], '//'.join(sp[1:])] - assert (len(sp) == 2) - sp = [elem.strip() for elem in sp] - if sp[0].isdigit(): - (sp[1], sp[0]) = (sp[0], sp[1]) - if sp[0] in self.sym_in_args: - arg, idx = self.sym_in_args[sp[0]] - return "{}.shape[{}]".format(arg, idx) + '//' + sp[1] - if sp[0] in self.sym_to_inputs.keys(): - return self.sym_to_inputs[sp[0]] + '//' + sp[1] - else: - return self.process_sym_name(sp[0]) + '//' + sp[1] - else: - if st in self.sym_in_args: - arg, idx = self.sym_in_args[st] - return "{}.shape[{}]".format(arg, idx) - return self.sym_to_inputs[st] + # return string wrapper in new version + # node.str() will not fallback SymInt value form + if isinstance(st, torch.SymInt): + return st.node.str() + return str(st) def gen_call_func(self): # TODO check scalar input @@ -293,9 +239,20 @@ def gen_call_func(self): # dynamic shape feature if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + # import args needed for map assignment args = ['_' if not arg in shape_symint and not arg in self.sym_to_inputs.values() else arg for arg in self.args] call_body.writeline(f"({','.join(args)}) = args") + # assign SymInt to InputArgs relationship + if len(self.sym_in_args) > 0: + for key in self.sym_in_args.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]") + if len(self.sym_to_inputs) > 0: + for key in self.sym_to_inputs.keys(): + if not key.isdigit() and not self.operator_in_str(key): + call_body.writeline(f"{key} = {self.sym_to_inputs[key]}") + # generate input dims if len(self.dynamic_inputs) > 0: dim_len = 0 @@ -328,7 +285,7 @@ def gen_call_func(self): shape = list(elem.shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [self.process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," # process output_shape with modified args @@ -336,12 +293,12 @@ def gen_call_func(self): shape = list(self.input_args[elem[1]].meta['val'].shape) if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - shape = [self.process_sym_name(str(dim)) for dim in shape] + shape = [self.process_sym_name(dim) for dim in shape] shape_str += "[" + ','.join(map(str, shape)) + "]," stride = list(self.input_args[elem[1]].meta['val'].stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [self.process_sym_name(dim) for dim in stride] extra_stride_str += '[' + ','.join(map(str, stride)) + '],' extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' shape_str = shape_str[:-1] + f''']''' @@ -364,7 +321,7 @@ def gen_call_func(self): out_storage_offsets.append('0') continue stride = list(elem.stride()) - stride = [self.process_sym_name(str(dim)) for dim in stride] + stride = [self.process_sym_name(dim) for dim in stride] out_strides.append(str(stride)) out_storage_offsets.append(elem.storage_offset()) call_body.writeline(f'out_stride = {out_strides}')