Skip to content

improve MoE bias update logic in optimizer #1593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2025

Conversation

rakkit
Copy link
Contributor

@rakkit rakkit commented Aug 19, 2025

We put all experts' usage into a buffer such that we only need one reduce rather than #number-of-layers times

Additionally, handle cases where tokens per expert are counted twice during full recompute.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 19, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! I left some comments.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, had some more comments.

@rakkit rakkit force-pushed the improve_moe_bias_update branch from cb4e41a to cb8c9b8 Compare August 19, 2025 21:50
@rakkit rakkit requested a review from tianyu-l August 19, 2025 21:57
@rakkit
Copy link
Contributor Author

rakkit commented Aug 20, 2025

for moe ep usage and/or bias . Here we need to do smth like

 expert_usage_metrics = {
     f"moe_ep_usage/L-{layer_id}_EP-{ep_idx}": usage / sum_tokens
     for ep_idx, usage in enumerate(tokens_per_expert)
 }
 
 model_part._metrics_to_log.update{expert_usage_metrics}

and once we finalize RP#1578

for Moe model we can have

    def get_extra_metrics(self, model_parts: list[nn.Module], *args, **kwargs) -> None  | dict[str, Any]:
        return model_parts._metrics_to_log

@rakkit rakkit requested a review from tianyu-l August 20, 2025 07:43
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Had some final comments.

@tianyu-l tianyu-l added the release blocking Issues that are blocking the milestone / release completion label Aug 21, 2025
@rakkit rakkit force-pushed the improve_moe_bias_update branch from f340bcb to 9c35bc1 Compare August 21, 2025 20:35
@rakkit
Copy link
Contributor Author

rakkit commented Aug 21, 2025

Removed the comment (for ep-usage) and added the early exit in the first loop

@rakkit rakkit requested a review from tianyu-l August 21, 2025 20:37
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you! Please fix linting so we can merge.

We put all experts usage into a buffer such that we only need one reduce rather than #number-of-layers times

Additionally handle cases where tokens per expert are counted twice during full recompute.

(assume all moe layers have same number of experts)
@rakkit rakkit force-pushed the improve_moe_bias_update branch from 9c35bc1 to 9a623b8 Compare August 21, 2025 23:27
@rakkit
Copy link
Contributor Author

rakkit commented Aug 21, 2025

format fixed, thanks a lot for the discussion.

It's important to know that this PR only improves the code on COMM parts -> reduced to only once.
In practice, from the profiler, the second loop will launch lots of kernels, [num moe layer] * [slice, mean, sign, multi, add, zeros), unless one makes everything there vectorized.

@tianyu-l tianyu-l merged commit 2bfcdd8 into pytorch:main Aug 22, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot. release blocking Issues that are blocking the milestone / release completion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants