-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Labels
component: trainingRelates to the SageMaker Training PlatformRelates to the SageMaker Training Platformtype: bug
Description
Describe the bug
hyperparameter_ranges argument of HyperparameterTuner constructor is annotated as Dict[str, ParameterRange], however HyperparameterTuner works correctly with PipelineVariable as key of Dict (hyperparam name), so the correct type annotation could be: Dict[str, ParameterRange]
To reproduce
That code works properly:
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import HyperparameterTuner, CategoricalParameter
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.steps import TuningStep
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline_context import (
PipelineSession
)
if __name__ == "__main__":
pipeline_session = PipelineSession()
estimator = PyTorch(
sagemaker_session=pipeline_session,
instance_type='ml.m5.large',
instance_count=1,
framework_version="2.3",
py_version="py311",
source_dir='source',
entry_point='main.py',
metric_definitions=[
{'Name': 'valid:loss', 'Regex': 'valid_loss=([0-9]+\\.?[0-9]*)'}
]
)
hparam_name = ParameterString("HParamName", default_value='hparam')
tuner = HyperparameterTuner(
estimator=estimator,
objective_metric_name='valid:loss',
objective_type='Minimize',
hyperparameter_ranges={
hparam_name: CategoricalParameter(
[1, 2]
)
},
max_jobs=2,
max_parallel_jobs=1,
base_tuning_job_name='test-tuning',
strategy='Grid',
metric_definitions=estimator.metric_definitions,
)
tuning_step = TuningStep(
name="Tuning",
step_args=tuner.fit(),
)
pipeline = Pipeline(
name="TestTuningPipeline",
parameters=[hparam_name],
steps=[tuning_step],
sagemaker_session=pipeline_session
)
pipeline.upsert()
execution = pipeline.start(
execution_display_name="TuningTest",
)
print(execution)
But the type annotation of hyperparameter_ranges is incorrect.
mypy main.py --follow-untyped-imports
reports:
test.py:35: error: Dict entry 0 has incompatible type "ParameterString": "CategoricalParameter"; expected "str": "ParameterRange" [dict-item]
Expected behavior
Correct type annotation of constructor arguments of HyperparameterTuner
Screenshots or logs
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 2.239.3
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): iirelevant
- Framework version: iirelevant
- Python version: 3.12
- CPU or GPU: iirelevant
- Custom Docker image (Y/N): iirelevant
Additional context
mari-swan and krhadas
Metadata
Metadata
Assignees
Labels
component: trainingRelates to the SageMaker Training PlatformRelates to the SageMaker Training Platformtype: bug