-
Notifications
You must be signed in to change notification settings - Fork 148
add FIM dataset support #2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add FIM dataset support #2066
Conversation
Signed-off-by: dimapihtar <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
|
/ok to test c129e61 |
📝 WalkthroughWalkthroughIntroduces support for Fill-In-the-Middle (FIM) dataset configuration by adding GPTFIMDatasetConfig class, integrating it into the dataset provider registry, and updating dataset selection logic to recognize FIM-specific attributes during provisioning. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/training/config.py`:
- Around line 453-486: The dataset config GPTFIMDatasetConfig can bypass
deferred finalization and seq-length validation because ConfigContainer.validate
only finalizes GPTDatasetConfig; update ConfigContainer.validate to explicitly
handle GPTFIMDatasetConfig by calling dataset.finalize() for instances of
(GPTDatasetConfig, GPTFIMDatasetConfig) and include GPTFIMDatasetConfig in the
seq-length check alongside GPTDatasetConfig and FinetuningDatasetConfig (i.e.,
ensure data_seq_length reads from self.dataset.seq_length for
GPTFIMDatasetConfig as well), or alternatively make GPTFIMDatasetConfig inherit
from GPTDatasetConfig so existing checks apply.
In `@tests/unit_tests/training/test_config.py`:
- Around line 313-352: The file defines two classes with the same name
TestMockGPTDatasetConfig causing the latter to overwrite the former and hide the
GPTFIM tests; rename the second TestMockGPTDatasetConfig class to a unique name
(e.g., TestMockGPTDatasetConfigFallback or TestMockGPTDatasetConfigAlternate) so
both classes are distinct, update any internal references if present, and run
the tests to ensure the GPTFIM tests (the class that instantiates
GPTFIMDatasetConfig and asserts its mixins) are collected.
| @dataclass | ||
| class GPTFIMDatasetConfig(MCoreGPTFIMDatasetConfig, DataloaderConfig): | ||
| """Megatron Core GPTFIMDatasetConfig with deferred post-init. | ||
| This class inherits from MCore's GPTFIMDatasetConfig and DataloaderConfig but defers the | ||
| execution of post_init() until finalize() is explicitly called. This allows | ||
| for field modifications after construction but before computed fields are calculated. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| seq_length: int | None = None, | ||
| skip_getting_attention_mask_from_dataset: bool = True, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Args: | ||
| seq_length (int | None): the sequence length. If not provided, `sequence_length` must be in kwargs. | ||
| skip_getting_attention_mask_from_dataset (bool): if set, the dataset will pass a None attention mask | ||
| and the attention mask is autogenerated from the attn backend. | ||
| """ | ||
| self.fim_data = True | ||
| self.skip_getting_attention_mask_from_dataset = skip_getting_attention_mask_from_dataset | ||
|
|
||
| if seq_length is not None: | ||
| kwargs["sequence_length"] = seq_length | ||
| elif "sequence_length" not in kwargs: | ||
| raise ValueError("Either `seq_length` or `sequence_length` must be provided.") | ||
|
|
||
| dataloader_kwargs = {k: kwargs.pop(k) for k in list(kwargs) if k in DataloaderConfig.__dataclass_fields__} | ||
| MCoreGPTFIMDatasetConfig.__init__(self, *args, **kwargs) | ||
| DataloaderConfig.__init__(self, **dataloader_kwargs) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPTFIMDatasetConfig may skip deferred finalization/validation.
ConfigContainer.validate() only finalizes GPTDatasetConfig and only runs the seq-length check for (GPTDatasetConfig, FinetuningDatasetConfig). Because GPTFIMDatasetConfig doesn’t inherit GPTDatasetConfig, it can bypass deferred __post_init__ and seq-length validation. Consider explicitly including GPTFIMDatasetConfig in those checks (or make it extend GPTDatasetConfig if feasible).
# in ConfigContainer.validate()
if isinstance(self.dataset, (GPTDatasetConfig, GPTFIMDatasetConfig)):
self.dataset.finalize()
# ...
if isinstance(self.dataset, (GPTDatasetConfig, GPTFIMDatasetConfig, FinetuningDatasetConfig)):
data_seq_length = self.dataset.seq_length🧰 Tools
🪛 Ruff (0.14.14)
481-481: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/config.py` around lines 453 - 486, The dataset
config GPTFIMDatasetConfig can bypass deferred finalization and seq-length
validation because ConfigContainer.validate only finalizes GPTDatasetConfig;
update ConfigContainer.validate to explicitly handle GPTFIMDatasetConfig by
calling dataset.finalize() for instances of (GPTDatasetConfig,
GPTFIMDatasetConfig) and include GPTFIMDatasetConfig in the seq-length check
alongside GPTDatasetConfig and FinetuningDatasetConfig (i.e., ensure
data_seq_length reads from self.dataset.seq_length for GPTFIMDatasetConfig as
well), or alternatively make GPTFIMDatasetConfig inherit from GPTDatasetConfig
so existing checks apply.
| class TestMockGPTDatasetConfig: | ||
| """Tests desired behavior for GPTFIMDatasetConfig.""" | ||
|
|
||
| def test_initialization(self): | ||
| config = GPTFIMDatasetConfig( | ||
| random_seed=1234, | ||
| seq_length=512, | ||
| fim_rate=0.1, | ||
| fim_no_prefix="test", | ||
| fim_extra_tokens={"middle": "<middle>"}, | ||
| fim_split_sample="test sample", | ||
| reset_position_ids=False, | ||
| reset_attention_mask=False, | ||
| eod_mask_loss=False, | ||
| ) | ||
| config.finalize() | ||
|
|
||
| # Should be an instance of both GPTDatasetConfig and GPTFIMDatasetConfig | ||
| from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig | ||
| from megatron.training.datasets.fim_dataset import GPTFIMDatasetConfig as MCoreGPTFIMDatasetConfig | ||
|
|
||
| assert isinstance(config, GPTFIMDatasetConfig) | ||
| assert isinstance(config, GPTDatasetConfig) | ||
| assert isinstance(config, MCoreGPTFIMDatasetConfig) | ||
| assert isinstance(config, BlendedMegatronDatasetConfig) | ||
|
|
||
| # Should have all the expected fields from parent class | ||
| assert hasattr(config, "random_seed") | ||
| assert hasattr(config, "seq_length") | ||
| assert hasattr(config, "path_to_cache") | ||
|
|
||
| # Verify have all the expected fields were set proeprly | ||
| assert config.fim_data | ||
| assert config.fim_rate == 0.1 | ||
| assert config.fim_no_prefix == "test" | ||
| assert config.fim_split_sample == "test sample" | ||
| assert config.fim_extra_tokens["middle"] == "<middle>" | ||
|
|
||
|
|
||
| class TestMockGPTDatasetConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test class name hides GPTFIM tests.
There are two TestMockGPTDatasetConfig classes; the latter overwrites the former at import time, so the GPTFIM tests are not collected. Rename the new class to avoid masking.
✅ Suggested rename
-class TestMockGPTDatasetConfig:
+class TestGPTFIMDatasetConfig:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class TestMockGPTDatasetConfig: | |
| """Tests desired behavior for GPTFIMDatasetConfig.""" | |
| def test_initialization(self): | |
| config = GPTFIMDatasetConfig( | |
| random_seed=1234, | |
| seq_length=512, | |
| fim_rate=0.1, | |
| fim_no_prefix="test", | |
| fim_extra_tokens={"middle": "<middle>"}, | |
| fim_split_sample="test sample", | |
| reset_position_ids=False, | |
| reset_attention_mask=False, | |
| eod_mask_loss=False, | |
| ) | |
| config.finalize() | |
| # Should be an instance of both GPTDatasetConfig and GPTFIMDatasetConfig | |
| from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig | |
| from megatron.training.datasets.fim_dataset import GPTFIMDatasetConfig as MCoreGPTFIMDatasetConfig | |
| assert isinstance(config, GPTFIMDatasetConfig) | |
| assert isinstance(config, GPTDatasetConfig) | |
| assert isinstance(config, MCoreGPTFIMDatasetConfig) | |
| assert isinstance(config, BlendedMegatronDatasetConfig) | |
| # Should have all the expected fields from parent class | |
| assert hasattr(config, "random_seed") | |
| assert hasattr(config, "seq_length") | |
| assert hasattr(config, "path_to_cache") | |
| # Verify have all the expected fields were set proeprly | |
| assert config.fim_data | |
| assert config.fim_rate == 0.1 | |
| assert config.fim_no_prefix == "test" | |
| assert config.fim_split_sample == "test sample" | |
| assert config.fim_extra_tokens["middle"] == "<middle>" | |
| class TestMockGPTDatasetConfig: | |
| class TestMockGPTDatasetConfig: | |
| """Tests desired behavior for GPTFIMDatasetConfig.""" | |
| def test_initialization(self): | |
| config = GPTFIMDatasetConfig( | |
| random_seed=1234, | |
| seq_length=512, | |
| fim_rate=0.1, | |
| fim_no_prefix="test", | |
| fim_extra_tokens={"middle": "<middle>"}, | |
| fim_split_sample="test sample", | |
| reset_position_ids=False, | |
| reset_attention_mask=False, | |
| eod_mask_loss=False, | |
| ) | |
| config.finalize() | |
| # Should be an instance of both GPTDatasetConfig and GPTFIMDatasetConfig | |
| from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig | |
| from megatron.training.datasets.fim_dataset import GPTFIMDatasetConfig as MCoreGPTFIMDatasetConfig | |
| assert isinstance(config, GPTFIMDatasetConfig) | |
| assert isinstance(config, GPTDatasetConfig) | |
| assert isinstance(config, MCoreGPTFIMDatasetConfig) | |
| assert isinstance(config, BlendedMegatronDatasetConfig) | |
| # Should have all the expected fields from parent class | |
| assert hasattr(config, "random_seed") | |
| assert hasattr(config, "seq_length") | |
| assert hasattr(config, "path_to_cache") | |
| # Verify have all the expected fields were set proeprly | |
| assert config.fim_data | |
| assert config.fim_rate == 0.1 | |
| assert config.fim_no_prefix == "test" | |
| assert config.fim_split_sample == "test sample" | |
| assert config.fim_extra_tokens["middle"] == "<middle>" | |
| class TestGPTFIMDatasetConfig: |
🧰 Tools
🪛 Ruff (0.14.14)
352-352: Redefinition of unused TestMockGPTDatasetConfig from line 313: TestMockGPTDatasetConfig redefined here
(F811)
🤖 Prompt for AI Agents
In `@tests/unit_tests/training/test_config.py` around lines 313 - 352, The file
defines two classes with the same name TestMockGPTDatasetConfig causing the
latter to overwrite the former and hide the GPTFIM tests; rename the second
TestMockGPTDatasetConfig class to a unique name (e.g.,
TestMockGPTDatasetConfigFallback or TestMockGPTDatasetConfigAlternate) so both
classes are distinct, update any internal references if present, and run the
tests to ensure the GPTFIM tests (the class that instantiates
GPTFIMDatasetConfig and asserts its mixins) are collected.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.