Skip to content

Commit 4c425c2

Browse files
authored
Add host header when unset to crt requests (#427)
1 parent 88b8113 commit 4c425c2

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,14 @@ async def _marshal_request(
292292
"""Create :py:class:`awscrt.http.HttpRequest` from
293293
:py:class:`smithy_http.aio.HTTPRequest`"""
294294
headers_list = []
295+
if "Host" not in request.fields:
296+
request.fields.set_field(
297+
Field(name="Host", values=[request.destination.host])
298+
)
299+
295300
for fld in request.fields.entries.values():
296-
if fld.kind != FieldPosition.HEADER:
301+
# TODO: Use literal values for "header"/"trailer".
302+
if fld.kind.value != FieldPosition.HEADER.value:
297303
continue
298304
for val in fld.values:
299305
headers_list.append((fld.name, val))

packages/smithy-http/src/smithy_http/interfaces/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def __len__(self) -> int:
9898
"""Get total number of Field entries."""
9999
...
100100

101+
def __contains__(self, key: str) -> bool:
102+
"""Allows in/not in checks on Field entries."""
103+
...
104+
101105
def get_by_type(self, kind: FieldPosition) -> list[Field]:
102106
"""Helper function for retrieving specific types of fields.
103107

packages/smithy-http/tests/unit/aio/test_crt.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
from copy import deepcopy
4+
from io import BytesIO
45

56
import pytest
67

8+
from smithy_core import URI
9+
from smithy_http import Fields
10+
from smithy_http.aio import HTTPRequest
711
from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream
812

913

@@ -12,6 +16,22 @@ def test_deepcopy_client() -> None:
1216
deepcopy(client)
1317

1418

19+
async def test_client_marshal_request() -> None:
20+
client = AWSCRTHTTPClient()
21+
request = HTTPRequest(
22+
method="GET",
23+
destination=URI(
24+
host="example.com", path="/path", query="key1=value1&key2=value2"
25+
),
26+
body=BytesIO(),
27+
fields=Fields(),
28+
)
29+
crt_request = await client._marshal_request(request) # type: ignore
30+
assert crt_request.headers.get("host") == "example.com" # type: ignore
31+
assert crt_request.method == "GET" # type: ignore
32+
assert crt_request.path == "/path?key1=value1&key2=value2" # type: ignore
33+
34+
1535
def test_stream_write() -> None:
1636
stream = BufferableByteStream()
1737
stream.write(b"foo")

0 commit comments

Comments
 (0)