Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def predict(
inference_id=None,
custom_attributes=None,
component_name: Optional[str] = None,
target_container_hostname=None,
):
"""Return the inference from the specified endpoint.

Expand Down Expand Up @@ -188,6 +189,9 @@ def predict(
function (Default: None).
component_name (str): Optional. Name of the Amazon SageMaker inference component
corresponding the predictor.
target_container_hostname (str): Optional. If the endpoint hosts multiple containers
and is configured to use direct invocation, this parameter specifies the host name
of the container to invoke. (Default: None).

Returns:
object: Inference for the given input. If a deserializer was specified when creating
Expand All @@ -203,6 +207,7 @@ def predict(
target_variant=target_variant,
inference_id=inference_id,
custom_attributes=custom_attributes,
target_container_hostname=target_container_hostname,
Copy link
Author

@kwnath kwnath Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this doesn't work when running something like:

response = runtime_sm_client.invoke_endpoint(
    EndpointName="endpoint",
    ContentType="application/json",
    TargetModel="target_model",
    TargetContainerHostname="primary",
    Body="...",
)

Returns:
ValidationError: An error occurred (ValidationError) when calling the InvokeEndpoint operation: Request x endpoint does not support TargetContainerHostname

)

inference_component_name = component_name or self._get_component_name()
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,29 @@ def test_predict_call_with_inference_id():
assert result == RETURN_VALUE


def test_predict_call_with_target_container_hostname():
sagemaker_session = empty_sagemaker_session()
predictor = Predictor(ENDPOINT, sagemaker_session)

data = "untouched"
result = predictor.predict(data, target_container_hostname="test_target_container_hostname")

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

expected_request_args = {
"Accept": DEFAULT_ACCEPT,
"Body": data,
"ContentType": DEFAULT_CONTENT_TYPE,
"EndpointName": ENDPOINT,
"TargetContainerHostname": "test_target_container_hostname",
}

_, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args

assert result == RETURN_VALUE


def test_multi_model_predict_call():
sagemaker_session = empty_sagemaker_session()
predictor = Predictor(ENDPOINT, sagemaker_session)
Expand Down