Skip to content

Commit 7ef18b1

Browse files
authored
GPT OSS Hotfix (#5263)
* changes for gpt_oss jobs support * added unit tests * fixing unit test
1 parent eb13102 commit 7ef18b1

File tree

4 files changed

+15
-0
lines changed

4 files changed

+15
-0
lines changed

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
136136
"mistral": ("mistral", "mistral_pretrain.py"),
137137
"mixtral": ("mixtral", "mixtral_pretrain.py"),
138138
"deepseek": ("deepseek", "deepseek_pretrain.py"),
139+
"gpt_oss": ("custom_model", "custom_pretrain.py"),
139140
}
140141

141142
for key in model_type_to_script:

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir):
9999
"mistral": ("mistral", "mistral_pretrain.py"),
100100
"mixtral": ("mixtral", "mixtral_pretrain.py"),
101101
"deepseek": ("deepseek", "deepseek_pretrain.py"),
102+
"gpt_oss": ("custom_model", "custom_pretrain.py"),
102103
}
103104

104105
if "model" not in recipe:

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def test_get_args_from_recipe_with_nova_and_role(mock_get_args_from_nova_recipe,
237237
"script": "deepseek_pretrain.py",
238238
"model_base_name": "deepseek",
239239
},
240+
{
241+
"model_type": "gpt_oss",
242+
"script": "custom_pretrain.py",
243+
"model_base_name": "custom_model",
244+
},
240245
],
241246
)
242247
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):

tests/unit/test_pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,14 @@ def test_training_recipe_for_trainium(sagemaker_session):
10871087
},
10881088
},
10891089
},
1090+
{
1091+
"script": "custom_pretrain.py",
1092+
"recipe": {
1093+
"model": {
1094+
"model_type": "gpt_oss",
1095+
},
1096+
},
1097+
},
10901098
],
10911099
)
10921100
@patch("shutil.copyfile")

0 commit comments

Comments
 (0)