-
Notifications
You must be signed in to change notification settings - Fork 453
[main] mlp weight prefetch in Qwen Dense Models #2816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -66,7 +66,9 @@ def set_ascend_forward_context( | |||||||||||||||||
moe_comm_method: str = "", | ||||||||||||||||||
num_actual_tokens: Optional[int] = None, | ||||||||||||||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, | ||||||||||||||||||
batch_descriptor: Optional[BatchDescriptor] = None): | ||||||||||||||||||
batch_descriptor: Optional[BatchDescriptor] = None, | ||||||||||||||||||
prefetch_stream: torch.npu.Stream = None, | ||||||||||||||||||
model_instance: torch.nn.Module = None): | ||||||||||||||||||
"""A context manager that stores the current forward context, | ||||||||||||||||||
can be attention metadata, etc. | ||||||||||||||||||
We add some additional param into forward_context. | ||||||||||||||||||
|
@@ -108,7 +110,8 @@ def set_ascend_forward_context( | |||||||||||||||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, | ||||||||||||||||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, | ||||||||||||||||||
# the performance may degrade due to the switching of communication methods. | ||||||||||||||||||
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | ||||||||||||||||||
flashcomm_v1_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ | ||||||||||||||||||
envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM and \ | ||||||||||||||||||
tp_world_size > 1 and \ | ||||||||||||||||||
num_tokens is not None and num_tokens > 1000 | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -122,6 +125,26 @@ def set_ascend_forward_context( | |||||||||||||||||
# set this for rope forward_oot using | ||||||||||||||||||
forward_context.is_first_layer = True | ||||||||||||||||||
|
||||||||||||||||||
# set layer_idx to enable optimization features that depend on this information. | ||||||||||||||||||
# This is only applicable to models that contain these necessary attributes. | ||||||||||||||||||
forward_context.layer_idx = None | ||||||||||||||||||
if model_instance is not None and \ | ||||||||||||||||||
hasattr(model_instance, "model") and \ | ||||||||||||||||||
hasattr(model_instance.model, "start_layer"): | ||||||||||||||||||
forward_context.layer_idx = model_instance.model.start_layer | ||||||||||||||||||
|
||||||||||||||||||
# set for mlp weight prefetch | ||||||||||||||||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ | ||||||||||||||||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ | ||||||||||||||||||
forward_context.layer_idx is not None and \ | ||||||||||||||||||
num_tokens is not None and num_tokens < 500 | ||||||||||||||||||
Comment on lines
+137
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, the condition
Suggested change
|
||||||||||||||||||
if prefetch_mlp_enabled: | ||||||||||||||||||
forward_context.prefetch_stream = prefetch_stream | ||||||||||||||||||
forward_context.model_instance = model_instance | ||||||||||||||||||
forward_context.prefetch_mlp_gate_up_proj = False | ||||||||||||||||||
forward_context.prefetch_mlp_down_proj = False | ||||||||||||||||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled | ||||||||||||||||||
|
||||||||||||||||||
if num_tokens is None and attn_metadata is not None: | ||||||||||||||||||
num_tokens = attn_metadata.num_actual_tokens | ||||||||||||||||||
|
||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,6 +135,15 @@ | |
# This feature will get better performance when concurrency is large. | ||
"VLLM_ASCEND_ENABLE_FLASHCOMM": | ||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), | ||
# Whether to enable MLP weight prefetch, only used in small concurrency. | ||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": | ||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), | ||
# buffer size for gate up prefetch | ||
"MLP_GATE_UP_PREFETCH_SIZE": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
lambda: int(os.getenv("MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)), | ||
# buffer size for down proj prefetch | ||
"MLP_DOWN_PREFETCH_SIZE": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
lambda: int(os.getenv("MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)), | ||
# Whether to enable dense model and general optimizations for better performance. | ||
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types. | ||
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE
is repeated in multipleif
statements. Consider defining a variable to store this value and reuse it to avoid redundancy and improve readability. This also centralizes the configuration, making it easier to manage.