Skip to content

Commit 65ca841

Browse files
luigi617luigi
andauthored
Service sdk constraints (#1338)
Co-authored-by: luigi <[email protected]>
1 parent 7806f7d commit 65ca841

20 files changed

+1401
-211
lines changed

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import software.amazon.smithy.kotlin.codegen.core.withBlock
1010
import software.amazon.smithy.kotlin.codegen.core.withInlineBlock
1111
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1212
import software.amazon.smithy.kotlin.codegen.model.getTrait
13+
import software.amazon.smithy.kotlin.codegen.service.contraints.ConstraintGenerator
14+
import software.amazon.smithy.kotlin.codegen.service.contraints.ConstraintUtilsGenerator
1315
import software.amazon.smithy.model.shapes.OperationShape
1416
import software.amazon.smithy.model.traits.AuthTrait
1517
import software.amazon.smithy.model.traits.HttpBearerAuthTrait
@@ -93,7 +95,7 @@ internal class KtorStubGenerator(
9395
}
9496

9597
private fun renderLogging() {
96-
delegator.useFileWriter("Logging.kt", "${ctx.settings.pkg.name}.utils") { writer ->
98+
delegator.useFileWriter("Logging.kt", "$pkgName.utils") { writer ->
9799

98100
writer.withBlock("internal fun #T.configureLogging() {", "}", RuntimeTypes.KtorServerCore.Application) {
99101
withBlock(
@@ -167,20 +169,20 @@ internal class KtorStubGenerator(
167169

168170
// Generates `Authentication.kt` with Authenticator interface + configureSecurity().
169171
override fun renderAuthModule() {
170-
delegator.useFileWriter("UserPrincipal.kt", "${ctx.settings.pkg.name}.auth") { writer ->
172+
delegator.useFileWriter("UserPrincipal.kt", "$pkgName.auth") { writer ->
171173
writer.withBlock("public data class UserPrincipal(", ")") {
172174
write("val user: String")
173175
}
174176
}
175177

176-
delegator.useFileWriter("Validation.kt", "${ctx.settings.pkg.name}.auth") { writer ->
178+
delegator.useFileWriter("Validation.kt", "$pkgName.auth") { writer ->
177179
writer.withBlock("public fun bearerValidation(token: String): UserPrincipal? {", "}") {
178180
write("// TODO: implement me")
179181
write("if (true) return UserPrincipal(#S) else return null", "Authenticated User")
180182
}
181183
}
182184

183-
delegator.useFileWriter("Authentication.kt", "${ctx.settings.pkg.name}.auth") { writer ->
185+
delegator.useFileWriter("Authentication.kt", "$pkgName.auth") { writer ->
184186
writer.withBlock("internal fun #T.configureAuthentication() {", "}", RuntimeTypes.KtorServerCore.Application) {
185187
write("")
186188
withBlock(
@@ -201,20 +203,23 @@ internal class KtorStubGenerator(
201203

202204
// For every operation request structure, create a `validate()` function file.
203205
override fun renderConstraintValidators() {
206+
ConstraintUtilsGenerator(ctx, delegator).render()
207+
operations.forEach { operation -> ConstraintGenerator(ctx, operation, delegator).render() }
204208
}
205209

206210
// Writes `Routing.kt` that maps Smithy operations → Ktor routes.
207211
override fun renderRouting() {
208212
val contentType = ContentType.fromServiceShape(serviceShape)
209213

210-
delegator.useFileWriter("Routing.kt", ctx.settings.pkg.name) { writer ->
214+
delegator.useFileWriter("Routing.kt", pkgName) { writer ->
211215

212216
operations.forEach { shape ->
213-
writer.addImport("${ctx.settings.pkg.name}.serde", "${shape.id.name}OperationDeserializer")
214-
writer.addImport("${ctx.settings.pkg.name}.serde", "${shape.id.name}OperationSerializer")
215-
writer.addImport("${ctx.settings.pkg.name}.model", "${shape.id.name}Request")
216-
writer.addImport("${ctx.settings.pkg.name}.model", "${shape.id.name}Response")
217-
writer.addImport("${ctx.settings.pkg.name}.operations", "handle${shape.id.name}Request")
217+
writer.addImport("$pkgName.serde", "${shape.id.name}OperationDeserializer")
218+
writer.addImport("$pkgName.serde", "${shape.id.name}OperationSerializer")
219+
writer.addImport("$pkgName.constraints", "check${shape.id.name}RequestConstraint")
220+
writer.addImport("$pkgName.model", "${shape.id.name}Request")
221+
writer.addImport("$pkgName.model", "${shape.id.name}Response")
222+
writer.addImport("$pkgName.operations", "handle${shape.id.name}Request")
218223
}
219224

220225
writer.withBlock("internal fun #T.configureRouting(): Unit {", "}", RuntimeTypes.KtorServerCore.Application) {
@@ -271,6 +276,11 @@ internal class KtorStubGenerator(
271276
"Malformed CBOR input",
272277
)
273278
}
279+
write(
280+
"try { check${shape.id.name}RequestConstraint(requestObj) } catch (ex: Exception) { throw #T(ex?.message ?: #S, ex) }",
281+
RuntimeTypes.KtorServerCore.BadRequestException,
282+
"Error while validating constraints",
283+
)
274284
write("val responseObj = handle${shape.id.name}Request(requestObj)")
275285
write("val serializer = ${shape.id.name}OperationSerializer()")
276286
withBlock(
@@ -356,7 +366,7 @@ internal class KtorStubGenerator(
356366
}
357367

358368
private fun renderErrorHandler() {
359-
delegator.useFileWriter("ErrorHandler.kt", "${ctx.settings.pkg.name}.plugins") { writer ->
369+
delegator.useFileWriter("ErrorHandler.kt", "$pkgName.plugins") { writer ->
360370
writer.write("@#T", RuntimeTypes.KotlinxCborSerde.Serializable)
361371
.write("private data class ErrorPayload(val code: Int, val message: String)")
362372
.write("")
@@ -386,7 +396,7 @@ internal class KtorStubGenerator(
386396
write("val acceptsCbor = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestAcceptItems, "application/cbor")
387397
write("val acceptsJson = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestAcceptItems, "application/json")
388398
write("")
389-
write("val log = #T.getLogger(#S)", RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, ctx.settings.pkg.name)
399+
write("val log = #T.getLogger(#S)", RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, pkgName)
390400
write("log.info(#S)", "Route Error Message: \${envelope.msg}")
391401
write("")
392402
withBlock("when {", "}") {
@@ -427,6 +437,16 @@ internal class KtorStubGenerator(
427437
write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )")
428438
}
429439
write("")
440+
withBlock("status(#T.NotFound) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) {
441+
write("val message = #S", "Resource not found")
442+
write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )")
443+
}
444+
write("")
445+
withBlock("status(#T.MethodNotAllowed) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) {
446+
write("val message = #S", "Method not allowed for this resource")
447+
write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )")
448+
}
449+
write("")
430450
withBlock("#T<Throwable> { call, cause ->", "}", RuntimeTypes.KtorServerStatusPage.exception) {
431451
withBlock("val status = when (cause) {", "}") {
432452
write(
@@ -456,7 +476,7 @@ internal class KtorStubGenerator(
456476
}
457477

458478
private fun renderContentTypeGuard() {
459-
delegator.useFileWriter("ContentTypeGuard.kt", "${ctx.settings.pkg.name}.plugins") { writer ->
479+
delegator.useFileWriter("ContentTypeGuard.kt", "$pkgName.plugins") { writer ->
460480

461481
writer.withBlock("private fun #T.hasBody(): Boolean {", "}", RuntimeTypes.KtorServerRouting.requestApplicationRequest) {
462482
write(
@@ -569,9 +589,9 @@ internal class KtorStubGenerator(
569589
override fun renderPerOperationHandlers() {
570590
operations.forEach { shape ->
571591
val name = shape.id.name
572-
delegator.useFileWriter("${name}Operation.kt", "${ctx.settings.pkg.name}.operations") { writer ->
573-
writer.addImport("${ctx.settings.pkg.name}.model", "${shape.id.name}Request")
574-
writer.addImport("${ctx.settings.pkg.name}.model", "${shape.id.name}Response")
592+
delegator.useFileWriter("${name}Operation.kt", "$pkgName.operations") { writer ->
593+
writer.addImport("$pkgName.model", "${shape.id.name}Request")
594+
writer.addImport("$pkgName.model", "${shape.id.name}Response")
575595

576596
writer.withBlock("public fun handle${name}Request(req: ${name}Request): ${name}Response {", "}") {
577597
write("// TODO: implement me")

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,14 @@ class ServiceTypes(val pkgName: String) {
6767
name = "AcceptTypeGuard"
6868
namespace = "$pkgName.plugins"
6969
}
70+
71+
val sizeOf = buildSymbol {
72+
name = "sizeOf"
73+
namespace = "$pkgName.constraints"
74+
}
75+
76+
val hasAllUniqueElements = buildSymbol {
77+
name = "hasAllUniqueElements"
78+
namespace = "$pkgName.constraints"
79+
}
7080
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
internal abstract class AbstractConstraintTraitGenerator {
4+
abstract fun render()
5+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
import software.amazon.smithy.kotlin.codegen.core.GenerationContext
4+
import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator
5+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
6+
import software.amazon.smithy.kotlin.codegen.core.withBlock
7+
import software.amazon.smithy.model.shapes.MemberShape
8+
import software.amazon.smithy.model.shapes.OperationShape
9+
import software.amazon.smithy.model.shapes.StructureShape
10+
import software.amazon.smithy.model.traits.RequiredTrait
11+
import kotlin.collections.iterator
12+
13+
internal class ConstraintGenerator(
14+
val ctx: GenerationContext,
15+
val operation: OperationShape,
16+
val delegator: KotlinDelegator,
17+
) {
18+
val inputShape = ctx.model.expectShape(operation.input.get()) as StructureShape
19+
val inputMembers = inputShape.allMembers
20+
21+
val opName = operation.id.name
22+
val pkgName = ctx.settings.pkg.name
23+
24+
fun render() {
25+
renderRequestConstraintsValidation()
26+
}
27+
private fun generateConstraintValidations(prefix: String, memberShape: MemberShape, writer: KotlinWriter) {
28+
val targetShape = ctx.model.expectShape(memberShape.target)
29+
30+
val memberName = memberShape.memberName
31+
val memberAndTargetTraits = memberShape.allTraits + targetShape.allTraits
32+
33+
for (memberTrait in memberAndTargetTraits.values) {
34+
val traitGenerator = getTraitGeneratorFromTrait(prefix, memberName, memberTrait, pkgName, writer)
35+
if (memberTrait !is RequiredTrait) {
36+
writer.write("if ($prefix$memberName == null) { return }")
37+
}
38+
traitGenerator?.render()
39+
}
40+
41+
for (member in targetShape.allMembers) {
42+
val newMemberPrefix = "${targetShape.id.name}".replaceFirstChar { it.lowercase() }
43+
writer.withBlock("if ($prefix$memberName != null) {", "}") {
44+
withBlock("for ($newMemberPrefix${member.key} in $prefix$memberName) {", "}") {
45+
call { generateConstraintValidations(newMemberPrefix, member.value, writer) }
46+
}
47+
}
48+
}
49+
}
50+
51+
private fun renderRequestConstraintsValidation() {
52+
delegator.useFileWriter("${opName}RequestConstraints.kt", "$pkgName.constraints") { writer ->
53+
writer.addImport("$pkgName.model", "${operation.id.name}Request")
54+
55+
writer.withBlock("public fun check${opName}RequestConstraint(data: ${opName}Request) {", "}") {
56+
for (memberShape in inputMembers.values) {
57+
generateConstraintValidations("data.", memberShape, writer)
58+
}
59+
}
60+
}
61+
}
62+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
import software.amazon.smithy.kotlin.codegen.core.GenerationContext
4+
import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator
5+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
6+
import software.amazon.smithy.kotlin.codegen.core.withBlock
7+
8+
internal class ConstraintUtilsGenerator(
9+
val ctx: GenerationContext,
10+
val delegator: KotlinDelegator,
11+
) {
12+
val pkgName = ctx.settings.pkg.name
13+
14+
fun render() {
15+
delegator.useFileWriter("utils.kt", "$pkgName.constraints") { writer ->
16+
renderLengthTraitUtils(writer)
17+
18+
writer.write("")
19+
renderUniqueItemsTraitUtils(writer)
20+
}
21+
}
22+
23+
private fun renderLengthTraitUtils(writer: KotlinWriter) {
24+
writer.withBlock("internal fun sizeOf(value: Any?): Long = when (value) {", "}") {
25+
write("is Collection<*> -> value.size.toLong()")
26+
write("is Array<*> -> value.size.toLong()")
27+
write("is Map<*, *> -> value.size.toLong()")
28+
write("is String -> value.codePointCount(0, value.length).toLong()")
29+
write("is ByteArray -> value.size.toLong()")
30+
withBlock("else -> {", "}") {
31+
write("val typeName = value?.javaClass?.simpleName ?: #S", "null")
32+
write("throw IllegalArgumentException( #S )", "sizeOf does not support \${typeName} type")
33+
}
34+
}
35+
}
36+
37+
private fun renderUniqueItemsTraitUtils(writer: KotlinWriter) {
38+
writer.withBlock("internal fun hasAllUniqueElements(elements: List<Any?>): Boolean {", "}") {
39+
withBlock("class Wrapped(private val v: Any?) {", "}") {
40+
withBlock("override fun equals(other: Any?): Boolean {", "}") {
41+
write("if (other !is Wrapped) return false")
42+
write("if (v?.javaClass != other.v?.javaClass) return false")
43+
withBlock("return when (v) {", "}") {
44+
write("null -> true")
45+
write("is String,")
46+
write("is Boolean,")
47+
write("is java.time.Instant,")
48+
write("is Number -> v == other.v")
49+
write("is ByteArray -> v.contentEquals(other.v as ByteArray)")
50+
withBlock("is List<*> -> {", "}") {
51+
write("val o = other.v as List<*>")
52+
write("v.size == o.size && v.indices.all { i -> Wrapped(v[i]) == Wrapped(o[i]) }")
53+
}
54+
withBlock("is Map<*, *> -> {", "}") {
55+
write("val o = other.v as Map<*, *>")
56+
write("v.size == o.size && v.all { (k, value) -> o.containsKey(k) && Wrapped(value) == Wrapped(o[k]) }")
57+
}
58+
write("else -> v == other.v")
59+
}
60+
}
61+
withBlock("override fun hashCode(): Int = when (v) {", "}") {
62+
write("null -> 0")
63+
write("is ByteArray -> v.contentHashCode()")
64+
write("is List<*> -> v.fold(1) { acc, e -> 31 * acc + Wrapped(e).hashCode() }")
65+
write("is Map<*, *> -> v.entries.fold(1) { acc, (k, e) -> 31 * acc + Wrapped(k).hashCode() xor Wrapped(e).hashCode() }")
66+
write("else -> v.hashCode()")
67+
}
68+
}
69+
write("")
70+
write("val seen = HashSet<Wrapped>(elements.size)")
71+
write("for (e in elements) if (!seen.add(Wrapped(e))) return false")
72+
write("return true")
73+
}
74+
}
75+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
4+
import software.amazon.smithy.kotlin.codegen.service.ServiceTypes
5+
import software.amazon.smithy.model.traits.LengthTrait
6+
7+
internal class LengthConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: LengthTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() {
8+
override fun render() {
9+
val min = trait.min.orElse(null)
10+
val max = trait.max.orElse(null)
11+
val member = "$memberPrefix$memberName"
12+
13+
if (max != null && min != null) {
14+
writer.write("require(#T($member) in $min..$max) { #S }", ServiceTypes(pkgName).sizeOf, "The size of `$memberName` must be between $min and $max (inclusive)")
15+
} else if (max != null) {
16+
writer.write("require(#T($member) <= $max) { #S }", ServiceTypes(pkgName).sizeOf, "The size of `$memberName` must be less than or equal to $max")
17+
} else {
18+
writer.write("require(#T($member) >= $min) { #S }", ServiceTypes(pkgName).sizeOf, "The size of `$memberName` must be greater than or equal to $min")
19+
}
20+
}
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
4+
import software.amazon.smithy.model.traits.PatternTrait
5+
6+
internal class PatternConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: PatternTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() {
7+
override fun render() {
8+
val member = "$memberPrefix$memberName"
9+
10+
writer.write("require(Regex(#S).containsMatchIn($member)) { #S }", trait.pattern.toString(), "Value `\${$member}` does not match required pattern: `${trait.pattern}`")
11+
}
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
4+
import software.amazon.smithy.model.traits.RangeTrait
5+
6+
internal class RangeConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: RangeTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() {
7+
override fun render() {
8+
val min = trait.min.orElse(null)
9+
val max = trait.max.orElse(null)
10+
val member = "$memberPrefix$memberName"
11+
12+
if (max != null && min != null) {
13+
writer.write("require($member in $min..$max) { #S }", "`$memberName` must be between $min and $max (inclusive)")
14+
} else if (max != null) {
15+
writer.write("require($member <= $max) { #S }", "`$memberName` must be less than or equal to $max")
16+
} else {
17+
writer.write("require($member >= $min) { #S }", "`$memberName` must be greater than or equal to $min")
18+
}
19+
}
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package software.amazon.smithy.kotlin.codegen.service.contraints
2+
3+
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
4+
import software.amazon.smithy.model.traits.RequiredTrait
5+
6+
internal class RequiredConstraintGenerator(val memberPrefix: String, val memberName: String, val trait: RequiredTrait, val pkgName: String, val writer: KotlinWriter) : AbstractConstraintTraitGenerator() {
7+
override fun render() {
8+
val member = "$memberPrefix$memberName"
9+
writer.write("require($member != null) { #S }", "`$memberName` must be provided")
10+
}
11+
}

0 commit comments

Comments
 (0)