Skip to content

Conversation

zhanghuiyao
Copy link
Collaborator

@zhanghuiyao zhanghuiyao commented Aug 8, 2025

What does this PR do?

pipe = pipeline("text-generation", model="openai/gpt-oss-20b", mindspore_dtype=mindspore.bfloat16)

messages = [
    {"role": "user", "content": "Explain quantum mechanics clearly and concisely."},
]

outputs = pipe(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])

outputs:

{'role': 'assistant', 'content': 'analysisThe user wants an explanation of quantum mechanics: clearly concise. Provide short key ideas: wavefunction, superposition, Schrödinger equation, measurement, uncertainties, etc. The answer should be clear, not too long.assistantfinal## Quick‐Start Guide to Quantum Mechanics\n\n| Concept | What it means | Why it matters |\n|---|---|---|\n| **Wavefunction (ψ)** | Describes the probability amplitude of a particle’s position and momentum. | Encodes every possible outcome—everything a particle *can* exists in the quantum.  |\n| **Schr| | **Explanation to quantum Mechanics. | **QuantumMechanism 19|   |   | 1\n\nThe quantum 7 | | ### The quantum results: ### **Quantum |  \n\n** quantum  \n\n**  \n 3 | **Explanation details  \n  \n  \n\n**Quantum Dynamics Quantum Mechanics\nExplanation to quantum mechanical quantum mechanics physics  \n\nI 29- quantum state quantum mechanics explanation\n\nThe explanation (answer.  \n\nSure.\n\n***QuantumMechanism\n\nI quantum mechanics quantum mechanics quantum states physics of a quantum mechanic quantum mechanics\n\n** quantum mechanics  \n  \n**Quantum mechanics**  \n\n**Quantum Mechanics  \n\n***Quantum Mechanics (Quantum Mechanics to quantum mechanics to quantum mechanics and explain'}

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

@zhanghuiyao zhanghuiyao requested a review from vigo999 as a code owner August 8, 2025 11:41
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @zhanghuiyao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I'm adding support for the gpt-oss model within the mindone.transformers library. This pull request introduces the full gpt-oss model architecture, including its specific RMSNorm, Mixture-of-Experts (MoE) components, and attention mechanisms. It also includes an example script for text generation with gpt-oss and integrates the model into the existing auto-loading and pipeline infrastructure. Furthermore, I've laid the groundwork for mixed-precision (MXFP4) inference and introduced a base class for gradient checkpointing to enhance performance and memory efficiency.

Highlights

  • New Model Integration: I've integrated the complete gpt-oss model architecture, including its specific RMSNorm, Mixture-of-Experts (MoE) components, rotary embeddings, and attention mechanisms, into the mindone.transformers library.
  • Mixed-Precision (MXFP4) Support: I've added support for mixed-precision FP4 (MXFP4) inference for GPT-OSS experts, including the necessary dequantization logic, to enhance performance and memory efficiency.
  • Gradient Checkpointing Foundation: I've introduced a new base class, GradientCheckpointingLayer, to provide a foundation for enabling gradient checkpointing, which is crucial for memory optimization in deep learning models.
  • Enhanced Model Loading: I've improved the model loading process by incorporating no_init_parameters for more efficient initialization and refining the handling of parameter data types during state dict loading.
  • Pipeline and Auto-Class Integration: I've updated the auto-loading and pipeline functionalities within mindone.transformers to correctly register and enable the gpt-oss model, ensuring seamless integration with existing workflows.
  • Text Generation Example: I've included a new example script that demonstrates how to use the gpt-oss model for text generation, making it easier for users to get started.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the GPT-OSS model, including its implementation, an example generation script, and integrations for features like MXFP4 quantization. The changes are quite extensive. My review has identified a few critical issues in the model's implementation that could cause runtime errors or incorrect behavior, particularly related to handling past key-values and expert routing logic in the MoE layer. I have also suggested performance improvements for the MXFP4 integration by avoiding unnecessary CPU-device data transfers and pointed out some potentially redundant code for better clarity.

glu = gate * mindspore.mint.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in this line. By using out[0], you are taking the output of the first token routed to this expert and broadcasting it to all other tokens for this expert. This means all tokens will receive the same output from the expert (scaled by their routing weight), which is likely incorrect. You should probably use the entire out tensor to compute the weighted output for each token individually.

Suggested change
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
weighted_output = out * routing_weights[token_idx, expert_idx, None]

past_seen_tokens = 0
elif isinstance(past_key_values, Cache):
past_seen_tokens = past_key_values.get_seq_length()
elif isinstance(past_key_value, tuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo in the variable name. past_key_value is used, but it should be past_key_values. This will cause a NameError at runtime because past_key_value is not defined in this scope.

Suggested change
elif isinstance(past_key_value, tuple):
elif isinstance(past_key_values, tuple):

Comment on lines +77 to +78
idx_lo = mindspore.tensor(blk.numpy() & 0x0F, mindspore.int64)
idx_hi = mindspore.tensor(blk.numpy() >> 4, mindspore.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .numpy() for bitwise operations can be inefficient as it involves data transfer between the device (e.g., GPU/NPU) and CPU, and then back to a tensor. For better performance, you should use MindSpore's device-side bitwise operations.

Suggested change
idx_lo = mindspore.tensor(blk.numpy() & 0x0F, mindspore.int64)
idx_hi = mindspore.tensor(blk.numpy() >> 4, mindspore.int64)
idx_lo = mindspore.ops.bitwise_and(blk, 0x0F).to(mindspore.int64)
idx_hi = mindspore.ops.right_shift(blk, 4).to(mindspore.int64)

blocks_attr = f"{proj}_blocks"
scales_attr = f"{proj}_scales"
dequantized = convert_moe_packed_tensors(getattr(self, blocks_attr), getattr(self, scales_attr))
dequantized = dequantized.transpose(1, 2).contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .contiguous() call is likely a remnant from PyTorch and is often not necessary in MindSpore. The transpose operation in MindSpore returns a new contiguous tensor. Removing this call would make the code cleaner and avoid potential confusion.

Suggested change
dequantized = dequantized.transpose(1, 2).contiguous()
dequantized = dequantized.transpose(1, 2)

@zhanghuiyao zhanghuiyao force-pushed the gpt_oss branch 2 times, most recently from 8f01067 to 46f5bc2 Compare August 13, 2025 08:18
@vigo999 vigo999 added the new model add new model to mindone label Sep 29, 2025
@vigo999 vigo999 added this to mindone Sep 29, 2025
@vigo999 vigo999 moved this to In Progress in mindone Sep 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

new model add new model to mindone

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

2 participants