Skip to content

Fix: ModernBERT sequence classification num_labels support#4177

Open
adityaghai07 wants to merge 3 commits intounslothai:mainfrom
adityaghai07:fix/bert-num-labels
Open

Fix: ModernBERT sequence classification num_labels support#4177
adityaghai07 wants to merge 3 commits intounslothai:mainfrom
adityaghai07:fix/bert-num-labels

Conversation

@adityaghai07
Copy link
Contributor

Fix: Add support for sequence classification parameters in FastModel

Issue #4163

Problem

When using FastModel.from_pretrained() with AutoModelForSequenceClassification for BERT-based models (including ModernBERT), users encountered a TypeError:

TypeError: ModernBertForSequenceClassification.__init__() got an unexpected keyword argument 'num_labels'

This prevented users from loading sequence classification models with custom label configurations, making it impossible to use unsloth for fine-tuning BERT models on classification tasks.

Reproduction

from unsloth import FastModel
from transformers import AutoModelForSequenceClassification

model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/ModernBERT-large",
    auto_model=AutoModelForSequenceClassification,
    num_labels=6,  #  This caused TypeError
    id2label={0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"},
    label2id={"sadness": 0, "joy": 1, "love": 2, "anger": 3, "fear": 4, "surprise": 5},
)

Solution

The issue was in unsloth/models/vision.py in the FastBaseModel.from_pretrained() method. The parameters num_labels, id2label, and label2id were being passed directly to auto_model.from_pretrained() via **kwargs, but for sequence classification models, these parameters must be set in the model's config before instantiation.

Changes Made

File: unsloth/models/vision.py (lines ~784-800)

Added code to extract sequence classification parameters from kwargs and set them on the config before model loading:

# Handle sequence classification parameters (num_labels, id2label, label2id)
# These need to be set in the config before model instantiation
num_labels = kwargs.pop("num_labels", None)
id2label = kwargs.pop("id2label", None)
label2id = kwargs.pop("label2id", None)

if num_labels is not None:
    model_config.num_labels = num_labels
if id2label is not None:
    model_config.id2label = id2label
if label2id is not None:
    model_config.label2id = label2id

This ensures the parameters are in the config when auto_model.from_pretrained() is called, preventing them from being passed as unexpected keyword arguments to the model's __init__ method.

Testing

Validation Tests

  • Python syntax validation passes
  • Config parameter extraction works correctly
  • None parameter handling (backward compatibility)
  • Module imports successfully

Backward Compatibility

  • Models without num_labels continue to work (existing use cases unaffected)
  • Models with num_labels now work correctly (new functionality)
  • Partial parameters (only num_labels without id2label/label2id) handled correctly
  • No breaking changes to existing API

Usage After Fix

from unsloth import FastModel
from transformers import AutoModelForSequenceClassification
import os

os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"

# Now works! 
model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/ModernBERT-large",
    auto_model=AutoModelForSequenceClassification,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=False,
    load_in_16bit=True,
    num_labels=6,
    id2label={0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"},
    label2id={"sadness": 0, "joy": 1, "love": 2, "anger": 3, "fear": 4, "surprise": 5},
)

Additional Notes

  • This fix follows the same pattern used in unsloth/models/llama.py for handling num_labels in LLaMA-based models
  • The fix is minimal and surgical - only affects models that explicitly pass these parameters
  • Applies to all transformer models using AutoModelForSequenceClassification through FastModel.from_pretrained()
  • Works with both full finetuning and LoRA/QLoRA

Related

Checklist

  • Code follows the project's style guidelines
  • Changes are backward compatible
  • No breaking changes to existing functionality
  • Tested with ModernBERT sequence classification

Thank you for reviewing! This fix enables ModernBERT sequence classification while maintaining full backward compatibility. 🦥

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug that prevented users from loading BERT-based models for sequence classification with custom label configurations using FastModel.from_pretrained. By correctly handling num_labels, id2label, and label2id parameters and applying them to the model's configuration before instantiation, this change enables seamless fine-tuning of classification models, expanding the utility of the Unsloth library for various NLP tasks.

Highlights

  • Sequence Classification Parameter Handling: Resolved a TypeError that occurred when num_labels, id2label, or label2id were passed to FastModel.from_pretrained() for AutoModelForSequenceClassification with BERT-based models.
  • Model Configuration Update: Modified the FastBaseModel.from_pretrained() method to extract sequence classification parameters (num_labels, id2label, label2id) from kwargs and set them directly on the model's configuration (model_config) prior to model instantiation.
  • Enhanced Compatibility: Ensured that AutoModelForSequenceClassification can now correctly load and utilize custom label configurations, enabling fine-tuning of BERT models for classification tasks within the Unsloth framework.
Changelog
  • unsloth/models/vision.py
    • Implemented logic to extract and apply sequence classification parameters (num_labels, id2label, label2id) to the model configuration.
Activity
  • No human activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses an issue where sequence classification parameters like num_labels were not being passed to the model's configuration, causing errors when loading certain BERT-based models. The fix involves extracting these parameters from kwargs and setting them on the model_config before model instantiation. My review includes a suggestion to refactor the new code to be more concise and maintainable by using a loop, which reduces code duplication.

Comment on lines +787 to +796
num_labels = kwargs.pop("num_labels", None)
id2label = kwargs.pop("id2label", None)
label2id = kwargs.pop("label2id", None)

if num_labels is not None:
model_config.num_labels = num_labels
if id2label is not None:
model_config.id2label = id2label
if label2id is not None:
model_config.label2id = label2id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the current implementation is correct, it can be made more concise and maintainable by using a loop to handle the sequence classification parameters. This avoids repetition and makes it easier to add more parameters in the future.

Suggested change
num_labels = kwargs.pop("num_labels", None)
id2label = kwargs.pop("id2label", None)
label2id = kwargs.pop("label2id", None)
if num_labels is not None:
model_config.num_labels = num_labels
if id2label is not None:
model_config.id2label = id2label
if label2id is not None:
model_config.label2id = label2id
for param in ("num_labels", "id2label", "label2id"):
value = kwargs.pop(param, None)
if value is not None:
setattr(model_config, param, value)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! thanks!

git pull origin fix/bert-num-labels --rebase#    pick 779c4b06 Using a loop to handle the sequence classification parameters.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant