Skip to content

Conversation

iugoood
Copy link
Contributor

@iugoood iugoood commented Aug 13, 2025

Add

1 add mobilevit,mobilevitv2,deepseek_v3,xlnet models
2 add UT

Usage

mobilevit

    >>> import requests
    >>> import torch
    >>> from PIL import Image
    >>> from transformers import AutoImageProcessor
    >>> from mindone.transformers import MobileViTForSemanticSegmentation
    >>> import numpy as np
    >>> import mindspore as ms
    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)
    >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
    >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> for key,value in inputs.items():
    >>>     if isinstance(value, np.ndarray):
    >>>         inputs[key] = ms.tensor(value)
    >>>     elif isinstance(value, list):
    >>>         inputs[key] = ms.tensor(value)
    >>>     if key == "pixel_values":
    >>>         inputs[key] = inputs[key].to(ms.float32)
    >>> outputs = model(**inputs)
    >>> # logits are of shape (batch_size, num_labels, height, width)
    >>> logits = outputs.logits

mobilevitv2

    >>> import requests
    >>> import torch
    >>> from PIL import Image
    >>> from transformers import AutoImageProcessor
    >>> from mindone.transformers import MobileViTV2ForSemanticSegmentation
    >>> import numpy as np
    >>> import mindspore as ms
    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)
    >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
    >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> for key,value in inputs.items():
    >>>     if isinstance(value, np.ndarray):
    >>>         inputs[key] = ms.tensor(value)
    >>>     elif isinstance(value, list):
    >>>         inputs[key] = ms.tensor(value)
    >>>     if key == "pixel_values":
    >>>         inputs[key] = inputs[key].to(ms.float32)
    >>> outputs = model(**inputs)
    >>> # logits are of shape (batch_size, num_labels, height, width)
    >>> logits = outputs.logits

deepseek_v3

    >>> from transformers import AutoTokenizer
    >>> from mindone.transformers import DeepseekV3ForCausalLM
    >>> import mindspore as ms
    >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
    >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
    >>> prompt = "Hey, are you conscious? Can you talk to me?"
    >>> inputs = tokenizer(prompt, return_tensors="np")
    >>> # Generate
    >>> generate_ids = model.generate(ms.tensor(inputs.input_ids), max_length=30)
    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

xlnet

    >>> from transformers import AutoTokenizer
    >>> from mindone.transformers import XLNetLMHeadModel
    >>> import mindspore as ms
    >>> tokenizer = AutoTokenizer.from_pretrained("xlnet/xlnet-large-cased")
    >>> model = XLNetLMHeadModel.from_pretrained("xlnet/xlnet-large-cased")
    >>> # We show how to setup inputs to predict a next token using a bi-directional context.
    >>> input_ids = ms.tensor(
    ...     tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)
    ... ).unsqueeze(
    ...     0
    ... )  # We will predict the masked token
    >>> perm_mask = mint.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=ms.float32)
    >>> perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
    >>> target_mapping = mint.zeros(
    ...     (1, 1, input_ids.shape[1]), dtype=ms.float32
    ... )  # Shape [1, 1, seq_length] => let's predict one token
    >>> target_mapping[
    ...     0, 0, -1
    ... ] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)
    >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
    >>> next_token_logits = outputs[
    ...     0
    ... ]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
    >>> # The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
    >>> input_ids = ms.tensor(
    ...     tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)
    ... ).unsqueeze(
    ...     0
    ... )  # We will predict the masked token
    >>> labels = ms.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
    >>> assert labels.shape[0] == 1, "only one word will be predicted"
    >>> perm_mask = mint.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=ms.float32)
    >>> perm_mask[
    ...     :, :, -1
    ... ] = 1.0  # Previous tokens don't see last token as is done in standard auto-regressive lm training
    >>> target_mapping = mint.zeros(
    ...     (1, 1, input_ids.shape[1]), dtype=ms.float32
    ... )  # Shape [1, 1, seq_length] => let's predict one token
    >>> target_mapping[
    ...     0, 0, -1
    ... ] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)
    >>> outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
    >>> loss = outputs.loss
    >>> next_token_logits = (
    ...     outputs.logits
    ... )  # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]

Performance

Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.6.0.

model mode speed
mobilevit pynative 32.63 token/s
mobilevitv2 pynative 45.82 token/s
deepseek_v3 pynative 3.32 token/s
xlnet pynative 2.71 token/s

@iugoood iugoood requested a review from vigo999 as a code owner August 13, 2025 03:14
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @iugoood, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the mindone/transformers library by integrating four new transformer models: MobileViT, MobileViTV2, DeepseekV3, and XLNet. The changes involve adding the complete model architectures, updating the library's import mechanisms, and registering these models for automatic configuration and usage. Crucially, comprehensive unit tests have been added for each new model to validate their implementation and ensure stability. A minor code fix was also included for a padding operation.

Highlights

  • New Model: MobileViT: Introduced the MobileViT model, including its core components for image classification and semantic segmentation.
  • New Model: MobileViTV2: Added the MobileViTV2 model, an updated version of MobileViT, also supporting image classification and semantic segmentation.
  • New Model: DeepseekV3: Integrated the DeepseekV3 model, a causal language model with a Mixture-of-Experts (MoE) architecture.
  • New Model: XLNet: Included the XLNet model, a generalized autoregressive pretraining method for natural language understanding.
  • Module Imports and Exposure: Updated the __init__.py files across the mindone/transformers and mindone/transformers/models directories to properly import and expose the new models and their classes.
  • Auto-Configuration and Auto-Modeling: Modified the auto configuration and modeling files to enable automatic loading and usage of the new MobileViT, MobileViTV2, DeepseekV3, and XLNet models.
  • Unit Test Coverage: Added comprehensive unit tests for MobileViT, MobileViTV2, DeepseekV3, and XLNet models to ensure their correct functionality and integration.
  • Minor Code Fix: Applied a minor fix in mindone/transformers/modeling_utils.py to use mindspore.mint.nn.functional.pad instead of ops.pad for consistency and correctness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for several new models: mobilevit, mobilevitv2, deepseek_v3, and xlnet. The implementation for these models and their corresponding tests are included.

My review focuses on a few key areas:

  • Performance: The deepseek_v3 model contains some inefficient operations, particularly in the Mixture of Experts (MoE) layer and rotary position embeddings, which are acknowledged with TODO comments in the code. These could become performance bottlenecks and should be addressed.
  • Code Style and Efficiency: I've pointed out minor issues in mobilevit related to inefficient tensor creation and the use of in-place dropout, which could be improved for better performance and code clarity.
  • Completeness: It appears that the deepseek_v3 model, while its files are added, has not been fully integrated into the library's __init__.py and auto modules. This might be intentional for a follow-up PR, but as it stands, the model is not usable through the standard factory functions.

Overall, this is a substantial contribution. Addressing the identified points will help improve the quality and performance of the newly added models.

Comment on lines +174 to +184
for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = mindspore.mint.where(mask)

if token_indices.numel() > 0:
expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, weighted_output)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of the Mixture of Experts (MoE) layer iterates through experts using a for loop. This is inefficient, especially for models with a large number of experts. The docstring even includes a 'CALL FOR CONTRIBUTION' to optimize this. This loop is a significant performance bottleneck and should be vectorized. Consider using batched matrix multiplication (bmm) or other techniques to process all experts in parallel. Since this is a generated file, this change should likely be applied to the source file src/transformers/models/deepseek_v3/modular_deepseek_v3.py.

Comment on lines +300 to +304
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The apply_rotary_pos_emb_interleave function uses multiple view, transpose, and reshape operations, which is inefficient. The TODO comment in the docstring acknowledges this. This implementation should be refactored to avoid these expensive data layout transformations. Since this is a generated file, this change should likely be applied to the source file src/transformers/models/deepseek_v3/modular_deepseek_v3.py.

# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling_deepseek_v3 import *

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file is missing a final newline character. According to PEP 8, all files should end with a single newline. Please add one for consistency with the rest of the codebase.

Suggested change
from .modeling_deepseek_v3 import *
from .modeling_deepseek_v3 import *

Comment on lines +421 to +422
new_height = int(mindspore.mint.ceil(mindspore.tensor(orig_height / patch_height)) * patch_height)
new_width = int(mindspore.mint.ceil(mindspore.tensor(orig_width / patch_width)) * patch_width)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a mindspore.tensor just to use mindspore.mint.ceil is inefficient. You can achieve the same result more efficiently by using Python's math.ceil, which avoids the overhead of tensor creation and device transfer.

Suggested change
new_height = int(mindspore.mint.ceil(mindspore.tensor(orig_height / patch_height)) * patch_height)
new_width = int(mindspore.mint.ceil(mindspore.tensor(orig_width / patch_width)) * patch_width)
new_height = int(math.ceil(orig_height / patch_height) * patch_height)
new_width = int(math.ceil(orig_width / patch_width) * patch_width)

self.mobilevit = MobileViTModel(config)

# Classifier head
self.dropout = mindspore.mint.nn.Dropout(config.classifier_dropout_prob, inplace=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using inplace=True for dropout can be problematic as it modifies tensors in-place, which can lead to unexpected side effects and make debugging difficult. It's generally safer to avoid in-place operations. The MobileViTV2 implementation in this same PR does not use in-place dropout, suggesting it's not a necessary pattern. Please remove inplace=True for better code clarity and safety.

Suggested change
self.dropout = mindspore.mint.nn.Dropout(config.classifier_dropout_prob, inplace=True)
self.dropout = mindspore.mint.nn.Dropout(config.classifier_dropout_prob)

@iugoood iugoood force-pushed the mobilevit branch 2 times, most recently from 87128f9 to 474c47c Compare August 13, 2025 07:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant