From 5c1d94bdfea5fc7cbbe9e244ab56fc255cba169e Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Tue, 17 Feb 2026 10:57:25 -0500 Subject: [PATCH 1/6] Add AWS JSON 1.0 and 1.1 traits --- .../src/smithy_aws_core/traits.py | 80 ++++++++++++++----- .../smithy-aws-core/tests/unit/test_traits.py | 18 ++++- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/traits.py b/packages/smithy-aws-core/src/smithy_aws_core/traits.py index 3902a55ff..2ab263ce0 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/traits.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/traits.py @@ -15,6 +15,33 @@ from smithy_core.traits import DynamicTrait, Trait +def _parse_http_protocol_values( + value: DocumentValue | DynamicTrait | None, +) -> tuple[tuple[str, ...], tuple[str, ...]]: + document_value = value or {} + assert isinstance(document_value, Mapping) + + http_versions_raw = document_value.get("http", ["http/1.1"]) + assert isinstance(http_versions_raw, Sequence) + http_versions_list: list[str] = [] + for entry in http_versions_raw: + assert isinstance(entry, str) + http_versions_list.append(entry) + http_versions = tuple(http_versions_list) + + event_stream_http_versions_raw = document_value.get("eventStreamHttp") + if not event_stream_http_versions_raw: + return http_versions, http_versions + + assert isinstance(event_stream_http_versions_raw, Sequence) + event_stream_http_versions_list: list[str] = [] + for entry in event_stream_http_versions_raw: + assert isinstance(entry, str) + event_stream_http_versions_list.append(entry) + + return http_versions, tuple(event_stream_http_versions_list) + + @dataclass(init=False, frozen=True) class RestJson1Trait(Trait, id=ShapeID("aws.protocols#restJson1")): http: Sequence[str] = field( @@ -26,24 +53,41 @@ class RestJson1Trait(Trait, id=ShapeID("aws.protocols#restJson1")): def __init__(self, value: DocumentValue | DynamicTrait = None): super().__init__(value) - document_value = value or {} - assert isinstance(document_value, Mapping) - - http_versions = document_value.get("http", ["http/1.1"]) - assert isinstance(http_versions, Sequence) - for val in http_versions: - assert isinstance(val, str) - object.__setattr__(self, "http", tuple(http_versions)) - event_stream_http_versions = document_value.get("eventStreamHttp") - if not event_stream_http_versions: - object.__setattr__(self, "event_stream_http", self.http) - else: - assert isinstance(event_stream_http_versions, Sequence) - for val in event_stream_http_versions: - assert isinstance(val, str) - object.__setattr__( - self, "event_stream_http", tuple(event_stream_http_versions) - ) + http, event_stream_http = _parse_http_protocol_values(value) + object.__setattr__(self, "http", http) + object.__setattr__(self, "event_stream_http", event_stream_http) + + +@dataclass(init=False, frozen=True) +class AwsJson1_0Trait(Trait, id=ShapeID("aws.protocols#awsJson1_0")): + http: Sequence[str] = field( + repr=False, hash=False, compare=False, default_factory=tuple + ) + event_stream_http: Sequence[str] = field( + repr=False, hash=False, compare=False, default_factory=tuple + ) + + def __init__(self, value: DocumentValue | DynamicTrait = None): + super().__init__(value) + http, event_stream_http = _parse_http_protocol_values(value) + object.__setattr__(self, "http", http) + object.__setattr__(self, "event_stream_http", event_stream_http) + + +@dataclass(init=False, frozen=True) +class AwsJson1_1Trait(Trait, id=ShapeID("aws.protocols#awsJson1_1")): + http: Sequence[str] = field( + repr=False, hash=False, compare=False, default_factory=tuple + ) + event_stream_http: Sequence[str] = field( + repr=False, hash=False, compare=False, default_factory=tuple + ) + + def __init__(self, value: DocumentValue | DynamicTrait = None): + super().__init__(value) + http, event_stream_http = _parse_http_protocol_values(value) + object.__setattr__(self, "http", http) + object.__setattr__(self, "event_stream_http", event_stream_http) @dataclass(init=False, frozen=True) diff --git a/packages/smithy-aws-core/tests/unit/test_traits.py b/packages/smithy-aws-core/tests/unit/test_traits.py index d4f04ebf1..5a41db40f 100644 --- a/packages/smithy-aws-core/tests/unit/test_traits.py +++ b/packages/smithy-aws-core/tests/unit/test_traits.py @@ -1,9 +1,21 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from smithy_aws_core.traits import RestJson1Trait +import pytest +from smithy_aws_core.traits import ( + AwsJson1_0Trait, + AwsJson1_1Trait, + RestJson1Trait, +) -def test_allows_empty_restjson1_value() -> None: - trait = RestJson1Trait(None) +@pytest.mark.parametrize( + "trait_type", + [RestJson1Trait, AwsJson1_0Trait, AwsJson1_1Trait], +) +def test_allows_empty_protocol_trait_value( + trait_type: type[RestJson1Trait] | type[AwsJson1_0Trait] | type[AwsJson1_1Trait], +) -> None: + trait = trait_type(None) assert trait.http == ("http/1.1",) + assert trait.event_stream_http == ("http/1.1",) From d6704dac654ef380b202e8b7829266b6715c08ed Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 26 Feb 2026 11:45:36 -0500 Subject: [PATCH 2/6] smithy-json: support quoted non-finite numeric serde --- .../smithy-json/src/smithy_json/_private/deserializers.py | 6 +++++- .../smithy-json/src/smithy_json/_private/serializers.py | 2 +- packages/smithy-json/tests/unit/__init__.py | 6 ++++++ packages/smithy-json/tests/unit/test_deserializers.py | 7 +++++++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/packages/smithy-json/src/smithy_json/_private/deserializers.py b/packages/smithy-json/src/smithy_json/_private/deserializers.py index bbbb16927..961429b31 100644 --- a/packages/smithy-json/src/smithy_json/_private/deserializers.py +++ b/packages/smithy-json/src/smithy_json/_private/deserializers.py @@ -141,8 +141,12 @@ def read_big_decimal(self, schema: Schema) -> Decimal: match event.value: case Decimal(): return event.value - case int() | float(): + case int(): + return Decimal(event.value) + case float(): return Decimal.from_float(event.value) + case "Infinity" | "-Infinity" | "NaN": + return Decimal(event.value) case _: raise JSONTokenError("number", event) diff --git a/packages/smithy-json/src/smithy_json/_private/serializers.py b/packages/smithy-json/src/smithy_json/_private/serializers.py index c1cd3df70..3f877d52c 100644 --- a/packages/smithy-json/src/smithy_json/_private/serializers.py +++ b/packages/smithy-json/src/smithy_json/_private/serializers.py @@ -288,7 +288,7 @@ def write_float(self, value: float | Decimal) -> None: def _write_non_numeric_float(self, value: float | Decimal) -> bool: if value != value: - self._sink.write(b"NaN") + self._sink.write(b'"NaN"') return True if value == _INF: diff --git a/packages/smithy-json/tests/unit/__init__.py b/packages/smithy-json/tests/unit/__init__.py index 294dc5457..6ce4ceea8 100644 --- a/packages/smithy-json/tests/unit/__init__.py +++ b/packages/smithy-json/tests/unit/__init__.py @@ -346,7 +346,13 @@ def _read_optional_map(k: str, d: ShapeDeserializer): (True, b"true"), (1, b"1"), (1.1, b"1.1"), + (float("nan"), b'"NaN"'), + (float("inf"), b'"Infinity"'), + (float("-inf"), b'"-Infinity"'), (Decimal("1.1"), b"1.1"), + (Decimal("NaN"), b'"NaN"'), + (Decimal("Infinity"), b'"Infinity"'), + (Decimal("-Infinity"), b'"-Infinity"'), (b"foo", b'"Zm9v"'), ("foo", b'"foo"'), (datetime(2024, 5, 15, tzinfo=UTC), b'"2024-05-15T00:00:00Z"'), diff --git a/packages/smithy-json/tests/unit/test_deserializers.py b/packages/smithy-json/tests/unit/test_deserializers.py index 309a7feae..00a6ed2cd 100644 --- a/packages/smithy-json/tests/unit/test_deserializers.py +++ b/packages/smithy-json/tests/unit/test_deserializers.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import math from datetime import datetime from decimal import Decimal from typing import Any @@ -89,6 +90,12 @@ def _read_optional_map(k: str, d: ShapeDeserializer): actual_value = actual.as_value() expected_value = expected.as_value() assert actual_value == expected_value + elif isinstance(expected, float) and math.isnan(expected): + assert isinstance(actual, float) + assert math.isnan(actual) + elif isinstance(expected, Decimal) and expected.is_nan(): + assert isinstance(actual, Decimal) + assert actual.is_nan() else: assert actual == expected From 2eb202327efba29edbb57fe13e3cf11b0dad5eb6 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 26 Feb 2026 11:45:50 -0500 Subject: [PATCH 3/6] codegen: improve generated protocol test assertions --- .../codegen/HttpProtocolTestGenerator.java | 51 +++++++++++++++++-- .../RestJsonProtocolGenerator.java | 6 --- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java index 0a06f15bf..189276680 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpProtocolTestGenerator.java @@ -16,6 +16,7 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.Stream; +import software.amazon.smithy.aws.traits.auth.SigV4Trait; import software.amazon.smithy.codegen.core.CodegenException; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.model.Model; @@ -188,12 +189,14 @@ private void generateRequestTest(OperationShape operation, HttpRequestTestCase t endpoint_uri="https://$L/$L", transport = $T(), retry_strategy=SimpleRetryStrategy(max_attempts=1), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), host, path, - REQUEST_TEST_ASYNC_HTTP_CLIENT_SYMBOL); + REQUEST_TEST_ASYNC_HTTP_CLIENT_SYMBOL, + (Runnable) this::writeSigV4TestConfig); })); // Generate the input using the expected shape and params @@ -437,13 +440,15 @@ private void generateResponseTest(OperationShape operation, HttpResponseTestCase headers=$J, body=b$S, ), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), RESPONSE_TEST_ASYNC_HTTP_CLIENT_SYMBOL, testCase.getCode(), CodegenUtils.toTuples(testCase.getHeaders()), - testCase.getBody().filter(body -> !body.isEmpty()).orElse("")); + testCase.getBody().filter(body -> !body.isEmpty()).orElse(""), + (Runnable) this::writeSigV4TestConfig); })); // Create an empty input object to pass var inputShape = model.expectShape(operation.getInputShape(), StructureShape.class); @@ -490,13 +495,15 @@ private void generateErrorResponseTest( headers=$J, body=b$S, ), + ${C|} ) """, CodegenUtils.getConfigSymbol(context.settings()), RESPONSE_TEST_ASYNC_HTTP_CLIENT_SYMBOL, testCase.getCode(), CodegenUtils.toTuples(testCase.getHeaders()), - testCase.getBody().orElse("")); + testCase.getBody().orElse(""), + (Runnable) this::writeSigV4TestConfig); })); // Create an empty input object to pass var inputShape = model.expectShape(operation.getInputShape(), StructureShape.class); @@ -531,7 +538,7 @@ private void assertResponseEqual(HttpMessageTestCase testCase, Shape operationOr .findAny(); if (streamBinding.isEmpty()) { - writer.write("assert actual == expected\n"); + writer.write("_assert_modeled_value(actual, expected)\n"); return; } @@ -556,7 +563,7 @@ assert isinstance(actual.$1L, AsyncByteStream) compareMediaBlob(testCase, writer); continue; } - writer.write("assert actual.$1L == expected.$1L\n", memberName); + writer.write("_assert_modeled_value(actual.$1L, expected.$1L)\n", memberName); } } @@ -607,10 +614,26 @@ private void writeClientBlock( }); } + private void writeSigV4TestConfig() { + if (!service.hasTrait(SigV4Trait.class)) { + return; + } + writer.addImport("smithy_aws_core.identity", "StaticCredentialsResolver"); + writer.write(""" + region="us-east-1", + aws_access_key_id="test-access-key-id", + aws_secret_access_key="test-secret-access-key", + aws_credentials_identity_resolver=StaticCredentialsResolver(), + """); + } + private void writeUtilStubs(Symbol serviceSymbol) { LOGGER.fine(String.format("Writing utility stubs for %s : %s", serviceSymbol.getName(), protocol.getName())); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); + writer.addStdlibImport("dataclasses", "fields"); + writer.addStdlibImport("dataclasses", "is_dataclass"); + writer.addStdlibImport("math", "isnan"); writer.addImports("smithy_http.interfaces", Set.of( "HTTPRequestConfiguration", @@ -621,6 +644,24 @@ private void writeUtilStubs(Symbol serviceSymbol) { writer.addImport("smithy_core.aio.utils", "async_list"); writer.write(""" + def _assert_modeled_value(actual: object, expected: object) -> None: + if isinstance(expected, float) and isnan(expected): + assert isinstance(actual, float) + assert isnan(actual) + return + + if is_dataclass(expected): + assert is_dataclass(actual) + for field in fields(expected): + _assert_modeled_value( + getattr(actual, field.name), + getattr(expected, field.name), + ) + return + + assert actual == expected + + class $1L($2T): ""\"A test error that subclasses the service-error for protocol tests.""\" diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java index 09f49690b..f8cf5f118 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java @@ -31,12 +31,6 @@ public class RestJsonProtocolGenerator implements ProtocolGenerator { private static final Set TESTS_TO_SKIP = Set.of( - // These two tests essentially try to assert nan == nan, - // which is never true. We should update the generator to - // make specific assertions for these. - "RestJsonSupportsNaNFloatHeaderOutputs", - "RestJsonSupportsNaNFloatInputs", - // This requires support of idempotency autofill "RestJsonQueryIdempotencyTokenAutoFill", From ad43c77858171cb261e7b35269c3a5cb7a0f7d62 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 26 Feb 2026 11:47:18 -0500 Subject: [PATCH 4/6] smithy-http: preserve host prefixes when applying endpoints --- packages/smithy-http/src/smithy_http/aio/protocols.py | 9 ++++++++- packages/smithy-http/tests/unit/aio/test_protocols.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/packages/smithy-http/src/smithy_http/aio/protocols.py b/packages/smithy-http/src/smithy_http/aio/protocols.py index af32cee16..968ffaacb 100644 --- a/packages/smithy-http/src/smithy_http/aio/protocols.py +++ b/packages/smithy-http/src/smithy_http/aio/protocols.py @@ -53,11 +53,18 @@ def set_service_endpoint( if uri.query and previous.query: query = f"{uri.query}&{previous.query}" + has_host_prefix = bool(previous.host) and previous.host != "." + host = uri.host + if has_host_prefix and uri.host and previous.host.endswith("."): + host = f"{previous.host}{uri.host}" + elif has_host_prefix: + host = previous.host + request.destination = _URI( scheme=uri.scheme, username=uri.username or previous.username, password=uri.password or previous.password, - host=uri.host, + host=host, port=uri.port or previous.port, path=path, query=query, diff --git a/packages/smithy-http/tests/unit/aio/test_protocols.py b/packages/smithy-http/tests/unit/aio/test_protocols.py index 4ae18ce67..2fc9d36ee 100644 --- a/packages/smithy-http/tests/unit/aio/test_protocols.py +++ b/packages/smithy-http/tests/unit/aio/test_protocols.py @@ -120,6 +120,16 @@ def deserialize_response( URI(host="com.example"), URI(host="com.example", fragment="header"), ), + ( + URI(host="foo."), + URI(host="com.example"), + URI(host="foo.com.example"), + ), + ( + URI(host="."), + URI(host="com.example"), + URI(host="com.example"), + ), ], ) def test_http_protocol_joins_uris( From 024382658d3eee36d23203fe9100a0ed78cd03b6 Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 26 Feb 2026 11:48:29 -0500 Subject: [PATCH 5/6] smithy-aws-core: add awsJson client protocol support --- .../src/smithy_aws_core/aio/protocols.py | 529 ++++++++++++++++-- .../src/smithy_aws_core/utils.py | 17 + .../tests/unit/aio/test_protocols.py | 328 ++++++++++- .../smithy-aws-core/tests/unit/test_utils.py | 28 +- 4 files changed, 837 insertions(+), 65 deletions(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 709651a4a..8494c7c9e 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -1,31 +1,58 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Callable +from collections.abc import AsyncIterable, Callable from inspect import iscoroutinefunction +from string import Formatter from typing import TYPE_CHECKING, Any, Final +from urllib.parse import quote as urlquote -from smithy_core.aio.interfaces import AsyncWriter +from smithy_core import URI as _URI +from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter +from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob from smithy_core.aio.interfaces.auth import AuthScheme from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver -from smithy_core.aio.types import AsyncBytesReader +from smithy_core.aio.types import AsyncBytesProvider, AsyncBytesReader from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.documents import TypeRegistry from smithy_core.exceptions import ( + CallError, DiscriminatorError, + ExpectationNotMetError, MissingDependencyError, + ModeledError, UnsupportedStreamError, ) -from smithy_core.interfaces import TypedProperties +from smithy_core.interfaces import ( + BytesReader, + SeekableBytesReader, + TypedProperties, + URI, + is_streaming_blob, +) +from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob +from smithy_core.prelude import DOCUMENT from smithy_core.schemas import APIOperation, Schema from smithy_core.serializers import SerializeableShape from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import EndpointTrait, HTTPTrait from smithy_core.types import TimestampFormat +from smithy_http import tuples_to_fields +from smithy_http.aio import HTTPRequest as _HTTPRequest from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse -from smithy_http.aio.protocols import HttpBindingClientProtocol +from smithy_http.aio.protocols import HttpBindingClientProtocol, HttpClientProtocol from smithy_json import JSONCodec, JSONDocument -from ..traits import RestJson1Trait -from ..utils import parse_document_discriminator, parse_error_code +from ..traits import ( + AwsJson1_0Trait, + AwsJson1_1Trait, + RestJson1Trait, +) +from ..utils import ( + parse_document_discriminator, + parse_error_code, + parse_header_error_code, +) try: from smithy_aws_event_stream.aio import ( @@ -38,6 +65,8 @@ except ImportError: _HAS_EVENT_STREAM = False # type: ignore +_AWS_ERROR_HEADER_KEY: Final = "x-amzn-errortype" + if TYPE_CHECKING: from smithy_aws_event_stream.aio import ( AWSEventPublisher, @@ -55,8 +84,17 @@ def _assert_event_stream_capable() -> None: ) +def _first_field_value(response: HTTPResponse, field_name: str) -> str | None: + if field_name not in response.fields: + return None + values = response.fields[field_name].values + return values[0] if values else None + + class AWSErrorIdentifier(HTTPErrorIdentifier): - _HEADER_KEY: Final = "x-amzn-errortype" + _error_code_parser: Callable[[str, str | None], ShapeID | None] = staticmethod( + parse_error_code + ) def identify( self, @@ -64,14 +102,14 @@ def identify( operation: APIOperation[Any, Any], response: HTTPResponse, ) -> ShapeID | None: - if self._HEADER_KEY not in response.fields: + code = _first_field_value(response=response, field_name=_AWS_ERROR_HEADER_KEY) + if code is None: return None + return self._error_code_parser(code, operation.schema.id.namespace) - error_field = response.fields[self._HEADER_KEY] - code = error_field.values[0] if len(error_field.values) > 0 else None - if code is not None: - return parse_error_code(code, operation.schema.id.namespace) - return None + +class AWSJSONErrorIdentifier(AWSErrorIdentifier): + _error_code_parser = staticmethod(parse_header_error_code) class AWSJSONDocument(JSONDocument): @@ -87,39 +125,61 @@ def discriminator(self) -> ShapeID: return parsed -class RestJsonClientProtocol(HttpBindingClientProtocol): - """An implementation of the aws.protocols#restJson1 protocol.""" - - _id: Final = RestJson1Trait.id - _contentType: Final = "application/json" - _error_identifier: Final = AWSErrorIdentifier() +class AWSJSON11Document(JSONDocument): + @property + def discriminator(self) -> ShapeID: + if self.shape_type is ShapeType.STRUCTURE: + return self._schema.id - def __init__(self, service_schema: Schema) -> None: - """Initialize a RestJsonClientProtocol. + if self.shape_type is ShapeType.MAP: + map_document = self.as_map() + code = map_document.get("__type") + if code is None: + code = map_document.get("code") + if code is not None and code.shape_type is ShapeType.STRING: + parsed = parse_header_error_code( + code.as_string(), self._settings.default_namespace + ) + if parsed is not None: + return parsed - :param service: The schema for the service to interact with. - """ - self._codec: Final = JSONCodec( - document_class=AWSJSONDocument, - default_namespace=service_schema.id.namespace, - default_timestamp_format=TimestampFormat.EPOCH_SECONDS, + raise DiscriminatorError( + f"Unable to parse discriminator for {self.shape_type} document." ) - @property - def id(self) -> ShapeID: - return self._id +class _EventStreamClientProtocolMixin: @property def payload_codec(self) -> Codec: - return self._codec + raise NotImplementedError - @property - def content_type(self) -> str: - return self._contentType + def _resolve_event_signing_config( + self, + *, + auth_scheme: AuthScheme[Any, Any, Any, Any] | None, + request: HTTPRequest, + context: TypedProperties, + ) -> "SigningConfig | None": + if auth_scheme is None: + return None + event_signer = auth_scheme.event_signer(request=request) + if event_signer is None: + return None + return SigningConfig( + signer=event_signer, + signing_properties=auth_scheme.signer_properties(context=context), + identity_resolver=auth_scheme.identity_resolver(context=context), + identity_properties=auth_scheme.identity_properties(context=context), + ) - @property - def error_identifier(self) -> HTTPErrorIdentifier: - return self._error_identifier + def _request_async_writer(self, request: HTTPRequest) -> AsyncWriter: + body = request.body + if not isinstance(body, AsyncWriter) or not iscoroutinefunction(body.write): + raise UnsupportedStreamError( + "Input streams require an async write function, but none was present " + "on the serialized HTTP request." + ) + return body def create_event_publisher[ OperationInput: SerializeableShape, @@ -135,32 +195,14 @@ def create_event_publisher[ auth_scheme: AuthScheme[Any, Any, Any, Any] | None = None, ) -> EventPublisher[Event]: _assert_event_stream_capable() - signing_config: SigningConfig | None = None - if auth_scheme is not None: - event_signer = auth_scheme.event_signer(request=request) - if event_signer is not None: - signing_config = SigningConfig( - signer=event_signer, - signing_properties=auth_scheme.signer_properties(context=context), - identity_resolver=auth_scheme.identity_resolver(context=context), - identity_properties=auth_scheme.identity_properties( - context=context - ), - ) - - # The HTTP body must be an async writeable. The HTTP serializers are responsible - # for ensuring this. - body = request.body - if not isinstance(body, AsyncWriter) or not iscoroutinefunction(body.write): - raise UnsupportedStreamError( - "Input streams require an async write function, but none was present " - "on the serialized HTTP request." - ) - return AWSEventPublisher[Event]( payload_codec=self.payload_codec, - async_writer=body, - signing_config=signing_config, + async_writer=self._request_async_writer(request), + signing_config=self._resolve_event_signing_config( + auth_scheme=auth_scheme, + request=request, + context=context, + ), ) def create_event_receiver[ @@ -183,3 +225,366 @@ def create_event_receiver[ source=AsyncBytesReader(response.body), deserializer=event_deserializer, ) + + +class _AWSJSONClientProtocol(_EventStreamClientProtocolMixin, HttpClientProtocol): + _error_identifier: Final = AWSJSONErrorIdentifier() + _http_trait: Final = HTTPTrait({"method": "POST", "uri": "/"}) + + _id: ShapeID + _content_type: str + _document_class: type[JSONDocument] = AWSJSONDocument + + def __init__(self, service_schema: Schema) -> None: + self._service_name = service_schema.id.name + self._codec: Final = JSONCodec( + document_class=self._document_class, + default_namespace=service_schema.id.namespace, + default_timestamp_format=TimestampFormat.EPOCH_SECONDS, + use_json_name=False, + ) + + @property + def id(self) -> ShapeID: + return self._id + + @property + def payload_codec(self) -> Codec: + return self._codec + + @property + def content_type(self) -> str: + return self._content_type + + @property + def error_identifier(self) -> HTTPErrorIdentifier: + return self._error_identifier + + def serialize_request[ + OperationInput: SerializeableShape, + OperationOutput: DeserializeableShape, + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + input: OperationInput, + endpoint: URI, + context: TypedProperties, + ) -> HTTPRequest: + payload = self.payload_codec.serialize(shape=input) + input_stream_member = operation.input_stream_member + has_input_event_stream = ( + isinstance(input_stream_member, Schema) + and input_stream_member.shape_type is ShapeType.UNION + ) + + field_tuples: list[tuple[str, str]] = [ + ("x-amz-target", f"{self._service_name}.{operation.schema.id.name}"), + ] + if has_input_event_stream: + field_tuples.append(("content-type", "application/vnd.amazon.eventstream")) + body: AsyncBytesReader | AsyncBytesProvider = AsyncBytesProvider() + else: + field_tuples.extend( + [ + ("content-type", self.content_type), + ("content-length", str(len(payload))), + ] + ) + body = AsyncBytesReader(payload) + + fields = tuples_to_fields(field_tuples) + host = self._resolve_host_prefix(operation=operation, payload=payload) + return _HTTPRequest( + destination=_URI( + host=host, + path=self._http_trait.path.pattern, + query=self._http_trait.query, + ), + body=body, + method=self._http_trait.method, + fields=fields, + ) + + async def deserialize_response[ + OperationInput: SerializeableShape, + OperationOutput: DeserializeableShape, + ]( + self, + *, + operation: APIOperation[OperationInput, OperationOutput], + request: HTTPRequest, + response: HTTPResponse, + error_registry: TypeRegistry, + context: TypedProperties, + ) -> OperationOutput: + if not self._is_success(operation, context, response): + raise await self._create_error( + operation=operation, + request=request, + response=response, + response_body=await self._buffer_async_body(response.body), + error_registry=error_registry, + context=context, + ) + + if operation.output_stream_member is not None: + # Stream members are consumed via create_event_receiver(). + return self.payload_codec.deserialize(source=b"{}", shape=operation.output) + + body = response.body + if not is_streaming_blob(body): + body = await self._buffer_async_body(body) + if not is_streaming_blob(body): + raise UnsupportedStreamError( + "Unable to read async stream. This stream must be buffered prior " + "to deserializing." + ) + + source = self._coerce_json_source(response=response, body=body) + return self.payload_codec.deserialize(source=source, shape=operation.output) + + async def _buffer_async_body(self, stream: AsyncStreamingBlob) -> SyncStreamingBlob: + match stream: + case AsyncByteStream(): + if not iscoroutinefunction(stream.read): + return stream # type: ignore + return await stream.read() + case AsyncIterable(): + chunks: list[bytes] = [] + async for chunk in stream: + chunks.append(chunk) + return b"".join(chunks) + case _: + return stream + + def _is_success( + self, + operation: APIOperation[Any, Any], + context: TypedProperties, + response: HTTPResponse, + ) -> bool: + return 200 <= response.status < 300 + + def _resolve_host_prefix( + self, + *, + operation: APIOperation[Any, Any], + payload: bytes, + ) -> str: + endpoint_trait = operation.schema.get_trait(EndpointTrait) + if endpoint_trait is None: + return "" + + host_prefix = endpoint_trait.host_prefix + labels = self._host_prefix_labels(host_prefix) + if not labels: + return host_prefix + + deserializer = self.payload_codec.create_deserializer(source=payload) + document = deserializer.read_document(schema=DOCUMENT) + if document.shape_type is not ShapeType.MAP: + raise ExpectationNotMetError( + f"Expected input document to be a map for host labels, got {document.shape_type}" + ) + + values: dict[str, str] = {} + map_document = document.as_map() + for label in labels: + value = map_document.get(label) + if value is None or value.shape_type is not ShapeType.STRING: + raise ExpectationNotMetError( + f"Expected host label member '{label}' to be a string in input payload" + ) + values[label] = urlquote(value.as_string(), safe=".") + + return host_prefix.format(**values) + + def _host_prefix_labels(self, host_prefix: str) -> set[str]: + labels: set[str] = set() + for _, field_name, _, _ in Formatter().parse(host_prefix): + if field_name: + labels.add(field_name) + return labels + + def _coerce_json_source( + self, + *, + response: HTTPResponse, + body: SyncStreamingBlob, + ) -> bytes | BytesReader: + if self._is_empty_body(response=response, body=body): + return b"{}" + if isinstance(body, bytearray): + return bytes(body) + return body + + def _is_empty_body( + self, *, response: HTTPResponse, body: SyncStreamingBlob + ) -> bool: + if "content-length" in response.fields: + return int(response.fields["content-length"].as_string()) == 0 + if isinstance(body, bytes | bytearray): + return len(body) == 0 + if ( + seek := getattr(body, "seek", None) + ) is not None and not iscoroutinefunction(seek): + position = None + if ( + tell := getattr(body, "tell", None) + ) is not None and not iscoroutinefunction(tell): + position = tell() + content_length = seek(0, 2) + if position is not None: + seek(position, 0) + else: + seek(0, 0) + return content_length == 0 + return False + + async def _create_error( + self, + operation: APIOperation[Any, Any], + request: HTTPRequest, + response: HTTPResponse, + response_body: SyncStreamingBlob, + error_registry: TypeRegistry, + context: TypedProperties, + ) -> CallError: + error_id = self.error_identifier.identify( + operation=operation, response=response + ) + + if error_id is not None and error_id not in error_registry: + raw_code = self._raw_header_error_code(response) + if raw_code is not None: + legacy_error_id = parse_error_code( + raw_code, operation.schema.id.namespace + ) + if legacy_error_id in error_registry: + error_id = legacy_error_id + + if ( + (error_id is None or error_id not in error_registry) + and self._matches_content_type(response) + and not self._is_empty_body(response=response, body=response_body) + ): + if isinstance(response_body, bytearray): + response_body = bytes(response_body) + deserializer = self.payload_codec.create_deserializer(source=response_body) + document = deserializer.read_document(schema=DOCUMENT) + + body_error_id: ShapeID | None = None + try: + body_error_id = document.discriminator + except DiscriminatorError: + body_error_id = None + + if body_error_id in error_registry: + error_id = body_error_id + if isinstance(response_body, SeekableBytesReader): + response_body.seek(0) + else: + legacy_error_id = parse_document_discriminator( + document, operation.schema.id.namespace + ) + if legacy_error_id in error_registry: + error_id = legacy_error_id + if isinstance(response_body, SeekableBytesReader): + response_body.seek(0) + + if error_id is not None and error_id in error_registry: + error_shape = error_registry.get(error_id) + + # make sure the error shape is derived from modeled exception + if not issubclass(error_shape, ModeledError): + raise ExpectationNotMetError( + f"Modeled errors must be derived from 'ModeledError', " + f"but got {error_shape}" + ) + + source = self._coerce_json_source(response=response, body=response_body) + deserializer = self.payload_codec.create_deserializer(source=source) + return error_shape.deserialize(deserializer) + + message = ( + f"Unknown error for operation {operation.schema.id} " + f"- status: {response.status}" + ) + if error_id is not None: + message += f" - id: {error_id}" + if response.reason is not None: + message += f" - reason: {response.reason}" + + is_timeout = response.status == 408 + is_throttle = response.status == 429 + fault = "client" if response.status < 500 else "server" + + return CallError( + message=message, + fault=fault, + is_throttling_error=is_throttle, + is_timeout_error=is_timeout, + is_retry_safe=is_throttle or is_timeout or None, + ) + + def _raw_header_error_code(self, response: HTTPResponse) -> str | None: + return _first_field_value(response=response, field_name=_AWS_ERROR_HEADER_KEY) + + def _matches_content_type(self, response: HTTPResponse) -> bool: + if "content-type" not in response.fields: + return False + actual = response.fields["content-type"].as_string() + return actual.split(";", 1)[0].strip().lower() == self.content_type.lower() + + +class AwsJson10ClientProtocol(_AWSJSONClientProtocol): + """An implementation of the aws.protocols#awsJson1_0 protocol.""" + + _id: ShapeID = AwsJson1_0Trait.id + _content_type: str = "application/x-amz-json-1.0" + + +class AwsJson11ClientProtocol(_AWSJSONClientProtocol): + """An implementation of the aws.protocols#awsJson1_1 protocol.""" + + _id: ShapeID = AwsJson1_1Trait.id + _content_type: str = "application/x-amz-json-1.1" + _document_class: type[JSONDocument] = AWSJSON11Document + + +class RestJsonClientProtocol( + _EventStreamClientProtocolMixin, HttpBindingClientProtocol +): + """An implementation of the aws.protocols#restJson1 protocol.""" + + _id: Final = RestJson1Trait.id + _content_type: Final = "application/json" + _error_identifier: Final = AWSErrorIdentifier() + + def __init__(self, service_schema: Schema) -> None: + """Initialize a RestJsonClientProtocol. + + :param service: The schema for the service to interact with. + """ + self._codec: Final = JSONCodec( + document_class=AWSJSONDocument, + default_namespace=service_schema.id.namespace, + default_timestamp_format=TimestampFormat.EPOCH_SECONDS, + ) + + @property + def id(self) -> ShapeID: + return self._id + + @property + def payload_codec(self) -> Codec: + return self._codec + + @property + def content_type(self) -> str: + return self._content_type + + @property + def error_identifier(self) -> HTTPErrorIdentifier: + return self._error_identifier diff --git a/packages/smithy-aws-core/src/smithy_aws_core/utils.py b/packages/smithy-aws-core/src/smithy_aws_core/utils.py index 940160e05..bbaa89049 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/utils.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/utils.py @@ -30,3 +30,20 @@ def parse_error_code(code: str, default_namespace: str | None) -> ShapeID | None return None return ShapeID.from_parts(name=code, namespace=default_namespace) + + +def parse_header_error_code(code: str, default_namespace: str | None) -> ShapeID | None: + if not code: + return None + + code = code.split(":")[0] + if "#" in code: + _, _, name = code.partition("#") + if name and default_namespace: + return ShapeID.from_parts(name=name, namespace=default_namespace) + return ShapeID(code) + + if not code or not default_namespace: + return None + + return ShapeID.from_parts(name=code, namespace=default_namespace) diff --git a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py index 7b767a080..bcf066743 100644 --- a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py +++ b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py @@ -1,13 +1,35 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any, cast from unittest.mock import Mock import pytest -from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument -from smithy_core.exceptions import DiscriminatorError +from smithy_aws_core.aio.protocols import ( + AWSErrorIdentifier, + AwsJson11ClientProtocol, + AWSJSONDocument, + RestJsonClientProtocol, +) +from smithy_core import URI +from smithy_core.aio.interfaces import AsyncWriter +from smithy_core.documents import TypeRegistry +from smithy_core.exceptions import CallError, DiscriminatorError, ModeledError +from smithy_core.prelude import STRING from smithy_core.schemas import APIOperation, Schema +from smithy_core.serializers import ShapeSerializer from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import ( + DynamicTrait, + EndpointTrait, + HostLabelTrait, + HTTPHeaderTrait, + HTTPTrait, + StreamingTrait, + Trait, +) +from smithy_core.types import TypedProperties from smithy_http import Fields, tuples_to_fields from smithy_http.aio import HTTPResponse from smithy_json import JSONSettings @@ -25,6 +47,7 @@ "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", "com.test#FooError", ), + ("com.other#FooError", "com.other#FooError"), ("", None), (":", None), (None, None), @@ -97,3 +120,304 @@ def test_aws_json_document_discriminator( else: discriminator = AWSJSONDocument(document, settings=settings).discriminator assert discriminator == expected + + +_EMPTY_INPUT_SCHEMA = Schema.collection( + id=ShapeID("com.test#EmptyInput"), +) +_HEADER_AND_LABEL_INPUT_SCHEMA = Schema.collection( + id=ShapeID("com.test#HeaderAndLabelInput"), + members={ + "headerMember": {"target": STRING, "traits": [HTTPHeaderTrait("x-test")]}, + "label": {"target": STRING, "traits": [HostLabelTrait()]}, + }, +) +_EVENT_STREAM_MEMBER_SCHEMA = Schema( + id=ShapeID("com.test#InputEvents"), + shape_type=ShapeType.UNION, + traits=[StreamingTrait()], +) + + +@dataclass +class _EmptyInput: + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(_EMPTY_INPUT_SCHEMA, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + pass + + +@dataclass +class _HeaderAndLabelInput: + header_member: str | None = None + label: str | None = None + + def serialize(self, serializer: ShapeSerializer) -> None: + serializer.write_struct(_HEADER_AND_LABEL_INPUT_SCHEMA, self) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + if self.header_member is not None: + serializer.write_string( + _HEADER_AND_LABEL_INPUT_SCHEMA.members["headerMember"], + self.header_member, + ) + if self.label is not None: + serializer.write_string( + _HEADER_AND_LABEL_INPUT_SCHEMA.members["label"], + self.label, + ) + + +def _operation_schema(name: str, *, endpoint: str | None = None) -> Schema: + traits: list[Trait | DynamicTrait] = [] + if endpoint is not None: + traits.append(EndpointTrait({"hostPrefix": endpoint})) + return Schema( + id=ShapeID(f"com.test#{name}"), + shape_type=ShapeType.OPERATION, + traits=traits, + ) + + +def _http_operation_schema(name: str) -> Schema: + return Schema( + id=ShapeID(f"com.test#{name}"), + shape_type=ShapeType.OPERATION, + traits=[HTTPTrait({"method": "POST", "uri": "/"})], + ) + + +def _mock_operation(schema: Schema) -> APIOperation[Any, Any]: + operation = Mock(spec=APIOperation) + operation.schema = schema + return cast("APIOperation[Any, Any]", operation) + + +@pytest.mark.asyncio +async def test_aws_json11_serializes_base_request_shape() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + request = protocol.serialize_request( + operation=_mock_operation(_operation_schema("EmptyOperation")), + input=_EmptyInput(), + endpoint=URI(host="example.com"), + context=TypedProperties(), + ) + + assert request.method == "POST" + assert request.destination.path == "/" + assert request.fields["content-type"].as_string() == "application/x-amz-json-1.1" + assert request.fields["x-amz-target"].as_string() == "JsonService.EmptyOperation" + assert request.fields["content-length"].as_string() == "2" + assert await request.consume_body_async() == b"{}" + + +@pytest.mark.asyncio +async def test_aws_json11_serializes_input_event_stream_request_with_writable_body() -> ( + None +): + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + operation = _mock_operation(_operation_schema("StreamingOperation")) + cast(Any, operation).input_stream_member = _EVENT_STREAM_MEMBER_SCHEMA + + request = protocol.serialize_request( + operation=operation, + input=_EmptyInput(), + endpoint=URI(host="example.com"), + context=TypedProperties(), + ) + + assert ( + request.fields["content-type"].as_string() + == "application/vnd.amazon.eventstream" + ) + assert "content-length" not in request.fields + assert isinstance(request.body, AsyncWriter) + + +@pytest.mark.asyncio +async def test_aws_json_ignores_http_bindings_but_applies_host_labels() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + request = protocol.serialize_request( + operation=_mock_operation( + _operation_schema("EndpointOperation", endpoint="foo.{label}.") + ), + input=_HeaderAndLabelInput(header_member="payload", label="bar"), + endpoint=URI(host="example.com"), + context=TypedProperties(), + ) + + assert request.destination.host == "foo.bar." + assert "x-test" not in request.fields + assert ( + await request.consume_body_async() + == b'{"headerMember":"payload","label":"bar"}' + ) + + +def test_aws_json_matches_content_type_with_parameters() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + response = HTTPResponse( + status=500, + fields=tuples_to_fields( + [("content-type", "application/x-amz-json-1.1; charset=utf-8")] + ), + ) + assert getattr(protocol, "_matches_content_type")(response) + + +@pytest.mark.asyncio +async def test_aws_json11_unknown_json_error_returns_call_error() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + operation = _mock_operation(_operation_schema("FailingOperation")) + response = HTTPResponse( + status=400, + reason="Bad Request", + fields=tuples_to_fields([("content-type", "application/x-amz-json-1.1")]), + body=b'{"message":"no discriminator"}', + ) + + error = await getattr(protocol, "_create_error")( + operation=operation, + request=Mock(), + response=response, + response_body=response.body, + error_registry=TypeRegistry({}), + context=TypedProperties(), + ) + + assert isinstance(error, CallError) + assert "reason: Bad Request" in error.message + + +@pytest.mark.asyncio +async def test_aws_json11_empty_json_error_body_returns_call_error() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + operation = _mock_operation(_operation_schema("FailingOperation")) + response = HTTPResponse( + status=400, + reason="Bad Request", + fields=tuples_to_fields( + [ + ("content-type", "application/x-amz-json-1.1"), + ("content-length", "0"), + ] + ), + body=b"", + ) + + error = await getattr(protocol, "_create_error")( + operation=operation, + request=Mock(), + response=response, + response_body=response.body, + error_registry=TypeRegistry({}), + context=TypedProperties(), + ) + + assert isinstance(error, CallError) + assert "reason: Bad Request" in error.message + + +class _OtherNamespaceModeledError(ModeledError): + @classmethod + def deserialize(cls, deserializer: Any) -> "_OtherNamespaceModeledError": + return cls("other namespace") + + +@pytest.mark.asyncio +async def test_aws_json11_resolves_modeled_error_when_header_is_sanitized() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + operation = _mock_operation(_operation_schema("FailingOperation")) + response = HTTPResponse( + status=400, + reason="Bad Request", + fields=tuples_to_fields( + [ + ("x-amzn-errortype", "com.other#OtherNsError"), + ("content-type", "application/x-amz-json-1.1"), + ] + ), + body=b'{"__type":"com.other#OtherNsError"}', + ) + + error = await getattr(protocol, "_create_error")( + operation=operation, + request=Mock(), + response=response, + response_body=response.body, + error_registry=TypeRegistry( + {ShapeID("com.other#OtherNsError"): _OtherNamespaceModeledError} + ), + context=TypedProperties(), + ) + + assert isinstance(error, _OtherNamespaceModeledError) + + +@pytest.mark.asyncio +async def test_aws_json11_resolves_modeled_error_from_header_only_shapeid() -> None: + protocol = AwsJson11ClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + operation = _mock_operation(_operation_schema("FailingOperation")) + response = HTTPResponse( + status=400, + reason="Bad Request", + fields=tuples_to_fields([("x-amzn-errortype", "com.other#OtherNsError")]), + body=b"", + ) + + error = await getattr(protocol, "_create_error")( + operation=operation, + request=Mock(), + response=response, + response_body=response.body, + error_registry=TypeRegistry( + {ShapeID("com.other#OtherNsError"): _OtherNamespaceModeledError} + ), + context=TypedProperties(), + ) + + assert isinstance(error, _OtherNamespaceModeledError) + + +@pytest.mark.asyncio +async def test_rest_json_resolves_modeled_error_from_header_only_shapeid() -> None: + protocol = RestJsonClientProtocol( + Schema(id=ShapeID("com.test#JsonService"), shape_type=ShapeType.SERVICE) + ) + operation = _mock_operation(_http_operation_schema("FailingOperation")) + response = HTTPResponse( + status=400, + reason="Bad Request", + fields=tuples_to_fields([("x-amzn-errortype", "com.other#OtherNsError")]), + body=b"", + ) + + error = await getattr(protocol, "_create_error")( + operation=operation, + request=Mock(), + response=response, + response_body=response.body, + error_registry=TypeRegistry( + {ShapeID("com.other#OtherNsError"): _OtherNamespaceModeledError} + ), + context=TypedProperties(), + ) + + assert isinstance(error, _OtherNamespaceModeledError) diff --git a/packages/smithy-aws-core/tests/unit/test_utils.py b/packages/smithy-aws-core/tests/unit/test_utils.py index 6927a2fce..952e378b1 100644 --- a/packages/smithy-aws-core/tests/unit/test_utils.py +++ b/packages/smithy-aws-core/tests/unit/test_utils.py @@ -2,7 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from smithy_aws_core.utils import parse_document_discriminator, parse_error_code +from smithy_aws_core.utils import ( + parse_document_discriminator, + parse_error_code, + parse_header_error_code, +) from smithy_core.documents import Document from smithy_core.shapes import ShapeID @@ -64,6 +68,11 @@ def test_aws_json_document_discriminator( "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", "com.test#FooError", ), + ("com.other#FooError", "com.other#FooError"), + ( + "com.other#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", + "com.other#FooError", + ), ("", None), (":", None), ], @@ -76,3 +85,20 @@ def test_parse_error_code(code: str, expected: ShapeID | None) -> None: def test_parse_error_code_without_default_namespace() -> None: actual = parse_error_code("FooError", None) assert actual is None + + +@pytest.mark.parametrize( + "code, expected", + [ + ("FooError", "com.test#FooError"), + ( + "com.other#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", + "com.test#FooError", + ), + ("", None), + (":", None), + ], +) +def test_parse_header_error_code(code: str, expected: ShapeID | None) -> None: + actual = parse_header_error_code(code, "com.test") + assert actual == expected From 5a4659c6fd15cd2dba0260962e3f844c6ca18c7f Mon Sep 17 00:00:00 2001 From: jonathan343 Date: Thu, 26 Feb 2026 11:48:39 -0500 Subject: [PATCH 6/6] codegen: add awsJson protocol generators and tests --- Makefile | 10 ++- codegen/aws/core/build.gradle.kts | 1 + .../codegen/AwsJson10ProtocolGenerator.java | 76 +++++++++++++++++++ .../codegen/AwsJson11ProtocolGenerator.java | 63 +++++++++++++++ .../aws/codegen/AwsProtocolsIntegration.java | 2 +- codegen/protocol-test/build.gradle.kts | 1 + codegen/protocol-test/smithy-build.json | 44 +++++++++++ 7 files changed, 192 insertions(+), 5 deletions(-) create mode 100644 codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson10ProtocolGenerator.java create mode 100644 codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson11ProtocolGenerator.java diff --git a/Makefile b/Makefile index 5e931e1c5..ec2452598 100644 --- a/Makefile +++ b/Makefile @@ -14,10 +14,12 @@ build-java: ## Builds the Java code generation packages. cd codegen && ./gradlew clean build -test-protocols: ## Generates and runs the restJson1 protocol tests. - cd codegen && ./gradlew :protocol-test:build - uv pip install codegen/protocol-test/build/smithyprojections/protocol-test/rest-json-1/python-client-codegen - uv run pytest codegen/protocol-test/build/smithyprojections/protocol-test/rest-json-1/python-client-codegen +test-protocols: ## Generates and runs protocol tests for all supported protocols. + cd codegen && ./gradlew :protocol-test:clean :protocol-test:build + @set -e; for projection_dir in codegen/protocol-test/build/smithyprojections/protocol-test/*/python-client-codegen; do \ + uv pip install "$$projection_dir"; \ + uv run pytest "$$projection_dir"; \ + done lint-py: ## Runs linters and formatters on the python packages. diff --git a/codegen/aws/core/build.gradle.kts b/codegen/aws/core/build.gradle.kts index 3a81c5190..49d506d8d 100644 --- a/codegen/aws/core/build.gradle.kts +++ b/codegen/aws/core/build.gradle.kts @@ -12,4 +12,5 @@ extra["moduleName"] = "software.amazon.smithy.python.aws.codegen" dependencies { implementation(project(":core")) implementation(libs.smithy.aws.traits) + implementation(libs.smithy.protocol.test.traits) } diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson10ProtocolGenerator.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson10ProtocolGenerator.java new file mode 100644 index 000000000..dd366a193 --- /dev/null +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson10ProtocolGenerator.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.aws.codegen; + +import java.util.Set; +import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait; +import software.amazon.smithy.model.node.ArrayNode; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.python.codegen.ApplicationProtocol; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.HttpProtocolTestGenerator; +import software.amazon.smithy.python.codegen.SymbolProperties; +import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class AwsJson10ProtocolGenerator implements ProtocolGenerator { + private static final Set TESTS_TO_SKIP = Set.of( + // TODO: support the request compression trait + // https://smithy.io/2.0/spec/behavior-traits.html#smithy-api-requestcompression-trait + "SDKAppliedContentEncoding_awsJson1_0", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_0", + + // TODO: Fix for both REST-JSON and JSON-RPC + "AwsJson10ClientPopulatesDefaultValuesInInput", + "AwsJson10ClientSkipsTopLevelDefaultValuesInInput", + "AwsJson10ClientUsesExplicitlyProvidedMemberValuesOverDefaults", + "AwsJson10ClientPopulatesDefaultsValuesWhenMissingInResponse", + "AwsJson10ClientIgnoresNonTopLevelDefaultsOnMembersWithClientOptional", + "AwsJson10ClientIgnoresDefaultValuesIfMemberValuesArePresentInResponse", + + // TODO: support client error-correction behavior when the server + // omits required values in modeled error responses. + "AwsJson10ClientErrorCorrectsWhenServerFailsToSerializeRequiredValues", + "AwsJson10ClientErrorCorrectsWithDefaultValuesWhenServerFailsToSerializeRequiredValues"); + + @Override + public ShapeId getProtocol() { + return AwsJson1_0Trait.ID; + } + + @Override + public ApplicationProtocol getApplicationProtocol(GenerationContext context) { + var service = context.settings().service(context.model()); + var trait = service.expectTrait(AwsJson1_0Trait.class); + var config = ObjectNode.builder() + .withMember("http", ArrayNode.fromStrings(trait.getHttp())) + .withMember("eventStreamHttp", ArrayNode.fromStrings(trait.getEventStreamHttp())) + .build(); + return ApplicationProtocol.createDefaultHttpApplicationProtocol(config); + } + + @Override + public void initializeProtocol(GenerationContext context, PythonWriter writer) { + writer.addDependency(AwsPythonDependency.SMITHY_AWS_CORE.withOptionalDependencies("json")); + writer.addImport("smithy_aws_core.aio.protocols", "AwsJson10ClientProtocol"); + var serviceSymbol = context.symbolProvider().toSymbol(context.settings().service(context.model())); + var serviceSchema = serviceSymbol.expectProperty(SymbolProperties.SCHEMA); + writer.write("AwsJson10ClientProtocol($T)", serviceSchema); + } + + @Override + public void generateProtocolTests(GenerationContext context) { + context.writerDelegator().useFileWriter("./tests/test_protocol.py", "tests.test_protocol", writer -> { + new HttpProtocolTestGenerator( + context, + getProtocol(), + writer, + (shape, testCase) -> TESTS_TO_SKIP.contains(testCase.getId())).run(); + }); + } +} diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson11ProtocolGenerator.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson11ProtocolGenerator.java new file mode 100644 index 000000000..704e22eb6 --- /dev/null +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsJson11ProtocolGenerator.java @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.python.aws.codegen; + +import java.util.Set; +import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait; +import software.amazon.smithy.model.node.ArrayNode; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.python.codegen.ApplicationProtocol; +import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.HttpProtocolTestGenerator; +import software.amazon.smithy.python.codegen.SymbolProperties; +import software.amazon.smithy.python.codegen.generators.ProtocolGenerator; +import software.amazon.smithy.python.codegen.writer.PythonWriter; +import software.amazon.smithy.utils.SmithyInternalApi; + +@SmithyInternalApi +public final class AwsJson11ProtocolGenerator implements ProtocolGenerator { + private static final Set TESTS_TO_SKIP = Set.of( + // TODO: support the request compression trait + // https://smithy.io/2.0/spec/behavior-traits.html#smithy-api-requestcompression-trait + "SDKAppliedContentEncoding_awsJson1_1", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsJson1_1"); + + @Override + public ShapeId getProtocol() { + return AwsJson1_1Trait.ID; + } + + @Override + public ApplicationProtocol getApplicationProtocol(GenerationContext context) { + var service = context.settings().service(context.model()); + var trait = service.expectTrait(AwsJson1_1Trait.class); + var config = ObjectNode.builder() + .withMember("http", ArrayNode.fromStrings(trait.getHttp())) + .withMember("eventStreamHttp", ArrayNode.fromStrings(trait.getEventStreamHttp())) + .build(); + return ApplicationProtocol.createDefaultHttpApplicationProtocol(config); + } + + @Override + public void initializeProtocol(GenerationContext context, PythonWriter writer) { + writer.addDependency(AwsPythonDependency.SMITHY_AWS_CORE.withOptionalDependencies("json")); + writer.addImport("smithy_aws_core.aio.protocols", "AwsJson11ClientProtocol"); + var serviceSymbol = context.symbolProvider().toSymbol(context.settings().service(context.model())); + var serviceSchema = serviceSymbol.expectProperty(SymbolProperties.SCHEMA); + writer.write("AwsJson11ClientProtocol($T)", serviceSchema); + } + + @Override + public void generateProtocolTests(GenerationContext context) { + context.writerDelegator().useFileWriter("./tests/test_protocol.py", "tests.test_protocol", writer -> { + new HttpProtocolTestGenerator( + context, + getProtocol(), + writer, + (shape, testCase) -> TESTS_TO_SKIP.contains(testCase.getId())).run(); + }); + } +} diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java index 63601dd5d..57f050a69 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsProtocolsIntegration.java @@ -16,6 +16,6 @@ public class AwsProtocolsIntegration implements PythonIntegration { @Override public List getProtocolGenerators() { - return List.of(); + return List.of(new AwsJson10ProtocolGenerator(), new AwsJson11ProtocolGenerator()); } } diff --git a/codegen/protocol-test/build.gradle.kts b/codegen/protocol-test/build.gradle.kts index 5c470b9c4..cddc35e75 100644 --- a/codegen/protocol-test/build.gradle.kts +++ b/codegen/protocol-test/build.gradle.kts @@ -30,5 +30,6 @@ repositories { dependencies { implementation(project(":core")) + implementation(project(":aws:core")) implementation(libs.smithy.aws.protocol.tests) } diff --git a/codegen/protocol-test/smithy-build.json b/codegen/protocol-test/smithy-build.json index cbaccad98..69460fc80 100644 --- a/codegen/protocol-test/smithy-build.json +++ b/codegen/protocol-test/smithy-build.json @@ -22,6 +22,50 @@ "moduleVersion": "0.0.1" } } + }, + "aws-json-1-0": { + "transforms": [ + { + "name": "includeServices", + "args": { + "services": [ + "aws.protocoltests.json10#JsonRpc10" + ] + } + }, + { + "name": "removeUnusedShapes" + } + ], + "plugins": { + "python-client-codegen": { + "service": "aws.protocoltests.json10#JsonRpc10", + "module": "awsjson10", + "moduleVersion": "0.0.1" + } + } + }, + "aws-json-1-1": { + "transforms": [ + { + "name": "includeServices", + "args": { + "services": [ + "aws.protocoltests.json#JsonProtocol" + ] + } + }, + { + "name": "removeUnusedShapes" + } + ], + "plugins": { + "python-client-codegen": { + "service": "aws.protocoltests.json#JsonProtocol", + "module": "awsjson11", + "moduleVersion": "0.0.1" + } + } } } }