|
| 1 | +# Preprocessing for dynamic shape execution |
| 2 | +As explained in basic flow of primitive execution for dynamic shape from [Overall flow flow for dynamic shape](overall_flow.md), several preprocessing steps are performed before setting arguments to kernel and executing selected impl. |
| 3 | + |
| 4 | +* `update_shape` - when the input shape changes, calculate and change the output shape and perform shape inference so that the shape is propagated to the next node. |
| 5 | +* `update_impl` - depending on the changed shape, `primitive_impl` is retrieved from in-memory cache or new impl is selected. |
| 6 | +* `realloc_if_needed` - allocates new output memory if necessary. |
| 7 | + |
| 8 | +The following is a description for some of the representative preprocessing steps for dynamic shape execution. |
| 9 | + |
| 10 | +## primitive_inst::update_shape |
| 11 | +### Dynamic shape inference |
| 12 | +To support dynamic shape in GPU plugin, `cldnn::layout` uses `ov::PartialShape` to express shape. While the existing `cldnn::tensor` does not support dynamic shape and has limitations in rank, `ov::PartialShape` supports static and dynamic dimensions and has no limitations in rank. And when creating `cldnn::primitive` from `ov::op`, `ov::PartialShape` that `ov::op` already has is directly used. |
| 13 | + |
| 14 | +> **Note**: In the execution flow for the existing static shape in GPU plugin, the shape of `ov::op` may be transformed into `ov::tensor` and used, so when creating `cldnn::primitive` from `ov::op`, it is separated from the dynamic shape execution flow. When building `cldnn::program`, if there is at least one dynamic node among the nodes, `ov::intel_gpu::allow_new_shape_infer` property is set [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/plugin/program_builder.cpp#L139) and execution of static shape and dynamic shape is separated through this property during `cldnn::primitive` creation. It will be integrated in the future when GPU plugin fully supports dynamic shape. |
| 15 | +
|
| 16 | +When the input shape of the model changes, the input shape of the current primitive is also updated by checking whether the input shape has changed, and the output shape is calculated through the input shape, then this shape is propagated to the next primitive on shape inference stage. |
| 17 | +Details on how to execute shape inference through `primitive_inst::update_shape` when executing primitive in GPU plugin for dynamic shape are as follows: |
| 18 | + |
| 19 | +1. In the basic flow that executes primitive, there is a runtime optimization stage (i.e. `primitive_inst::do_runtime_in_place_concat` [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L720)) that runs before `update_shape()`. At this time, if `update_shape()` has already been executed by another primitive, set `update_shape_done_by_other` to TRUE. Therefore, if `update_shape_done_by_other` is TRUE, `update_shape()` is skipped. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L247) |
| 20 | +2. First, output layouts of `kernel_impl_params` from the dependencies of `primitive_inst` are compared with the input layouts of `kernel_impl_params` of the current primitive. If changed, the changed shape is updated to input layouts of `kernel_impl_params`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L254) |
| 21 | +3. Set `_shape_changed` to TRUE if the input shape has changed. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L266) |
| 22 | +4. If the current node is `shape_of` and the input shape has not changed, reset `_shape_changed` to FALSE and skip `update_shape()`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L270) |
| 23 | +5. If the current node is *`shape_of` subgraph*, check *dependent `shape_of` primitives* and skip `update_shape()` if the shape has not changed. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L276) |
| 24 | +6. `update_shape()` is skipped if any of the following conditions hold: the input shape has not changed, the node generates dynamic output (e.g. `Nonzero`, `Unique`), or the output layouts of `kernel_impl_params` are already static. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L313) |
| 25 | +7. In static shape execution, data for additional inputs that determine the output shape are set as attributes when creating `cldnn::primitive`. In dynamic shape execution, if that data is stored in the output memory of a preceding node, execution waits until those dependent nodes complete. To determine which input nodes have memory dependencies, most `program_node`s define `get_shape_infer_dependencies()`. The dependency information (index and memory for each dependent input node) is collected from the current node, stored in a `map`, and the corresponding primitive events are added to an event list to await completion. Finally, the populated map is saved in `memory_deps` of `kernel_impl_params`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L319) |
| 26 | +8. There are two APIs for output shape calculation on `program_node`: `calc_output_layout()` for static shape execution and `calc_output_layouts()` for dynamic shape execution. In this step, `calc_output_layouts()` is called, which invokes the `shape_infer()` API of `ov::op` with the updated input layouts from `kernel_impl_params`, the primitive's attributes, and `memory_deps`, and returns output layouts as a vector. The newly calculated output layout is then written back to `output_layouts` in `kernel_impl_params` [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L366) |
| 27 | + ```cpp |
| 28 | + struct program_node { |
| 29 | + ... |
| 30 | + public: |
| 31 | + layout calc_output_layout() const; |
| 32 | + std::vector<layout> calc_output_layouts() const; |
| 33 | + } |
| 34 | + ``` |
| 35 | +9. If there is fused operation in `kernel_impl_params`, the output layout of the descriptor is also updated with `ov::PartialShape` of updated output layout. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L379) |
| 36 | + |
| 37 | +## primitive_inst::update_weight |
| 38 | +If `primitive_impl` is created or updated through `update_impl()`, and it is a weightable node (e.g. `convolution`, `deconvolution`, `fc`), the weight should be reordered to the layout required by kernel as needed. The following describes the processes performed in `update_weights()`. |
| 39 | + |
| 40 | +1. If impl is nullptr or the current node is not weightable node, `update_weight()` is skipped. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L1168) |
| 41 | +2. Create *reorder kernel params* (i.e. `kernel_impl_params` for weights reorder) from `WeightsReorderParams` of `primitive_inst`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L1172) |
| 42 | +3. In cases where weights reorder is not necessary, if weights were previously reordered, incorrect memory buffer is allocated, so reset *reordered weights cache* to original weight memory layout. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L1181) |
| 43 | +4. If weights reorder is necessary, update the weight layout of `kernel_impl_params` to the output layout of *reorder kernel params*. This is the expected layout. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L1186) |
| 44 | + - If the expected layout hits *reordered weights cache*, it is reused. |
| 45 | + - If the expected layout is compatible with the original layout, the original weights memory is reinterpreted and added to *reordered weights cache* without the need for reordering. |
| 46 | + - If the expected layout misses *reordered weights cache*, retrieve a cached reorder impl from `implementations cache` using `reorder kernel params`, or create a new reorder impl through `WeightsReordersFactory` and set the compiled kernel on it. Add the impl to `implementation cache`. Check whether the weights memory can be reused in `reordered weights cache`; if so, reuse it, otherwise allocate a new buffer. Update `reordered weights cache` accordingly. Finally, use `kernel_arguments_data()` to set kernel arguments in the reorder impl and execute the kernel. |
| 47 | + |
| 48 | +## primitive_inst::realloc_if_needed |
| 49 | +In the case of static shape execution, output memory is allocated when creating `primitive_inst`, but in dynamic shape execution, output memory is allocated before arguments are set to kernel and execution. The following describes the processes performed in `realloc_if_needed()`. |
| 50 | + |
| 51 | +1. If the current node is `concat` and has 1 user, `can_be_optimized()` is TRUE but `allocation_done_by_other` is FALSE (i.e. not yet allocated by another node), execute `concat`'s `realloc_if_needed()` and set `allocation_done_by_other` to TRUE. Also, use concat's output memory as the output memory of the current node and skip `realloc_if_needed()`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L390) |
| 52 | +2. For better performance, if *fake aligned shape* is used when executing the kernel (e.g. `fully_connected`), the input and output shapes of `kernel_impl_params` are updated accordingly. A more detailed explanation will be added as a separate section later (TBD). [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L403) |
| 53 | +3. If the node is `input_layout`, `realloc_if_needed()` is skipped because it is assumed to always use external memory. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L408) |
| 54 | +4. Check whether output memory is already allocated and the requested buffer size is smaller than the current buffer size, and store the result in `can_reuse_buffer`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L421) |
| 55 | +5. If the current node is `concat` and both `can_be_optimized()` and `allocation_done_by_other` are TRUE, `realloc_if_needed()` is skipped. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L424) |
| 56 | +6. `ShapePredictor` predicts a preallocation shape from the current shape and data type, and updates the output layout shape of `kernel_impl_params` accordingly. A more detailed explanation will be added as a separate section later (TBD). [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L429) |
| 57 | +7. If `can_reuse_buffer` is TRUE, `reused` of output memory is set to TRUE and output memory is updated with reinterpreted buffer. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L439) |
| 58 | +8. If `can_reuse_buffer` is FALSE, reallocate with `allocate_outputs()` to set the output memory and update `max_output_layout_size`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L448) |
| 59 | +9. Get internal buffer layouts from the current `primitive_impl`. [(link)](https://github.com/openvinotoolkit/openvino/blob/eea49f3c9e6bba5463460fdc126c2df38a4a5215/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L458) |
| 60 | + - If the previously allocated intermediate memory can be reused, the intermediate memory is updated with reinterpreted buffer. |
| 61 | + - If it cannot be reused, allocate a new buffer through `allocate_internal_buffer()` to update or add a new intermediate memory that has already been allocated. |
0 commit comments