Fix: ModernBERT sequence classification num_labels support#4177
Fix: ModernBERT sequence classification num_labels support#4177adityaghai07 wants to merge 3 commits intounslothai:mainfrom
Conversation
for more information, see https://pre-commit.ci
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug that prevented users from loading BERT-based models for sequence classification with custom label configurations using FastModel.from_pretrained. By correctly handling num_labels, id2label, and label2id parameters and applying them to the model's configuration before instantiation, this change enables seamless fine-tuning of classification models, expanding the utility of the Unsloth library for various NLP tasks. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses an issue where sequence classification parameters like num_labels were not being passed to the model's configuration, causing errors when loading certain BERT-based models. The fix involves extracting these parameters from kwargs and setting them on the model_config before model instantiation. My review includes a suggestion to refactor the new code to be more concise and maintainable by using a loop, which reduces code duplication.
unsloth/models/vision.py
Outdated
| num_labels = kwargs.pop("num_labels", None) | ||
| id2label = kwargs.pop("id2label", None) | ||
| label2id = kwargs.pop("label2id", None) | ||
|
|
||
| if num_labels is not None: | ||
| model_config.num_labels = num_labels | ||
| if id2label is not None: | ||
| model_config.id2label = id2label | ||
| if label2id is not None: | ||
| model_config.label2id = label2id |
There was a problem hiding this comment.
While the current implementation is correct, it can be made more concise and maintainable by using a loop to handle the sequence classification parameters. This avoids repetition and makes it easier to add more parameters in the future.
| num_labels = kwargs.pop("num_labels", None) | |
| id2label = kwargs.pop("id2label", None) | |
| label2id = kwargs.pop("label2id", None) | |
| if num_labels is not None: | |
| model_config.num_labels = num_labels | |
| if id2label is not None: | |
| model_config.id2label = id2label | |
| if label2id is not None: | |
| model_config.label2id = label2id | |
| for param in ("num_labels", "id2label", "label2id"): | |
| value = kwargs.pop(param, None) | |
| if value is not None: | |
| setattr(model_config, param, value) |
git pull origin fix/bert-num-labels --rebase# pick 779c4b06 Using a loop to handle the sequence classification parameters.
Fix: Add support for sequence classification parameters in FastModel
Issue #4163
Problem
When using
FastModel.from_pretrained()withAutoModelForSequenceClassificationfor BERT-based models (including ModernBERT), users encountered aTypeError:This prevented users from loading sequence classification models with custom label configurations, making it impossible to use unsloth for fine-tuning BERT models on classification tasks.
Reproduction
Solution
The issue was in
unsloth/models/vision.pyin theFastBaseModel.from_pretrained()method. The parametersnum_labels,id2label, andlabel2idwere being passed directly toauto_model.from_pretrained()via**kwargs, but for sequence classification models, these parameters must be set in the model's config before instantiation.Changes Made
File:
unsloth/models/vision.py(lines ~784-800)Added code to extract sequence classification parameters from kwargs and set them on the config before model loading:
This ensures the parameters are in the config when
auto_model.from_pretrained()is called, preventing them from being passed as unexpected keyword arguments to the model's__init__method.Testing
Validation Tests
Backward Compatibility
num_labelscontinue to work (existing use cases unaffected)num_labelsnow work correctly (new functionality)num_labelswithoutid2label/label2id) handled correctlyUsage After Fix
Additional Notes
unsloth/models/llama.pyfor handlingnum_labelsin LLaMA-based modelsAutoModelForSequenceClassificationthroughFastModel.from_pretrained()Related
Checklist
Thank you for reviewing! This fix enables ModernBERT sequence classification while maintaining full backward compatibility. 🦥