diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index 3b7659016e..b6523e14dd 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -136,6 +136,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), + "gpt_oss": ("custom_model", "custom_pretrain.py"), } for key in model_type_to_script: diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 633317927b..208239e368 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -99,6 +99,7 @@ def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), + "gpt_oss": ("custom_model", "custom_pretrain.py"), } if "model" not in recipe: diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index a58b1f641e..17cfda55b0 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -237,6 +237,11 @@ def test_get_args_from_recipe_with_nova_and_role(mock_get_args_from_nova_recipe, "script": "deepseek_pretrain.py", "model_base_name": "deepseek", }, + { + "model_type": "gpt_oss", + "script": "custom_pretrain.py", + "model_base_name": "custom_model", + }, ], ) def test_get_trainining_recipe_gpu_model_name_and_script(test_case): diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 34d3c6784b..8352f3090b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -1087,6 +1087,14 @@ def test_training_recipe_for_trainium(sagemaker_session): }, }, }, + { + "script": "custom_pretrain.py", + "recipe": { + "model": { + "model_type": "gpt_oss", + }, + }, + }, ], ) @patch("shutil.copyfile")