diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt index 853ed38a55..935eba31ab 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt @@ -7,6 +7,7 @@ package software.amazon.smithy.kotlin.codegen.aws.protocols import software.amazon.smithy.aws.traits.protocols.AwsQueryErrorTrait import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait +import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AbstractQueryFormUrlSerializerGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingProtocolGenerator import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator @@ -86,6 +87,14 @@ private class AwsQuerySerializerGenerator( members: List, writer: KotlinWriter, ): FormUrlSerdeDescriptorGenerator = AwsQuerySerdeFormUrlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members) + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + TODO("Used for service-codegen. Not yet implemented") + } } private class AwsQueryXmlParserGenerator( diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt index b84f22f442..38e74d010d 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt @@ -98,6 +98,14 @@ private class Ec2QuerySerializerGenerator( members: List, writer: KotlinWriter, ): FormUrlSerdeDescriptorGenerator = Ec2QuerySerdeFormUrlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members) + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + TODO("Used for service-codegen. Not yet implemented") + } } private class Ec2QueryParserGenerator( diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt index dcf2f0a814..db6591738f 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RestJson1.kt @@ -40,9 +40,11 @@ class RestJson1 : JsonHttpBindingProtocolGenerator() { writer: KotlinWriter, ) { super.renderSerializeHttpBody(ctx, op, writer) + if (ctx.settings.build.generateServiceProject) return val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) - if (!resolver.hasHttpBody(op)) return + + if (!resolver.hasHttpRequestBody(op)) return // restjson1 has some different semantics and expectations around empty structures bound via @httpPayload trait // * empty structures get serialized to `{}` diff --git a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt index 19a1963844..3fbf2cfadc 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/core/AwsHttpBindingProtocolGeneratorTest.kt @@ -123,6 +123,14 @@ class AwsHttpBindingProtocolGeneratorTest { ): Symbol { error("Unneeded for test") } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + error("Unneeded for test") + } } override val protocol: ShapeId diff --git a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt index 164e7cb2e6..c8fae63ca8 100644 --- a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt +++ b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt @@ -214,6 +214,14 @@ class MockHttpProtocolGenerator(model: Model) : HttpBindingProtocolGenerator() { val symbol = ctx.symbolProvider.toSymbol(shape) name = "serialize" + StringUtils.capitalize(symbol.name) + "Payload" } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + error("Unneeded for test") + } } override fun operationErrorHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol = buildSymbol { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt index 5283bfa228..41caec378b 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt @@ -143,6 +143,7 @@ data class KotlinDependency( // Ktor server dependencies // FIXME: version numbers should not be hardcoded, they should be setting dynamically based on the Gradle library versions val KTOR_SERVER_CORE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server", "io.ktor", "ktor-server-core", KTOR_VERSION) + val KTOR_SERVER_UTILS = KotlinDependency(GradleConfiguration.Implementation, "io.ktor", "io.ktor", "ktor-server-core", KTOR_VERSION) val KTOR_SERVER_NETTY = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.netty", "io.ktor", "ktor-server-netty", KTOR_VERSION) val KTOR_SERVER_CIO = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.cio", "io.ktor", "ktor-server-cio", KTOR_VERSION) val KTOR_SERVER_JETTY_JAKARTA = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.jetty.jakarta", "io.ktor", "ktor-server-jetty-jakarta", KTOR_VERSION) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index 6ed0d40286..c7dbb93819 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -513,6 +513,10 @@ object RuntimeTypes { val BadRequestException = symbol("BadRequestException", "plugins") } + object KtorServerUtils : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_UTILS) { + val AttributeKey = symbol("AttributeKey", "util") + } + object KtorServerRouting : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_CORE) { val routing = symbol("routing", "routing") val route = symbol("route", "routing") @@ -558,6 +562,9 @@ object RuntimeTypes { val HttpHeaders = symbol("HttpHeaders") val Cbor = symbol("Cbor", "ContentType.Application") val Json = symbol("Json", "ContentType.Application") + val Any = symbol("Any", "ContentType.Application") + val OctetStream = symbol("OctetStream", "ContentType.Application") + val PlainText = symbol("Plain", "ContentType.Text") } object KtorServerLogging : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_LOGGING) { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt index 73453fb104..4dae6e5c7e 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt @@ -120,6 +120,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { httpOperations.forEach { operation -> generateOperationSerializer(ctx, operation) } + + if (ctx.settings.build.generateServiceProject) { + val modeledErrors = httpOperations.flatMap { it.errors }.map { ctx.model.expectShape(it) as StructureShape }.toSet() + modeledErrors.forEach { generateExceptionSerializer(ctx, it) } + } } private fun generateDeserializers(ctx: ProtocolGenerator.GenerationContext) { @@ -131,13 +136,15 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } // generate HttpDeserialize for exception types - val modeledErrors = httpOperations.flatMap { it.errors }.map { ctx.model.expectShape(it) as StructureShape }.toSet() - modeledErrors.forEach { generateExceptionDeserializer(ctx, it) } + if (!ctx.settings.build.generateServiceProject) { + val modeledErrors = httpOperations.flatMap { it.errors }.map { ctx.model.expectShape(it) as StructureShape }.toSet() + modeledErrors.forEach { generateExceptionDeserializer(ctx, it) } + } } override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) { if (ctx.settings.build.generateServiceProject) { - require(protocolName == "smithyRpcv2cbor") { "service project accepts only Cbor protocol" } + require(protocolName in listOf("smithyRpcv2cbor", "awsRestjson1")) { "service project accepts only Cbor or JSON protocol" } } if (!ctx.settings.build.generateServiceProject) { val symbol = ctx.symbolProvider.toSymbol(ctx.service) @@ -185,17 +192,20 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val serdeMeta = HttpSerdeMeta(op.isInputEventStream(ctx.model)) ctx.delegator.useSymbolWriter(serializerSymbol) { writer -> - // FIXME: this works only for Cbor protocol now if (ctx.settings.build.generateServiceProject) { + val serializerResultSymbol = getHttpSerializerResultSymbol(protocolName) + val defaultResponse = getHttpSerializerDefaultResponse(protocolName) + writer .openBlock("internal class #T {", serializerSymbol) .call { writer.openBlock( - "public fun serialize(context: #T, input: #T): ByteArray {", + "public fun serialize(context: #T, input: #T): #T {", RuntimeTypes.Core.ExecutionContext, serializationSymbol, + serializerResultSymbol, ) - .write("var response: Any") + .write("var response: #T = $defaultResponse", serializerResultSymbol) .call { renderSerializeHttpBody(ctx, op, writer) } @@ -236,11 +246,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ) { val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) val httpTrait = resolver.httpTrait(op) - val bindings = if (ctx.settings.build.generateServiceProject) { - resolver.responseBindings(op) - } else { - resolver.requestBindings(op) - } + val bindings = resolver.requestBindings(op) writer .addImport(RuntimeTypes.Core.ExecutionContext) @@ -290,6 +296,42 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } } + /** + * Generate HttpSerialize for a modeled error (exception) + */ + private fun generateExceptionSerializer(ctx: ProtocolGenerator.GenerationContext, shape: StructureShape) { + val serializationSymbol = ctx.symbolProvider.toSymbol(shape) + + val serializerSymbol = buildSymbol { + val deserializerName = "${serializationSymbol.name}Serializer" + definitionFile = "$deserializerName.kt" + name = deserializerName + namespace = ctx.settings.pkg.serde + reference(serializationSymbol, SymbolReference.ContextOption.DECLARE) + } + + ctx.delegator.useSymbolWriter(serializerSymbol) { writer -> + val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) + val bindings = resolver.responseBindings(shape) + val serializerResultSymbol = getHttpSerializerResultSymbol(protocolName) + val defaultResponse = getHttpSerializerDefaultResponse(protocolName) + writer.withBlock("internal class #T {", "}", serializerSymbol) { + writer.openBlock( + "public fun serialize(context: #T, input: #T): #T {", + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + serializerResultSymbol, + ) + .write("var response: #T = $defaultResponse", serializerResultSymbol) + .call { + renderExceptionSerializeBody(ctx, serializationSymbol, bindings, writer) + } + .write("return response") + .closeBlock("}") + } + } + } + /** * Calls the operation body serializer function and binds the results to `builder.body`. * By default if no members are bound to the body this function renders nothing. @@ -297,7 +339,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { */ protected open fun renderSerializeHttpBody(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) - if (!resolver.hasHttpBody(op)) return + if (ctx.settings.build.generateServiceProject) { + if (!resolver.hasHttpResponseBody(op)) return + } else { + if (!resolver.hasHttpRequestBody(op)) return + } // payload member(s) val bindings = if (ctx.settings.build.generateServiceProject) { @@ -315,9 +361,35 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val sdg = structuredDataSerializer(ctx) val opBodySerializerFn = sdg.operationSerializer(ctx, op, documentMembers) writer.write("val payload = #T(context, input)", opBodySerializerFn) - writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = payload.decodeToString()") + } else { + writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + } + } + if (!ctx.settings.build.generateServiceProject) { + renderContentTypeHeader(ctx, op, writer, resolver) } - renderContentTypeHeader(ctx, op, writer, resolver) + } + + /** + * Calls the operation body serializer function and binds the results to `builder.body`. + * By default if no members are bound to the body this function renders nothing. + * If there is a payload to render it should be bound to `builder.body` when this function returns + */ + protected open fun renderExceptionSerializeBody( + ctx: ProtocolGenerator.GenerationContext, + deserializationSymbol: Symbol, + bindings: List, + writer: KotlinWriter, + ) { + val documentMembers = bindings.filterDocumentBoundMembers() + // Unbound document members that should be serialized into the document format for the protocol. + // delegate to the generate operation body serializer function + val sdg = structuredDataSerializer(ctx) + val exceptionBodySerializerFn = sdg.errorSerializer(ctx, deserializationSymbol.shape as StructureShape, documentMembers) + writer.write("val payload = #T(context, input)", exceptionBodySerializerFn) + writer.write("response = payload") } protected open fun renderContentTypeHeader( @@ -535,7 +607,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { if (isBinaryStream) { writer.write("builder.body = input.#L.#T()", memberName, RuntimeTypes.Http.toHttpBody) } else { - writer.write("builder.body = #T.fromBytes(input.#L)", RuntimeTypes.Http.HttpBody, memberName) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L.decodeToString()", memberName) + } else { + writer.write("builder.body = #T.fromBytes(input.#L)", RuntimeTypes.Http.HttpBody, memberName) + } } } @@ -545,29 +621,46 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } else { memberName } - writer.write("builder.body = #T.fromBytes(input.#L.#T())", RuntimeTypes.Http.HttpBody, contents, KotlinTypes.Text.encodeToByteArray) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L", contents) + } else { + writer.write("builder.body = #T.fromBytes(input.#L.#T())", RuntimeTypes.Http.HttpBody, contents, KotlinTypes.Text.encodeToByteArray) + } } ShapeType.ENUM -> - writer.write( - "builder.body = #T.fromBytes(input.#L.value.#T())", - RuntimeTypes.Http.HttpBody, - memberName, - KotlinTypes.Text.encodeToByteArray, - ) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L.value.toString()", memberName) + } else { + writer.write( + "builder.body = #T.fromBytes(input.#L.value.#T())", + RuntimeTypes.Http.HttpBody, + memberName, + KotlinTypes.Text.encodeToByteArray, + ) + } + ShapeType.INT_ENUM -> - writer.write( - "builder.body = #T.fromBytes(input.#L.value.toString().#T())", - RuntimeTypes.Http.HttpBody, - memberName, - KotlinTypes.Text.encodeToByteArray, - ) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = input.#L.value.toString()", memberName) + } else { + writer.write( + "builder.body = #T.fromBytes(input.#L.value.toString().#T())", + RuntimeTypes.Http.HttpBody, + memberName, + KotlinTypes.Text.encodeToByteArray, + ) + } ShapeType.STRUCTURE, ShapeType.UNION, ShapeType.DOCUMENT -> { val sdg = structuredDataSerializer(ctx) val payloadSerializerFn = sdg.payloadSerializer(ctx, binding.member) writer.write("val payload = #T(input.#L)", payloadSerializerFn, memberName) - writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = payload.decodeToString()") + } else { + writer.write("builder.body = #T.fromBytes(payload)", RuntimeTypes.Http.HttpBody) + } } else -> @@ -689,25 +782,13 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } else { resolver.responseBindings(shape) } - - when (ctx.settings.build.generateServiceProject) { - true -> - writer - .addImport(exceptionDeserializerSymbols) - .write("") - .openBlock("internal class #T {", deserializerSymbol) - .write("") - .call { renderServiceHttpDeserialize(ctx, deserializationSymbol, bindings, serdeMeta, null, writer) } - .closeBlock("}") - false -> - writer - .addImport(exceptionDeserializerSymbols) - .write("") - .openBlock("internal class #T: #T.NonStreaming<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserializer, deserializationSymbol) - .write("") - .call { renderHttpDeserialize(ctx, deserializationSymbol, bindings, serdeMeta, null, writer) } - .closeBlock("}") - } + writer + .addImport(exceptionDeserializerSymbols) + .write("") + .openBlock("internal class #T: #T.NonStreaming<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserializer, deserializationSymbol) + .write("") + .call { renderHttpDeserialize(ctx, deserializationSymbol, bindings, serdeMeta, null, writer) } + .closeBlock("}") } } @@ -796,13 +877,15 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { ) { writer .openBlock( - "public fun deserialize(context: #T, payload: #T?): #T {", + "public fun deserialize(context: #T, call: #T, payload: #T?): #T {", RuntimeTypes.Core.ExecutionContext, + RuntimeTypes.KtorServerCore.ApplicationCallClass, KotlinTypes.ByteArray, deserializationSymbol, ) - writer.write("val builder = #T.Builder()", deserializationSymbol) + writer.write("val request = call.request") + .write("val builder = #T.Builder()", deserializationSymbol) .write("") .call { // headers @@ -925,13 +1008,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } else { "" } -// FIXME: Service supports only CBOR now. Modify this part to support service for other data format. -// val message = if (ctx.settings.build.generateServiceProject) { -// "request" -// } else { -// "response" -// } - val message = "response" + val message = if (ctx.settings.build.generateServiceProject) { + "request" + } else { + "response" + } when (memberTarget) { is NumberShape -> { if (memberTarget is IntEnumShape) { @@ -1077,13 +1158,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } else { "" } -// FIXME: Service supports only CBOR now. Modify this part to support service for other data format. -// val message = if (ctx.settings.build.generateServiceProject) { -// "request" -// } else { -// "response" -// } - val message = "response" + + val message = if (ctx.settings.build.generateServiceProject) { + "request" + } else { + "response" + } writer.write("val $keyCollName = $message.headers.names()$filter") writer.openBlock("if ($keyCollName.isNotEmpty()) {") .write("val map = mutableMapOf()", targetValueSymbol) @@ -1264,3 +1344,15 @@ private fun httpDeserializerInfo(ctx: ProtocolGenerator.GenerationContext, op: O return HttpSerdeMeta(isStreaming) } + +private fun getHttpSerializerResultSymbol(protocolName: String) = when (protocolName) { + "smithyRpcv2cbor" -> KotlinTypes.ByteArray + "awsRestjson1" -> KotlinTypes.String + else -> error("service project accepts only Cbor or JSON protocol") +} + +private fun getHttpSerializerDefaultResponse(protocolName: String) = when (protocolName) { + "smithyRpcv2cbor" -> "ByteArray(0)" + "awsRestjson1" -> "\"\"" + else -> error("service project accepts only Cbor or JSON protocol") +} diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt index dc4deb49d2..7e2345f836 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingResolver.kt @@ -90,11 +90,19 @@ interface HttpBindingResolver { /** * @return true if the operation contains request data bound to the PAYLOAD or DOCUMENT locations */ -fun HttpBindingResolver.hasHttpBody(operationShape: OperationShape): Boolean = +fun HttpBindingResolver.hasHttpRequestBody(operationShape: OperationShape): Boolean = requestBindings(operationShape).any { it.location == HttpBinding.Location.PAYLOAD || it.location == HttpBinding.Location.DOCUMENT } +/** + * @return true if the operation contains request data bound to the PAYLOAD or DOCUMENT locations + */ +fun HttpBindingResolver.hasHttpResponseBody(operationShape: OperationShape): Boolean = + responseBindings(operationShape).any { + it.location == HttpBinding.Location.PAYLOAD || it.location == HttpBinding.Location.DOCUMENT + } + /** * Protocol content type mappings */ diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt index ddb70b68dd..a4467e10e3 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt @@ -138,4 +138,24 @@ class CborSerializerGenerator( } } } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + val symbol = ctx.symbolProvider.toSymbol(errorShape) + + return symbol.errorSerializer(ctx.settings) { writer -> + addNestedDocumentSerializers(ctx, errorShape, writer) + val fnName = symbol.errorSerializerName() + writer.openBlock("private fun #L(context: #T, input: #T): ByteArray {", fnName, RuntimeTypes.Core.ExecutionContext, symbol) + .write("val serializer = #T()", RuntimeTypes.Serde.SerdeCbor.CborSerializer) + .call { + renderSerializerBody(ctx, errorShape, members, writer) + } + .write("return serializer.toByteArray()") + .closeBlock("}") + } + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt index 4970e55f91..a98bb369e6 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonParserGenerator.kt @@ -36,7 +36,12 @@ open class JsonParserGenerator( ) override fun operationDeserializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List): Symbol { - val outputSymbol = op.output.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) } + val deserializationTarget = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } + val outputSymbol = deserializationTarget.get().let { ctx.symbolProvider.toSymbol(ctx.model.expectShape(it)) } return op.bodyDeserializer(ctx.settings) { writer -> addNestedDocumentDeserializers(ctx, op, writer) val fnName = op.bodyDeserializerName() @@ -74,8 +79,13 @@ open class JsonParserGenerator( documentMembers: List, writer: KotlinWriter, ) { + val deserializationTarget = if (ctx.settings.build.generateServiceProject) { + op.input + } else { + op.output + } writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeJson.JsonDeserializer) - val shape = ctx.model.expectShape(op.output.get()) + val shape = ctx.model.expectShape(deserializationTarget.get()) renderDeserializerBody(ctx, shape, documentMembers, writer) } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt index 6366fe8e04..62a78064b3 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/JsonSerializerGenerator.kt @@ -26,8 +26,13 @@ open class JsonSerializerGenerator( open val defaultTimestampFormat: TimestampFormatTrait.Format = TimestampFormatTrait.Format.EPOCH_SECONDS override fun operationSerializer(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, members: List): Symbol { - val input = ctx.model.expectShape(op.input.get()) - val symbol = ctx.symbolProvider.toSymbol(input) + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + val shape = ctx.model.expectShape(serializationTarget.get()) + val symbol = ctx.symbolProvider.toSymbol(shape) return op.bodySerializer(ctx.settings) { writer -> addNestedDocumentSerializers(ctx, op, writer) @@ -61,7 +66,12 @@ open class JsonSerializerGenerator( documentMembers: List, writer: KotlinWriter, ) { - val shape = ctx.model.expectShape(op.input.get()) + val serializationTarget = if (ctx.settings.build.generateServiceProject) { + op.output + } else { + op.input + } + val shape = ctx.model.expectShape(serializationTarget.get()) writer.write("val serializer = #T()", RuntimeTypes.Serde.SerdeJson.JsonSerializer) renderSerializerBody(ctx, shape, documentMembers, writer) writer.write("return serializer.toByteArray()") @@ -118,4 +128,24 @@ open class JsonSerializerGenerator( } } } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + val symbol = ctx.symbolProvider.toSymbol(errorShape) + + return symbol.errorSerializer(ctx.settings) { writer -> + addNestedDocumentSerializers(ctx, errorShape, writer) + val fnName = symbol.errorSerializerName() + writer.openBlock("private fun #L(context: #T, input: #T): String {", fnName, RuntimeTypes.Core.ExecutionContext, symbol) + .write("val serializer = #T()", RuntimeTypes.Serde.SerdeJson.JsonSerializer) + .call { + renderSerializerBody(ctx, errorShape, members, writer) + } + .write("return serializer.toByteArray().decodeToString()") + .closeBlock("}") + } + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt index 071d8ba364..2083b83524 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/SerdeExt.kt @@ -149,6 +149,26 @@ fun Symbol.errorDeserializer(settings: KotlinSettings, block: SymbolRenderer): S renderBy = block } +/** + * Get the serializer name for an error shape + */ +fun Symbol.errorSerializerName(): String = "serialize" + StringUtils.capitalize(this.name) + "Error" + +/** + * Get the function responsible for serializing members bound to the payload of an error shape as [Symbol] and + * register [block] * which will be invoked to actually render the function (signature and implementation) + */ +fun Symbol.errorSerializer(settings: KotlinSettings, block: SymbolRenderer): Symbol = buildSymbol { + name = errorSerializerName() + namespace = settings.pkg.serde + val symbol = this@errorSerializer + // place it in the same file as the exception deserializer, e.g. for HTTP protocols this will be in + // same file as HttpDeserialize + definitionFile = "${symbol.name}Serializer.kt" + reference(symbol, SymbolReference.ContextOption.DECLARE) + renderBy = block +} + /** * Get the function responsible for deserializing the specific shape as a standalone payload */ diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt index c3e49c128d..7d08a03b86 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/StructuredDataSerializerGenerator.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape /** * Responsible for rendering serialization of structured data (e.g. json, yaml, xml). @@ -56,4 +57,27 @@ interface StructuredDataSerializerGenerator { shape: Shape, members: Collection? = null, ): Symbol + + /** + * Render function responsible for serializing members bound to the payload for the given error shape. + * + * Because only a subset of fields of an operation error may be bound to the payload a builder is given + * as an argument. + * + * ``` + * fun serializeFooError(builder: FooError.Builder, payload: ByteArray) { + * ... + * } + * ``` + * + * Implementations are expected to instantiate an appropriate serializer for the protocol and serialize + * the error shape from the payload using the builder passed in. + * + * @param ctx the protocol generator context + * @param errorShape the error shape to render deserialize for + * @param members the members of the error shape that are bound to the payload. Not all members are + * bound to the document, some may be bound to e.g. headers, status code, etc + * @return the generated symbol which should be a function matching the signature expected for the protocol + */ + fun errorSerializer(ctx: ProtocolGenerator.GenerationContext, errorShape: StructureShape, members: List): Symbol } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt index 1f77136e15..0ff77e2e85 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/XmlSerializerGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingConte import software.amazon.smithy.model.shapes.MemberShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.Shape +import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.model.traits.XmlAttributeTrait @@ -141,4 +142,12 @@ open class XmlSerializerGenerator( } } } + + override fun errorSerializer( + ctx: ProtocolGenerator.GenerationContext, + errorShape: StructureShape, + members: List, + ): Symbol { + TODO("Used for service-codegen. Not yet implemented") + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt index 826f770455..51cedb1108 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt @@ -10,12 +10,26 @@ import software.amazon.smithy.kotlin.codegen.core.withBlock import software.amazon.smithy.kotlin.codegen.core.withInlineBlock import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.getTrait +import software.amazon.smithy.kotlin.codegen.service.MediaType.ANY +import software.amazon.smithy.kotlin.codegen.service.MediaType.JSON +import software.amazon.smithy.kotlin.codegen.service.MediaType.OCTET_STREAM +import software.amazon.smithy.kotlin.codegen.service.MediaType.PLAIN_TEXT import software.amazon.smithy.kotlin.codegen.service.contraints.ConstraintGenerator import software.amazon.smithy.kotlin.codegen.service.contraints.ConstraintUtilsGenerator +import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.model.shapes.ShapeId import software.amazon.smithy.model.traits.AuthTrait import software.amazon.smithy.model.traits.HttpBearerAuthTrait +import software.amazon.smithy.model.traits.HttpErrorTrait +import software.amazon.smithy.model.traits.HttpHeaderTrait +import software.amazon.smithy.model.traits.HttpLabelTrait +import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait +import software.amazon.smithy.model.traits.HttpQueryParamsTrait +import software.amazon.smithy.model.traits.HttpQueryTrait import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.traits.MediaTypeTrait +import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.utils.AbstractCodeWriter class LoggingWriter(parent: LoggingWriter? = null) : AbstractCodeWriter() { @@ -110,7 +124,6 @@ internal class KtorStubGenerator( write("#T.WARN -> #T.WARN", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) write("#T.ERROR -> #T.ERROR", ServiceTypes(pkgName).logLevel, RuntimeTypes.KtorLoggingSlf4j.Level) write("#T.OFF -> null", ServiceTypes(pkgName).logLevel) - write("else -> #T.INFO", RuntimeTypes.KtorLoggingSlf4j.Level) } write("") write("val logbackLevel = slf4jLevel?.let { #T.valueOf(it.name) } ?: #T.OFF", RuntimeTypes.KtorLoggingLogback.Level, RuntimeTypes.KtorLoggingLogback.Level) @@ -209,8 +222,6 @@ internal class KtorStubGenerator( // Writes `Routing.kt` that maps Smithy operations → Ktor routes. override fun renderRouting() { - val contentType = ContentType.fromServiceShape(serviceShape) - delegator.useFileWriter("Routing.kt", pkgName) { writer -> operations.forEach { shape -> @@ -220,6 +231,9 @@ internal class KtorStubGenerator( writer.addImport("$pkgName.model", "${shape.id.name}Request") writer.addImport("$pkgName.model", "${shape.id.name}Response") writer.addImport("$pkgName.operations", "handle${shape.id.name}Request") + shape.errors.forEach { error -> + writer.addImport("$pkgName.serde", "${error.name}Serializer") + } } writer.withBlock("internal fun #T.configureRouting(): Unit {", "}", RuntimeTypes.KtorServerCore.Application) { @@ -243,15 +257,27 @@ internal class KtorStubGenerator( "OPTIONS" -> RuntimeTypes.KtorServerRouting.options else -> error("Unsupported http trait ${httpTrait.method}") } - + val contentType = MediaType.fromServiceShape(ctx, serviceShape, shape.input.get()) val contentTypeGuard = when (contentType) { - ContentType.CBOR -> "cbor()" - ContentType.JSON -> "json()" + MediaType.CBOR -> "cbor()" + JSON -> "json()" + PLAIN_TEXT -> "text()" + OCTET_STREAM -> "binary()" + ANY -> "any()" + } + + val acceptType = MediaType.fromServiceShape(ctx, serviceShape, shape.output.get()) + val acceptTypeGuard = when (acceptType) { + MediaType.CBOR -> "cbor()" + JSON -> "json()" + PLAIN_TEXT -> "text()" + OCTET_STREAM -> "binary()" + ANY -> "any()" } withBlock("#T (#S) {", "}", RuntimeTypes.KtorServerRouting.route, uri) { write("#T(#T) { $contentTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).contentTypeGuard) - write("#T(#T) { $contentTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).acceptTypeGuard) + write("#T(#T) { $acceptTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).acceptTypeGuard) withBlock( "#W", "}", @@ -266,16 +292,23 @@ internal class KtorStubGenerator( ) write("val deserializer = ${shape.id.name}OperationDeserializer()") withBlock( - "val requestObj = try { deserializer.deserialize(#T(), request) } catch (ex: Exception) {", + "var requestObj = try { deserializer.deserialize(#T(), call, request) } catch (ex: Exception) {", "}", RuntimeTypes.Core.ExecutionContext, ) { write( - "throw #T(#S, ex)", + "throw #T(ex?.message ?: #S, ex)", RuntimeTypes.KtorServerCore.BadRequestException, "Malformed CBOR input", ) } + if (ctx.model.expectShape(shape.input.get()).allMembers.isNotEmpty()) { + withBlock("requestObj = requestObj.copy {", "}") { + call { readHttpLabel(shape, writer) } + call { readHttpQuery(shape, writer) } + } + } + write( "try { check${shape.id.name}RequestConstraint(requestObj) } catch (ex: Exception) { throw #T(ex?.message ?: #S, ex) }", RuntimeTypes.KtorServerCore.BadRequestException, @@ -289,15 +322,51 @@ internal class KtorStubGenerator( RuntimeTypes.Core.ExecutionContext, ) { write( - "throw #T(#S, ex)", + "throw #T(ex?.message ?: #S, ex)", RuntimeTypes.KtorServerCore.BadRequestException, "Malformed CBOR output", ) } - .call { renderResponseCall(writer, contentType, successCode) } + call { readResponseHttpHeader("responseObj", shape.output.get(), writer) } + call { readResponseHttpPrefixHeader("responseObj", shape.output.get(), writer) } + call { renderResponseCall("response", writer, acceptType, successCode.toString(), shape.output.get()) } } withBlock(" catch (t: Throwable) {", "}") { - write("throw t") + writeInline("val errorObj: Any? = ") + withBlock("when (t) {", "}") { + shape.errors.forEach { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) + write("is #T -> t as #T", errorSymbol, errorSymbol) + } + write("else -> null") + } + write("") + write("") + writeInline("val errorResponse: Pair? = ") + withBlock("when (errorObj) {", "}") { + shape.errors.forEach { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) + write("is #T -> Pair(${errorShapeId.name}Serializer().serialize(#T(), errorObj), ${errorShape.getTrait()?.code})", errorSymbol, RuntimeTypes.Core.ExecutionContext) + } + write("else -> null") + } + write("if (errorResponse == null) throw t") + + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) + withBlock("when (errorObj) {", "}") { + shape.errors.forEach { errorShapeId -> + val errorShape = ctx.model.expectShape(errorShapeId) + val errorSymbol = ctx.symbolProvider.toSymbol(errorShape) + withBlock("is #T -> {", "}", errorSymbol) { + readResponseHttpHeader("errorObj", errorShapeId, writer) + readResponseHttpPrefixHeader("errorObj", errorShapeId, writer) + } + } + write("else -> null") + } + call { renderResponseCall("errorResponse.first", writer, acceptType, "\"\${errorResponse.second}\"", shape.output.get()) } } } } @@ -308,6 +377,99 @@ internal class KtorStubGenerator( } } + private fun readHttpLabel(shape: OperationShape, writer: KotlinWriter) { + val inputShape = ctx.model.expectShape(shape.input.get()) + inputShape.allMembers + .filter { member -> member.value.hasTrait(HttpLabelTrait.ID) } + .forEach { member -> + val memberName = member.key + val memberShape = member.value + + val httpLabelVariableName = "call.parameters[\"$memberName\"]?" + val targetShape = ctx.model.expectShape(memberShape.target) + writer.writeInline("$memberName = ") + .call { + renderCastingPrimitiveFromShapeType( + httpLabelVariableName, + targetShape.type, + writer, + memberShape.getTrait() ?: inputShape.getTrait(), + "Unsupported type ${memberShape.type} for httpLabel", + ) + } + } + } + + private fun readHttpQuery(shape: OperationShape, writer: KotlinWriter) { + val inputShape = ctx.model.expectShape(shape.input.get()) + val httpQueryKeys = mutableSetOf() + inputShape.allMembers + .filter { member -> member.value.hasTrait(HttpQueryTrait.ID) } + .forEach { member -> + val memberName = member.key + val memberShape = member.value + val httpQueryTrait = memberShape.getTrait()!! + val httpQueryVariableName = "call.request.queryParameters[\"${httpQueryTrait.value}\"]?" + val targetShape = ctx.model.expectShape(memberShape.target) + httpQueryKeys.add(httpQueryTrait.value) + writer.writeInline("$memberName = ") + .call { + when { + targetShape.isListShape -> { + val listMemberShape = targetShape.allMembers.values.first() + val listMemberTargetShapeId = ctx.model.expectShape(listMemberShape.target) + val httpQueryListVariableName = "(call.request.queryParameters.getAll(\"${httpQueryTrait.value}\") " + + "?: call.request.queryParameters.getAll(\"${httpQueryTrait.value}[]\") " + + "?: emptyList())" + writer.withBlock("$httpQueryListVariableName.mapNotNull{", "}") { + renderCastingPrimitiveFromShapeType( + "it?", + listMemberTargetShapeId.type, + writer, + listMemberShape.getTrait() ?: targetShape.getTrait(), + "Unsupported type ${memberShape.type} for list in httpLabel", + ) + } + } + else -> renderCastingPrimitiveFromShapeType( + httpQueryVariableName, + targetShape.type, + writer, + memberShape.getTrait() ?: inputShape.getTrait(), + "Unsupported type ${memberShape.type} for httpQuery", + ) + } + } + } + val httpQueryParamsMember = inputShape.allMembers.values.firstOrNull { it.hasTrait(HttpQueryParamsTrait.ID) } + httpQueryParamsMember?.apply { + val httpQueryParamsMemberName = httpQueryParamsMember.memberName + val httpQueryParamsMapShape = ctx.model.expectShape(httpQueryParamsMember.target) as MapShape + val httpQueryParamsMapValueTypeShape = ctx.model.expectShape(httpQueryParamsMapShape.value.target) + println(httpQueryParamsMapShape) + val httpQueryKeysLiteral = httpQueryKeys.joinToString(", ") { "\"$it\"" } + writer.withInlineBlock("$httpQueryParamsMemberName = call.request.queryParameters.entries().filter { (key, _) ->", "}") { + write("key !in setOf($httpQueryKeysLiteral)") + } + .withBlock(".associate { (key, values) ->", "}") { + if (httpQueryParamsMapValueTypeShape.isListShape) { + write("key to values!!") + } else { + write("key to values.first()") + } + } + .withBlock(".mapValues { (_, value) ->", "}") { + renderCastingPrimitiveFromShapeType( + "value", + httpQueryParamsMapValueTypeShape.type, + writer, + httpQueryParamsMapValueTypeShape.getTrait() ?: httpQueryParamsMapShape.getTrait(), + "Unsupported type ${httpQueryParamsMapValueTypeShape.type} for httpQuery", + ) + } + } + } + private fun renderRoutingAuth(w: KotlinWriter, shape: OperationShape) { val hasServiceHttpBearerAuthTrait = serviceShape.hasTrait(HttpBearerAuthTrait.ID) val authTrait = shape.getTrait() @@ -324,38 +486,107 @@ internal class KtorStubGenerator( } } + private fun readResponseHttpHeader(dataName: String, shapeId: ShapeId, writer: KotlinWriter) { + val shape = ctx.model.expectShape(shapeId) + shape.allMembers + .filter { member -> member.value.hasTrait(HttpHeaderTrait.ID) } + .forEach { member -> + val headerName = member.value.getTrait()!!.value + val memberName = member.key + writer.write("call.response.headers.append(#S, $dataName.$memberName.toString())", headerName) + } + } + + private fun readResponseHttpPrefixHeader(dataName: String, shapeId: ShapeId, writer: KotlinWriter) { + val shape = ctx.model.expectShape(shapeId) + shape.allMembers + .filter { member -> member.value.hasTrait(HttpPrefixHeadersTrait.ID) } + .forEach { member -> + val prefixHeaderName = member.value.getTrait()!!.value + val memberName = member.key + writer.withBlock("for ((suffixHeader, headerValue) in $dataName?.$memberName ?: mapOf()) {", "}") { + writer.write("call.response.headers.append(#S, headerValue.toString())", "$prefixHeaderName\${suffixHeader}") + } + } + } + private fun renderResponseCall( + responseName: String, w: KotlinWriter, - contentType: ContentType, - successCode: Int, + acceptType: MediaType, + successCode: String, + outputShapeId: ShapeId, ) { - when (contentType) { - ContentType.CBOR -> w.withBlock( + when (acceptType) { + MediaType.CBOR -> w.withBlock( "#T.#T(", ")", RuntimeTypes.KtorServerCore.applicationCall, RuntimeTypes.KtorServerRouting.responseRespondBytes, ) { - write("bytes = response,") + write("bytes = $responseName as ByteArray,") write("contentType = #T,", RuntimeTypes.KtorServerHttp.Cbor) write( - "status = #T.fromValue($successCode),", + "status = #T.fromValue($successCode.toInt()),", RuntimeTypes.KtorServerHttp.HttpStatusCode, ) } - ContentType.JSON -> w.withBlock( + OCTET_STREAM -> w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseRespondBytes, + ) { + write("bytes = $responseName as ByteArray,") + write("contentType = #T,", RuntimeTypes.KtorServerHttp.OctetStream) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + JSON -> w.withBlock( "#T.#T(", ")", RuntimeTypes.KtorServerCore.applicationCall, RuntimeTypes.KtorServerRouting.responseResponseText, ) { - write("text = response,") + write("text = $responseName as String,") write("contentType = #T,", RuntimeTypes.KtorServerHttp.Json) write( - "status = #T.fromValue($successCode),", + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + PLAIN_TEXT -> w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseResponseText, + ) { + write("text = $responseName as String,") + write("contentType = #T,", RuntimeTypes.KtorServerHttp.PlainText) + write( + "status = #T.fromValue($successCode.toInt()),", RuntimeTypes.KtorServerHttp.HttpStatusCode, ) } + ANY -> { + val outputShape = ctx.model.expectShape(outputShapeId) + val mediaTraits = outputShape.allMembers.values.firstNotNullOf { it.getTrait() } + w.withBlock( + "#T.#T(", + ")", + RuntimeTypes.KtorServerCore.applicationCall, + RuntimeTypes.KtorServerRouting.responseRespondBytes, + ) { + write("bytes = $responseName as ByteArray,") + write("contentType = #S,", mediaTraits.value) + write( + "status = #T.fromValue($successCode.toInt()),", + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) + } + } } } @@ -367,6 +598,8 @@ internal class KtorStubGenerator( private fun renderErrorHandler() { delegator.useFileWriter("ErrorHandler.kt", "$pkgName.plugins") { writer -> + writer.write("internal val ResponseHandledKey = #T(#S)", RuntimeTypes.KtorServerUtils.AttributeKey, "ResponseHandled") + .write("") writer.write("@#T", RuntimeTypes.KotlinxCborSerde.Serializable) .write("private data class ErrorPayload(val code: Int, val message: String)") .write("") @@ -432,17 +665,23 @@ internal class KtorStubGenerator( RuntimeTypes.KtorServerStatusPage.StatusPages, ) { withBlock("status(#T.Unauthorized) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { + write("if (call.attributes.getOrNull(#T) == true) { return@status }", ServiceTypes(pkgName).responseHandledKey) + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) write("val missing = call.request.headers[#S].isNullOrBlank()", "Authorization") write("val message = if (missing) #S else #S", "Missing bearer token", "Invalid or expired bearer token") write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") } write("") withBlock("status(#T.NotFound) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { + write("if (call.attributes.getOrNull(#T) == true) { return@status }", ServiceTypes(pkgName).responseHandledKey) + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) write("val message = #S", "Resource not found") write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") } write("") withBlock("status(#T.MethodNotAllowed) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { + write("if (call.attributes.getOrNull(#T) == true) { return@status }", ServiceTypes(pkgName).responseHandledKey) + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) write("val message = #S", "Method not allowed for this resource") write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") } @@ -468,6 +707,7 @@ internal class KtorStubGenerator( write("") write("val envelope = if (cause is ErrorEnvelope) cause else ErrorEnvelope(status.value, cause.message ?: #S)", "Unexpected error") + write("call.attributes.put(#T, true)", ServiceTypes(pkgName).responseHandledKey) write("call.respondEnvelope( envelope, status )") } } @@ -488,6 +728,10 @@ internal class KtorStubGenerator( writer.withBlock("public class ContentTypeGuardConfig {", "}") { write("public var allow: List<#T> = emptyList()", RuntimeTypes.KtorServerHttp.ContentType) write("") + withBlock("public fun any(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Any) + } + write("") withBlock("public fun json(): Unit {", "}") { write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Json) } @@ -495,6 +739,14 @@ internal class KtorStubGenerator( withBlock("public fun cbor(): Unit {", "}") { write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Cbor) } + write("") + withBlock("public fun text(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.PlainText) + } + write("") + withBlock("public fun binary(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.OctetStream) + } } .write("") @@ -545,6 +797,10 @@ internal class KtorStubGenerator( writer.withBlock("public class AcceptTypeGuardConfig {", "}") { write("public var allow: List<#T> = emptyList()", RuntimeTypes.KtorServerHttp.ContentType) write("") + withBlock("public fun any(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Any) + } + write("") withBlock("public fun json(): Unit {", "}") { write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Json) } @@ -552,6 +808,14 @@ internal class KtorStubGenerator( withBlock("public fun cbor(): Unit {", "}") { write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Cbor) } + write("") + withBlock("public fun text(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.PlainText) + } + write("") + withBlock("public fun binary(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.OctetStream) + } } .write("") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt index 067c3f0d14..c6ee377cb1 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubConfigurations.kt @@ -5,25 +5,53 @@ import software.amazon.smithy.build.FileManifest import software.amazon.smithy.kotlin.codegen.core.GenerationContext import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.model.shapes.ShapeId +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.traits.HttpPayloadTrait +import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.protocol.traits.Rpcv2CborTrait -enum class ContentType(val value: String) { +enum class MediaType(val value: String) { CBOR("CBOR"), JSON("JSON"), + PLAIN_TEXT("PLAIN_TEXT"), + OCTET_STREAM("OctetStream"), + ANY("ANY"), ; override fun toString(): String = value companion object { - fun fromValue(value: String): ContentType = ContentType + fun fromValue(value: String): MediaType = MediaType .entries .firstOrNull { it.name.equals(value.uppercase(), ignoreCase = true) } - ?: throw IllegalArgumentException("$value is not a validContentType value, expected one of ${ContentType.entries}") + ?: throw IllegalArgumentException("$value is not a validContentType value, expected one of ${MediaType.entries}") - fun fromServiceShape(shape: ServiceShape): ContentType = when { - shape.hasTrait(Rpcv2CborTrait.ID) -> CBOR - shape.hasTrait(RestJson1Trait.ID) -> JSON - else -> throw IllegalArgumentException("service shape does not a valid protocol") + fun fromServiceShape(ctx: GenerationContext, shape: ServiceShape, targetShapeId: ShapeId): MediaType { + return when { + shape.hasTrait(Rpcv2CborTrait.ID) -> CBOR + shape.hasTrait(RestJson1Trait.ID) -> { + val targetShape = ctx.model.expectShape(targetShapeId) + for (memberShape in targetShape.allMembers.values) { + if (!memberShape.hasTrait(HttpPayloadTrait.ID)) continue + val memberType = ctx.model.expectShape(memberShape.target).type + when (memberType) { + ShapeType.STRING -> return PLAIN_TEXT + ShapeType.BLOB -> return OCTET_STREAM + ShapeType.DOCUMENT, + ShapeType.STRUCTURE, + ShapeType.UNION, + -> return JSON + else -> { + if (memberShape.hasTrait(MediaTypeTrait.ID)) return ANY + } + } + } + return JSON + } + + else -> throw IllegalArgumentException("Cannot find supported MediaType for the service") + } } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt index 69728b1d83..74385533cf 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt @@ -77,4 +77,9 @@ class ServiceTypes(val pkgName: String) { name = "hasAllUniqueElements" namespace = "$pkgName.constraints" } + + val responseHandledKey = buildSymbol { + name = "ResponseHandledKey" + namespace = "$pkgName.plugins" + } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/contraints/ConstraintGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/contraints/ConstraintGenerator.kt index 0291c8ed6a..95ff6e66a8 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/contraints/ConstraintGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/contraints/ConstraintGenerator.kt @@ -29,20 +29,31 @@ internal class ConstraintGenerator( val memberName = memberShape.memberName val memberAndTargetTraits = memberShape.allTraits + targetShape.allTraits - + when { + targetShape.isListShape -> + for (member in targetShape.allMembers) { + val newMemberPrefix = "${targetShape.id.name}".replaceFirstChar { it.lowercase() } + writer.withBlock("if ($prefix$memberName != null) {", "}") { + withBlock("for ($newMemberPrefix${member.key} in $prefix$memberName ?: listOf()) {", "}") { + call { generateConstraintValidations(newMemberPrefix, member.value, writer) } + } + } + } + targetShape.isStructureShape -> + for (member in targetShape.allMembers) { + val newMemberPrefix = "$prefix$memberName?." + generateConstraintValidations(newMemberPrefix, member.value, writer) + } + } for (memberTrait in memberAndTargetTraits.values) { val traitGenerator = getTraitGeneratorFromTrait(prefix, memberName, memberTrait, pkgName, writer) - if (memberTrait !is RequiredTrait) { - writer.write("if ($prefix$memberName == null) { return }") - } - traitGenerator?.render() - } - - for (member in targetShape.allMembers) { - val newMemberPrefix = "${targetShape.id.name}".replaceFirstChar { it.lowercase() } - writer.withBlock("if ($prefix$memberName != null) {", "}") { - withBlock("for ($newMemberPrefix${member.key} in $prefix$memberName) {", "}") { - call { generateConstraintValidations(newMemberPrefix, member.value, writer) } + traitGenerator?.apply { + if (memberTrait !is RequiredTrait) { + writer.withBlock("if ($prefix$memberName != null) {", "}") { + render() + } + } else { + render() } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/utils.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/utils.kt new file mode 100644 index 0000000000..dd6001dfcb --- /dev/null +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/utils.kt @@ -0,0 +1,39 @@ +package software.amazon.smithy.kotlin.codegen.service + +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.traits.TimestampFormatTrait + +fun renderCastingPrimitiveFromShapeType( + variable: String, + type: ShapeType, + writer: KotlinWriter, + timestampFormatTrait: TimestampFormatTrait? = null, + errorMessage: String? = null, +) { + when (type) { + ShapeType.BLOB -> writer.write("$variable.toByteArray()") + ShapeType.STRING -> writer.write("$variable.toString()") + ShapeType.BYTE -> writer.write("$variable.toByte()") + ShapeType.INTEGER -> writer.write("$variable.toInt()") + ShapeType.SHORT -> writer.write("$variable.toShort()") + ShapeType.LONG -> writer.write("$variable.toLong()") + ShapeType.FLOAT -> writer.write("$variable.toFloat()") + ShapeType.DOUBLE -> writer.write("$variable.toDouble()") + ShapeType.BIG_DECIMAL -> writer.write("$variable.toBigDecimal()") + ShapeType.BIG_INTEGER -> writer.write("$variable.toBigInteger()") + ShapeType.BOOLEAN -> writer.write("$variable.toBoolean()") + ShapeType.TIMESTAMP -> + when (timestampFormatTrait?.format) { + TimestampFormatTrait.Format.EPOCH_SECONDS -> + writer.write("$variable.let{ #T.fromEpochSeconds(it) }", RuntimeTypes.Core.Instant) + TimestampFormatTrait.Format.DATE_TIME -> + writer.write("$variable.let{ #T.fromIso8601(it) }", RuntimeTypes.Core.Instant) + TimestampFormatTrait.Format.HTTP_DATE -> + writer.write("$variable.let{ #T.fromRfc5322(it) }", RuntimeTypes.Core.Instant) + else -> writer.write("$variable.let{ #T.fromEpochSeconds(it) }", RuntimeTypes.Core.Instant) + } + else -> throw IllegalStateException(errorMessage ?: "Unable to render casting primitive for $type") + } +} diff --git a/tests/codegen/service-codegen-tests/model/service-generator-test.smithy b/tests/codegen/service-codegen-tests/model/service-cbor-test.smithy similarity index 65% rename from tests/codegen/service-codegen-tests/model/service-generator-test.smithy rename to tests/codegen/service-codegen-tests/model/service-cbor-test.smithy index e7109f354b..1aa9eebc28 100644 --- a/tests/codegen/service-codegen-tests/model/service-generator-test.smithy +++ b/tests/codegen/service-codegen-tests/model/service-cbor-test.smithy @@ -1,17 +1,18 @@ $version: "2.0" -namespace com.test +namespace com.cbor use smithy.protocols#rpcv2Cbor @rpcv2Cbor @httpBearerAuth -service ServiceGeneratorTest { +service CborServiceTest { version: "1.0.0" operations: [ - PostTest, - AuthTest, - ErrorTest, + PostTest + AuthTest + ErrorTest + HttpErrorTest ] } @@ -64,4 +65,25 @@ structure ErrorTestInput { @output structure ErrorTestOutput { output1: String +} + + +@http(method: "POST", uri: "/http-error", code: 200) +operation HttpErrorTest { + input: HttpErrorTestInput + output: HttpErrorTestOutput + errors: [HttpError] +} + +@input +structure HttpErrorTestInput {} + +@output +structure HttpErrorTestOutput {} + +@error("client") +@httpError(456) +structure HttpError { + msg: String + num: Integer } \ No newline at end of file diff --git a/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy b/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy index 3f8bef6c19..4d5e042dab 100644 --- a/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy +++ b/tests/codegen/service-codegen-tests/model/service-constraints-test.smithy @@ -1,6 +1,6 @@ $version: "2.0" -namespace com.test +namespace com.constraints use smithy.protocols#rpcv2Cbor diff --git a/tests/codegen/service-codegen-tests/model/service-json-test.smithy b/tests/codegen/service-codegen-tests/model/service-json-test.smithy new file mode 100644 index 0000000000..52f078c451 --- /dev/null +++ b/tests/codegen/service-codegen-tests/model/service-json-test.smithy @@ -0,0 +1,199 @@ +$version: "2.0" + +namespace com.json + +use aws.protocols#restJson1 + +@restJson1 +@httpBearerAuth +service JsonServiceTest { + version: "1.0.0" + operations: [ + HttpHeaderTest + HttpLabelTest + HttpQueryTest + HttpStringPayloadTest + HttpStructurePayloadTest + TimestampTest + JsonNameTest + HttpErrorTest + ] +} + +@http(method: "POST", uri: "/http-header", code: 201) +operation HttpHeaderTest { + input: HttpHeaderTestInput + output: HttpHeaderTestOutput +} + +@input +structure HttpHeaderTestInput { + @httpHeader("X-Request-Header") + header: String + + @httpPrefixHeaders("X-Request-Headers-") + headers: MapOfStrings +} + +@output +structure HttpHeaderTestOutput { + @httpHeader("X-Response-Header") + header: String + + @httpPrefixHeaders("X-Response-Headers-") + headers: MapOfStrings +} + + +@http(method: "GET", uri: "/http-label/{foo}", code: 200) +operation HttpLabelTest { + input: HttpLabelTestInput + output: HttpLabelTestOutput +} + +@input +structure HttpLabelTestInput { + @required + @httpLabel + foo: String +} + +@output +structure HttpLabelTestOutput { + output: String +} + +@http(method: "DELETE", uri: "/http-query", code: 200) +operation HttpQueryTest { + input: HttpQueryTestInput + output: HttpQueryTestOutput +} + +@input +structure HttpQueryTestInput { + @httpQuery("query") + query: Integer + + @httpQueryParams + params: MapOfStrings +} + +@output +structure HttpQueryTestOutput { + output: String +} + + +@http(method: "POST", uri: "/http-payload/string", code: 201) +operation HttpStringPayloadTest { + input: HttpStringPayloadTestInput + output: HttpStringPayloadTestOutput +} + +@input +structure HttpStringPayloadTestInput { + @httpPayload + content: String +} + +@output +structure HttpStringPayloadTestOutput { + @httpPayload + content: String +} + +@http(method: "POST", uri: "/http-payload/structure", code: 201) +operation HttpStructurePayloadTest { + input: HttpStructurePayloadTestInput + output: HttpStructurePayloadTestOutput +} + +@input +structure HttpStructurePayloadTestInput { + @httpPayload + content: HttpStructurePayloadTestStructure +} + +@output +structure HttpStructurePayloadTestOutput { + @httpPayload + content: HttpStructurePayloadTestStructure +} + + +@http(method: "POST", uri: "/timestamp", code: 201) +operation TimestampTest { + input: TimestampTestInput + output: TimestampTestOutput +} + +@input +structure TimestampTestInput { + default: Timestamp + @timestampFormat("date-time") + dateTime: Timestamp + @timestampFormat("http-date") + httpDate: Timestamp + @timestampFormat("epoch-seconds") + epochSeconds: Timestamp +} + +@output +structure TimestampTestOutput { + default: Timestamp + @timestampFormat("date-time") + dateTime: Timestamp + @timestampFormat("http-date") + httpDate: Timestamp + @timestampFormat("epoch-seconds") + epochSeconds: Timestamp +} + +@http(method: "POST", uri: "/json-name", code: 201) +operation JsonNameTest { + input: JsonNameTestInput + output: JsonNameTestOutput +} + +@input +structure JsonNameTestInput { + @jsonName("requestName") + content: String +} + +@output +structure JsonNameTestOutput { + @jsonName("responseName") + content: String +} + +@http(method: "POST", uri: "/http-error", code: 200) +operation HttpErrorTest { + input: HttpErrorTestInput + output: HttpErrorTestOutput + errors: [HttpError] +} + +@input +structure HttpErrorTestInput {} + +@output +structure HttpErrorTestOutput {} + +@error("client") +@httpError(456) +structure HttpError { + msg: String + num: Integer +} + +structure HttpStructurePayloadTestStructure { + content1: String + content2: Integer + content3: Float +} + +map MapOfStrings { + key: String + value: String +} \ No newline at end of file diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/CborServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/CborServiceTestGenerator.kt new file mode 100644 index 0000000000..94688a3c18 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/CborServiceTestGenerator.kt @@ -0,0 +1,128 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateCborServiceTest() { + val modelPath: Path = Paths.get("model", "service-cbor-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "CborServiceTest" + val packageName = "com.cbor" + val outputDirName = "service-cbor-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val postTestOperation = """ + package $packageName.operations + + import $packageName.model.PostTestRequest + import $packageName.model.PostTestResponse + + public fun handlePostTestRequest(req: PostTestRequest): PostTestResponse { + val response = PostTestResponse.Builder() + val input1 = req.input1 ?: "" + val input2 = req.input2 ?: 0 + response.output1 = input1 + " world!" + response.output2 = input2 + 1 + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/PostTestOperation.kt", postTestOperation) + + val errorTestOperation = """ + package $packageName.operations + + import $packageName.model.ErrorTestRequest + import $packageName.model.ErrorTestResponse + + public fun handleErrorTestRequest(req: ErrorTestRequest): ErrorTestResponse { + val variable: String? = null + val error = variable!!.length + return ErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/ErrorTestOperation.kt", errorTestOperation) + + val httpErrorTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpErrorTestRequest + import $packageName.model.HttpErrorTestResponse + import $packageName.model.HttpError + + public fun handleHttpErrorTestRequest(req: HttpErrorTestRequest): HttpErrorTestResponse { + + val error = HttpError.Builder() + error.msg = "this is an error message" + error.num = 444 + throw error.build() + + return HttpErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpErrorTestOperation.kt", httpErrorTestOperation) + + val bearerValidation = """ + package $packageName.auth + + public fun bearerValidation(token: String): UserPrincipal? { + if (token == "correctToken") return UserPrincipal("Authenticated User") else return null + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) + + val settingGradleKts = """ + rootProject.name = "service-cbor-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/ConstraintsServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/ConstraintsServiceTestGenerator.kt new file mode 100644 index 0000000000..c96dd83ca6 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/ConstraintsServiceTestGenerator.kt @@ -0,0 +1,78 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateServiceConstraintsTest() { + val modelPath: Path = Paths.get("model", "service-constraints-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "ServiceConstraintsTest" + val packageName = "com.constraints" + val outputDirName = "service-constraints-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val bearerValidation = """ + package $packageName.auth + + public fun bearerValidation(token: String): UserPrincipal? { + if (token == "correctToken") return UserPrincipal("Authenticated User") else return null + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) + + val settingGradleKts = """ + rootProject.name = "service-constraints-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt index ac7f2f1e58..6ffbbea384 100644 --- a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt @@ -1,179 +1,7 @@ package com.test -import software.amazon.smithy.build.FileManifest -import software.amazon.smithy.build.PluginContext -import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin -import software.amazon.smithy.model.loader.ModelAssembler -import software.amazon.smithy.model.node.Node -import software.amazon.smithy.model.node.ObjectNode -import java.nio.file.Files -import java.nio.file.Path -import java.nio.file.Paths - internal fun main() { - generateServiceGeneratorTest() + generateJsonServiceTest() + generateCborServiceTest() generateServiceConstraintsTest() } - -internal fun generateServiceGeneratorTest() { - val modelPath: Path = Paths.get("model", "service-generator-test.smithy") - val defaultModel = ModelAssembler() - .discoverModels() - .addImport(modelPath) - .assemble() - .unwrap() - val serviceName = "ServiceGeneratorTest" - val packageName = "com.test" - - val packagePath = packageName.replace('.', '/') - - val settings: ObjectNode = ObjectNode.builder() - .withMember("service", Node.from("$packageName#$serviceName")) - .withMember( - "package", - ObjectNode.builder() - .withMember("name", Node.from(packageName)) - .withMember("version", Node.from("1.0.0")) - .build(), - ) - .withMember( - "build", - ObjectNode.builder() - .withMember("rootProject", true) - .withMember("generateServiceProject", true) - .withMember( - "optInAnnotations", - Node.arrayNode( - Node.from("aws.smithy.kotlin.runtime.InternalApi"), - Node.from("kotlinx.serialization.ExperimentalSerializationApi"), - ), - ) - .build(), - ) - .withMember( - "serviceStub", - ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), - ) - .build() - val outputDir: Path = Paths.get("build", "service-generator-test").also { Files.createDirectories(it) } - val manifest: FileManifest = FileManifest.create(outputDir) - - val context: PluginContext = PluginContext.builder() - .model(defaultModel) - .fileManifest(manifest) - .settings(settings) - .build() - KotlinCodegenPlugin().execute(context) - - val postTestOperation = """ - package $packageName.operations - - import $packageName.model.PostTestRequest - import $packageName.model.PostTestResponse - - public fun handlePostTestRequest(req: PostTestRequest): PostTestResponse { - val response = PostTestResponse.Builder() - val input1 = req.input1 ?: "" - val input2 = req.input2 ?: 0 - response.output1 = input1 + " world!" - response.output2 = input2 + 1 - return response.build() - } - """.trimIndent() - manifest.writeFile("src/main/kotlin/$packagePath/operations/PostTestOperation.kt", postTestOperation) - - val errorTestOperation = """ - package $packageName.operations - - import $packageName.model.ErrorTestRequest - import $packageName.model.ErrorTestResponse - - public fun handleErrorTestRequest(req: ErrorTestRequest): ErrorTestResponse { - val variable: String? = null - val error = variable!!.length - return ErrorTestResponse.Builder().build() - } - """.trimIndent() - manifest.writeFile("src/main/kotlin/$packagePath/operations/ErrorTestOperation.kt", errorTestOperation) - - val bearerValidation = """ - package $packageName.auth - - public fun bearerValidation(token: String): UserPrincipal? { - if (token == "correctToken") return UserPrincipal("Authenticated User") else return null - } - """.trimIndent() - manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) - - val settingGradleKts = """ - rootProject.name = "service-generator-test" - includeBuild("../../../../../") - """.trimIndent() - manifest.writeFile("settings.gradle.kts", settingGradleKts) -} - -internal fun generateServiceConstraintsTest() { - val modelPath: Path = Paths.get("model", "service-constraints-test.smithy") - val defaultModel = ModelAssembler() - .discoverModels() - .addImport(modelPath) - .assemble() - .unwrap() - val serviceName = "ServiceConstraintsTest" - val packageName = "com.test" - - val packagePath = packageName.replace('.', '/') - - val settings: ObjectNode = ObjectNode.builder() - .withMember("service", Node.from("$packageName#$serviceName")) - .withMember( - "package", - ObjectNode.builder() - .withMember("name", Node.from(packageName)) - .withMember("version", Node.from("1.0.0")) - .build(), - ) - .withMember( - "build", - ObjectNode.builder() - .withMember("rootProject", true) - .withMember("generateServiceProject", true) - .withMember( - "optInAnnotations", - Node.arrayNode( - Node.from("aws.smithy.kotlin.runtime.InternalApi"), - Node.from("kotlinx.serialization.ExperimentalSerializationApi"), - ), - ) - .build(), - ) - .withMember( - "serviceStub", - ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), - ) - .build() - val outputDir: Path = Paths.get("build", "service-constraints-test").also { Files.createDirectories(it) } - val manifest: FileManifest = FileManifest.create(outputDir) - - val context: PluginContext = PluginContext.builder() - .model(defaultModel) - .fileManifest(manifest) - .settings(settings) - .build() - KotlinCodegenPlugin().execute(context) - - val bearerValidation = """ - package $packageName.auth - - public fun bearerValidation(token: String): UserPrincipal? { - if (token == "correctToken") return UserPrincipal("Authenticated User") else return null - } - """.trimIndent() - manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) - - val settingGradleKts = """ - rootProject.name = "service-constraints-test" - includeBuild("../../../../../") - """.trimIndent() - manifest.writeFile("settings.gradle.kts", settingGradleKts) -} diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/JsonServiceTestGenerator.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/JsonServiceTestGenerator.kt new file mode 100644 index 0000000000..b4cb9ec4dc --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/JsonServiceTestGenerator.kt @@ -0,0 +1,204 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun generateJsonServiceTest() { + val modelPath: Path = Paths.get("model", "service-json-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "JsonServiceTest" + val packageName = "com.json" + val outputDirName = "service-json-test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", outputDirName).also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val httpHeaderTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpHeaderTestRequest + import $packageName.model.HttpHeaderTestResponse + + public fun handleHttpHeaderTestRequest(req: HttpHeaderTestRequest): HttpHeaderTestResponse { + val response = HttpHeaderTestResponse.Builder() + response.header = req.headers?.get("hhh") + response.headers = mapOf("hhh" to (req.header ?: "")) + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpHeaderTestOperation.kt", httpHeaderTestOperation) + + val httpLabelTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpLabelTestRequest + import $packageName.model.HttpLabelTestResponse + + public fun handleHttpLabelTestRequest(req: HttpLabelTestRequest): HttpLabelTestResponse { + val response = HttpLabelTestResponse.Builder() + response.output = req.foo + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpLabelTestOperation.kt", httpLabelTestOperation) + + val httpQueryTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpQueryTestRequest + import $packageName.model.HttpQueryTestResponse + + public fun handleHttpQueryTestRequest(req: HttpQueryTestRequest): HttpQueryTestResponse { + val response = HttpQueryTestResponse.Builder() + response.output = req.query.toString() + (req.params?.get("qqq") ?: "") + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpQueryTestOperation.kt", httpQueryTestOperation) + + val httpStringPayloadTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpStringPayloadTestRequest + import $packageName.model.HttpStringPayloadTestResponse + + public fun handleHttpStringPayloadTestRequest(req: HttpStringPayloadTestRequest): HttpStringPayloadTestResponse { + val response = HttpStringPayloadTestResponse.Builder() + response.content = req.content + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpStringPayloadTestOperation.kt", httpStringPayloadTestOperation) + + val httpStructurePayloadTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpStructurePayloadTestRequest + import $packageName.model.HttpStructurePayloadTestResponse + import $packageName.model.HttpStructurePayloadTestStructure + + public fun handleHttpStructurePayloadTestRequest(req: HttpStructurePayloadTestRequest): HttpStructurePayloadTestResponse { + val response = HttpStructurePayloadTestResponse.Builder() + val content = HttpStructurePayloadTestStructure.Builder() + content.content1 = req.content?.content1 + content.content2 = req.content?.content2 + content.content3 = req.content?.content3 + response.content = content.build() + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpStructurePayloadTestOperation.kt", httpStructurePayloadTestOperation) + + val timestampTestOperation = """ + package $packageName.operations + + import $packageName.model.TimestampTestRequest + import $packageName.model.TimestampTestResponse + + public fun handleTimestampTestRequest(req: TimestampTestRequest): TimestampTestResponse { + val response = TimestampTestResponse.Builder() + response.default = req.default + response.dateTime = req.dateTime + response.httpDate = req.httpDate + response.epochSeconds = req.epochSeconds + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/TimestampTestOperation.kt", timestampTestOperation) + + val jsonNameTestOperation = """ + package $packageName.operations + + import $packageName.model.JsonNameTestRequest + import $packageName.model.JsonNameTestResponse + + public fun handleJsonNameTestRequest(req: JsonNameTestRequest): JsonNameTestResponse { + val response = JsonNameTestResponse.Builder() + response.content = req.content + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/JsonNameTestOperation.kt", jsonNameTestOperation) + + val httpErrorTestOperation = """ + package $packageName.operations + + import $packageName.model.HttpErrorTestRequest + import $packageName.model.HttpErrorTestResponse + import $packageName.model.HttpError + + public fun handleHttpErrorTestRequest(req: HttpErrorTestRequest): HttpErrorTestResponse { + + val error = HttpError.Builder() + error.msg = "this is an error message" + error.num = 444 + throw error.build() + + return HttpErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/HttpErrorTestOperation.kt", httpErrorTestOperation) + + val bearerValidation = """ + package $packageName.auth + + public fun bearerValidation(token: String): UserPrincipal? { + if (token == "correctToken") return UserPrincipal("Authenticated User") else return null + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) + + val settingGradleKts = """ + rootProject.name = "service-json-test" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceGeneratorTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/CborServiceTest.kt similarity index 93% rename from tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceGeneratorTest.kt rename to tests/codegen/service-codegen-tests/src/test/kotlin/com/test/CborServiceTest.kt index ecd65e5efe..629e23b480 100644 --- a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceGeneratorTest.kt +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/CborServiceTest.kt @@ -19,17 +19,17 @@ import kotlin.test.assertIs import kotlin.test.assertTrue @TestInstance(TestInstance.Lifecycle.PER_CLASS) -class ServiceGeneratorTest { +class CborServiceTest { val closeGracePeriodMillis: Long = 5_000L val closeTimeoutMillis: Long = 1_000L val requestBodyLimit: Long = 10L * 1024 * 1024 val port: Int = ServerSocket(0).use { it.localPort } - val portListnerTimeout = 10L + val portListnerTimeout = 180L val baseUrl = "http://localhost:$port" - val projectDir: Path = Paths.get("build/service-generator-test") + val projectDir: Path = Paths.get("build/service-cbor-test") private lateinit var proc: Process @@ -299,7 +299,7 @@ class ServiceGeneratorTest { response.body(), ) assertEquals(400, body.code) - assertEquals("Malformed CBOR input", body.message) + assertEquals("Unexpected EOF: expected 109 more bytes; consumed: 14", body.message) } @Test @@ -380,4 +380,28 @@ class ServiceGeneratorTest { assertEquals(413, body.code) assertEquals("Request is larger than the limit of 10485760 bytes", body.message) } + + @Test + fun `checks http error`() { + val cbor = Cbor { } + + val response = sendRequest( + "$baseUrl/http-error", + "POST", + null, + "application/cbor", + "application/cbor", + "correctToken", + ) + assertIs>(response) + + assertEquals(456, response.statusCode(), "Expected 456") + val body = cbor.decodeFromByteArray( + HttpError.serializer(), + response.body(), + ) + + assertEquals(444, body.num) + assertEquals("this is an error message", body.msg) + } } diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/JsonServiceTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/JsonServiceTest.kt new file mode 100644 index 0000000000..2c70f96bbd --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/JsonServiceTest.kt @@ -0,0 +1,229 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.json.Json +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket +import java.net.http.HttpResponse +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class JsonServiceTest { + val closeGracePeriodMillis: Long = 5_000L + val closeTimeoutMillis: Long = 1_000L + val requestBodyLimit: Long = 10L * 1024 * 1024 + val port: Int = ServerSocket(0).use { it.localPort } + + val portListnerTimeout = 60L + + val baseUrl = "http://localhost:$port" + + val projectDir: Path = Paths.get("build/service-json-test") + + private lateinit var proc: Process + + @BeforeAll + fun boot() { + proc = startService("netty", port, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit, projectDir) + val ready = waitForPort(port, portListnerTimeout) + assertTrue(ready, "Service did not start within $portListnerTimeout s") + } + + @AfterAll + fun shutdown() = cleanupService(proc) + + @Test + fun `checks http-header`() { + val response = sendRequest( + "$baseUrl/http-header", + "POST", + null, + "application/json", + "application/json", + "correctToken", + mapOf("X-Request-Header" to "header", "X-Request-Headers-hhh" to "headers"), + ) + assertIs>(response) + + assertEquals(201, response.statusCode(), "Expected 201") + + assertEquals("headers", response.headers().firstValue("X-Response-Header").get()) + assertEquals("header", response.headers().firstValue("X-Response-Headers-hhh").get()) + } + + @Test + fun `checks http-label`() { + val response = sendRequest( + "$baseUrl/http-label/labelValue", + "GET", + null, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + + assertEquals(200, response.statusCode(), "Expected 200") + val body = Json.decodeFromString( + HttpLabelTestOutputResponse.serializer(), + response.body(), + ) + assertEquals("labelValue", body.output) + } + + @Test + fun `checks http-query`() { + val response = sendRequest( + "$baseUrl/http-query?query=123&qqq=kotlin", + "DELETE", + null, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + + assertEquals(200, response.statusCode(), "Expected 200") + val body = Json.decodeFromString( + HttpQueryTestOutputResponse.serializer(), + response.body(), + ) + assertEquals("123kotlin", body.output) + } + + @Test + fun `checks http-payload string`() { + val response = sendRequest( + "$baseUrl/http-payload/string", + "POST", + "This is the entire content", + "text/plain", + "text/plain", + "correctToken", + ) + assertIs>(response) + + assertEquals(201, response.statusCode(), "Expected 201") + assertEquals("This is the entire content", response.body()) + } + + @Test + fun `checks http-payload structure`() { + val requestJson = Json.encodeToJsonElement( + HttpStructurePayloadTestStructure.serializer(), + HttpStructurePayloadTestStructure( + "content", + 123, + 456.toFloat(), + ), + ) + + val response = sendRequest( + "$baseUrl/http-payload/structure", + "POST", + requestJson, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + val body = Json.decodeFromString( + HttpStructurePayloadTestStructure.serializer(), + response.body(), + ) + assertEquals("content", body.content1) + assertEquals(123, body.content2) + assertEquals(456.toFloat(), body.content3) + } + + @Test + fun `checks timestamp`() { + val requestJson = Json.encodeToJsonElement( + TimestampTestRequestResponse.serializer(), + TimestampTestRequestResponse( + 1515531081.123, + "1985-04-12T23:20:50.520Z", + "Tue, 29 Apr 2014 18:30:38 GMT", + 1234567890.123, + ), + ) + + val response = sendRequest( + "$baseUrl/timestamp", + "POST", + requestJson, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + val body = Json.decodeFromString( + TimestampTestRequestResponse.serializer(), + response.body(), + ) + assertEquals(1515531081.123, body.default) + assertEquals("1985-04-12T23:20:50.520Z", body.dateTime) + assertEquals("Tue, 29 Apr 2014 18:30:38 GMT", body.httpDate) + assertEquals(1234567890.123, body.epochSeconds) + } + + @Test + fun `checks json name`() { + val requestJson = Json.encodeToJsonElement( + JsonNameTestRequest.serializer(), + JsonNameTestRequest("Hello Kotlin Team"), + ) + + val response = sendRequest( + "$baseUrl/json-name", + "POST", + requestJson, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + val body = Json.decodeFromString( + JsonNameTestResponse.serializer(), + response.body(), + ) + assertEquals("Hello Kotlin Team", body.responseName) + } + + @Test + fun `checks http error`() { + val response = sendRequest( + "$baseUrl/http-error", + "POST", + null, + "application/json", + "application/json", + "correctToken", + ) + assertIs>(response) + + assertEquals(456, response.statusCode(), "Expected 456") + val body = Json.decodeFromString( + HttpError.serializer(), + response.body(), + ) + + assertEquals(444, body.num) + assertEquals("this is an error message", body.msg) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt index 8d4c17826c..763e9c7a48 100644 --- a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceConstraintsTest.kt @@ -26,7 +26,7 @@ class ServiceConstraintsTest { val requestBodyLimit: Long = 10L * 1024 * 1024 val port: Int = ServerSocket(0).use { it.localPort } - val portListenerTimeout = 180L + val portListenerTimeout = 60L val baseUrl = "http://localhost:$port" diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt index dc289fb3be..6ea3bb88ac 100644 --- a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceDataClasses.kt @@ -20,6 +20,9 @@ data class AuthTestRequest(val input1: String) @Serializable data class ErrorTestRequest(val input1: String) +@Serializable +data class HttpError(val msg: String, val num: Int) + @Serializable data class RequiredConstraintTestRequest(val requiredInput: String? = null, val notRequiredInput: String? = null) @@ -44,3 +47,30 @@ data class NestedUniqueItemsConstraintTestRequest(val nestedUniqueItemsListInput @Serializable data class DoubleNestedUniqueItemsConstraintTestRequest(val doubleNestedUniqueItemsListInput: List>>) + +@Serializable +data class HttpLabelTestOutputResponse(val output: String) + +@Serializable +data class HttpQueryTestOutputResponse(val output: String) + +@Serializable +data class HttpStructurePayloadTestStructure( + val content1: String, + val content2: Int, + val content3: Float, +) + +@Serializable +data class TimestampTestRequestResponse( + val default: Double, + val dateTime: String, + val httpDate: String, + val epochSeconds: Double, +) + +@Serializable +data class JsonNameTestRequest(val requestName: String) + +@Serializable +data class JsonNameTestResponse(val responseName: String) diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt index 4d6e8dc841..46d71ae756 100644 --- a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceEngineFactoryTest.kt @@ -19,8 +19,8 @@ class ServiceEngineFactoryTest { val gracefulWindow = closeTimeoutMillis + closeGracePeriodMillis val requestBodyLimit: Long = 10L * 1024 * 1024 - val portListnerTimeout = 180L - val projectDir: Path = Paths.get("build/service-generator-test") + val portListnerTimeout = 60L + val projectDir: Path = Paths.get("build/service-cbor-test") @Test fun `checks service with netty engine`() { diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt index 548545b679..ffa3a5275b 100644 --- a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt @@ -12,10 +12,10 @@ import kotlin.test.Test import kotlin.test.assertTrue class ServiceFileTest { - val packageName = "com.test" + val packageName = "com.cbor" val packagePath = packageName.replace('.', '/') - val projectDir: Path = Paths.get("build/service-generator-test") + val projectDir: Path = Paths.get("build/service-cbor-test") @Test fun `generates service and all necessary files`() { diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt index 88e4a7cd0c..87954d03e1 100644 --- a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/utils.kt @@ -1,5 +1,6 @@ package com.test +import kotlinx.serialization.json.JsonElement import org.gradle.testkit.runner.GradleRunner import java.io.IOException import java.net.Socket @@ -100,6 +101,7 @@ internal fun sendRequest( contentType: String? = null, acceptType: String? = null, bearerToken: String? = null, + headers: Map = emptyMap(), ): HttpResponse<*> { val client = HttpClient.newHttpClient() @@ -107,18 +109,21 @@ internal fun sendRequest( null -> HttpRequest.BodyPublishers.noBody() is ByteArray -> HttpRequest.BodyPublishers.ofByteArray(data) is String -> HttpRequest.BodyPublishers.ofString(data) + is JsonElement -> HttpRequest.BodyPublishers.ofString(data.toString()) else -> throw IllegalArgumentException( "Unsupported body type: ${data::class.qualifiedName}", ) } - val request = HttpRequest.newBuilder() + val builder = HttpRequest.newBuilder() .uri(URI.create(url)) - .apply { - contentType?.let { header("Content-Type", it) } - acceptType?.let { header("Accept", it) } - bearerToken?.let { header("Authorization", "Bearer $it") } - } + + contentType?.let { builder.header("Content-Type", it) } + acceptType ?.let { builder.header("Accept", it) } + bearerToken?.let { builder.header("Authorization", "Bearer $it") } + headers.forEach { (name, value) -> builder.header(name, value) } + + val request = builder .method(method, bodyPublisher) .build()