diff --git a/build.gradle.kts b/build.gradle.kts index bdec55641a..8b7a33571e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -114,6 +114,7 @@ apiValidation { "nullability-tests", "paginator-tests", "waiter-tests", + "service-codegen-tests", "compile", "slf4j-1x-consumer", "slf4j-2x-consumer", diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt index 245623a1c3..ff09bc0d5b 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/RpcV2Cbor.kt @@ -109,6 +109,7 @@ class RpcV2Cbor : AwsHttpBindingProtocolGenerator() { resolver.requestBindings(op) } val httpPayload = bindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD } + if (httpPayload != null) { renderExplicitHttpPayloadSerializer(ctx, httpPayload, writer) } else { @@ -117,9 +118,15 @@ class RpcV2Cbor : AwsHttpBindingProtocolGenerator() { // delegate to the generate operation body serializer function val sdg = structuredDataSerializer(ctx) val opBodySerializerFn = sdg.operationSerializer(ctx, op, documentMembers) - writer.write("builder.body = #T(context, input)", opBodySerializerFn) + if (ctx.settings.build.generateServiceProject) { + writer.write("response = #T(context, input)", opBodySerializerFn) + } else { + writer.write("builder.body = #T(context, input)", opBodySerializerFn) + } + } + if (!ctx.settings.build.generateServiceProject) { + renderContentTypeHeader(ctx, op, writer, resolver) } - renderContentTypeHeader(ctx, op, writer, resolver) } /** diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt index f6cc25a550..5283bfa228 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinDependency.kt @@ -38,7 +38,7 @@ private fun getDefaultRuntimeVersion(): String { const val RUNTIME_GROUP: String = "aws.smithy.kotlin" val RUNTIME_VERSION: String = System.getProperty("smithy.kotlin.codegen.clientRuntimeVersion", getDefaultRuntimeVersion()) val KOTLIN_COMPILER_VERSION: String = System.getProperty("smithy.kotlin.codegen.kotlinCompilerVersion", "2.2.0") -val KTOR_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorVersion", "3.1.3") +val KTOR_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorVersion", "3.2.2") val SERIALIZATION_PLUGIN: String = System.getProperty("smithy.kotlin.codegen.SerializationPlugin", "2.0.20") val KOTLINX_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorKotlinxVersion", "1.9.0") val KTOR_LOGGING_BACKEND_VERSION: String = System.getProperty("smithy.kotlin.codegen.ktorLoggingBackendVersion", "1.4.14") @@ -144,8 +144,11 @@ data class KotlinDependency( // FIXME: version numbers should not be hardcoded, they should be setting dynamically based on the Gradle library versions val KTOR_SERVER_CORE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server", "io.ktor", "ktor-server-core", KTOR_VERSION) val KTOR_SERVER_NETTY = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.netty", "io.ktor", "ktor-server-netty", KTOR_VERSION) + val KTOR_SERVER_CIO = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.cio", "io.ktor", "ktor-server-cio", KTOR_VERSION) + val KTOR_SERVER_JETTY_JAKARTA = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.jetty.jakarta", "io.ktor", "ktor-server-jetty-jakarta", KTOR_VERSION) val KTOR_SERVER_HTTP = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.http", "io.ktor", "ktor-http-jvm", KTOR_VERSION) val KTOR_SERVER_LOGGING = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins.calllogging", "io.ktor", "ktor-server-call-logging", KTOR_VERSION) + val KTOR_SERVER_BODY_LIMIT = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins", "io.ktor", "ktor-server-body-limit", KTOR_VERSION) val KTOR_LOGGING_SLF4J = KotlinDependency(GradleConfiguration.Implementation, "org.slf4j", "ch.qos.logback", "logback-classic", KTOR_LOGGING_BACKEND_VERSION) val KTOR_LOGGING_LOGBACK = KotlinDependency(GradleConfiguration.Implementation, "ch.qos.logback", "ch.qos.logback", "logback-classic", KTOR_LOGGING_BACKEND_VERSION) val KTOR_SERVER_STATUS_PAGE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins.statuspages", "io.ktor", "ktor-server-status-pages-jvm", KTOR_VERSION) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt index b45a1c762f..6ed0d40286 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/RuntimeTypes.kt @@ -497,9 +497,12 @@ object RuntimeTypes { val embeddedServer = symbol("embeddedServer", "engine") val EmbeddedServerType = symbol("EmbeddedServer", "engine") val ApplicationEngineFactory = symbol("ApplicationEngineFactory", "engine") + val connector = symbol("connector", "engine") val Application = symbol("Application", "application") val ApplicationCallClass = symbol("ApplicationCall", "application") + val ApplicationStarting = symbol("ApplicationStarting", "application") + val ApplicationStarted = symbol("ApplicationStarted", "application") val ApplicationStopping = symbol("ApplicationStopping", "application") val ApplicationStopped = symbol("ApplicationStopped", "application") val ApplicationCreateRouteScopedPlugin = symbol("createRouteScopedPlugin", "application") @@ -527,20 +530,31 @@ object RuntimeTypes { val requestApplicationRequest = symbol("ApplicationRequest", "request") val requestContentLength = symbol("contentLength", "request") val requestContentType = symbol("contentType", "request") - val requestacceptItems = symbol("acceptItems", "request") + val requestAcceptItems = symbol("acceptItems", "request") - val responseText = symbol("respondText", "response") - val responseRespond = symbol("respond", "response") + val responseResponseText = symbol("respondText", "response") val responseRespondBytes = symbol("respondBytes", "response") } object KtorServerNetty : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_NETTY) { val Netty = symbol("Netty") + val Configuration = symbol("Configuration", "NettyApplicationEngine") + } + + object KtorServerCio : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_CIO) { + val CIO = symbol("CIO") + val Configuration = symbol("Configuration", "CIOApplicationEngine") + } + + object KtorServerJettyJakarta : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_JETTY_JAKARTA) { + val Jetty = symbol("Jetty") + val Configuration = symbol("Configuration", "JettyApplicationEngineBase") } object KtorServerHttp : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_HTTP) { val ContentType = symbol("ContentType") val HttpStatusCode = symbol("HttpStatusCode") + val parseAndSortHeader = symbol("parseAndSortHeader") val HttpHeaders = symbol("HttpHeaders") val Cbor = symbol("Cbor", "ContentType.Application") val Json = symbol("Json", "ContentType.Application") @@ -550,6 +564,11 @@ object RuntimeTypes { val CallLogging = symbol("CallLogging") } + object KtorServerBodyLimit : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_BODY_LIMIT) { + val RequestBodyLimit = symbol("RequestBodyLimit", "bodylimit") + val PayloadTooLargeException = symbol("PayloadTooLargeException") + } + object KtorLoggingSlf4j : RuntimeTypePackage(KotlinDependency.KTOR_LOGGING_SLF4J) { val Level = symbol("Level", "event") val LoggerFactory = symbol("LoggerFactory") diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt index 6d2384584e..5eef2a0ef7 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/lang/KotlinTypes.kt @@ -117,6 +117,7 @@ object KotlinTypes { val Duration = stdlibSymbol("Duration") val milliseconds = stdlibSymbol("milliseconds", "time.Duration.Companion") val minutes = stdlibSymbol("minutes", "time.Duration.Companion") + val seconds = stdlibSymbol("seconds", "time.Duration.Companion") } object Coroutines { diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt index f1d28bcb39..73453fb104 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/protocol/HttpBindingProtocolGenerator.kt @@ -136,6 +136,9 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { } override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) { + if (ctx.settings.build.generateServiceProject) { + require(protocolName == "smithyRpcv2cbor") { "service project accepts only Cbor protocol" } + } if (!ctx.settings.build.generateServiceProject) { val symbol = ctx.symbolProvider.toSymbol(ctx.service) ctx.delegator.useFileWriter("Default${symbol.name}.kt", ctx.settings.pkg.name) { writer -> @@ -182,27 +185,47 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { val serdeMeta = HttpSerdeMeta(op.isInputEventStream(ctx.model)) ctx.delegator.useSymbolWriter(serializerSymbol) { writer -> - writer - .addImport(operationSerializerSymbols) - .write("") - .openBlock("internal class #T: #T.#L<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerializer, serdeMeta.variantName, serializationSymbol) - .call { - val modifier = if (serdeMeta.isStreaming) "suspend " else "" - writer.openBlock( - "override #Lfun serialize(context: #T, input: #T): #T {", - modifier, - RuntimeTypes.Core.ExecutionContext, - serializationSymbol, - RuntimeTypes.Http.Request.HttpRequestBuilder, - ) - .write("val builder = #T()", RuntimeTypes.Http.Request.HttpRequestBuilder) - .call { - renderHttpSerialize(ctx, op, writer) - } - .write("return builder") - .closeBlock("}") - } - .closeBlock("}") + // FIXME: this works only for Cbor protocol now + if (ctx.settings.build.generateServiceProject) { + writer + .openBlock("internal class #T {", serializerSymbol) + .call { + writer.openBlock( + "public fun serialize(context: #T, input: #T): ByteArray {", + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + ) + .write("var response: Any") + .call { + renderSerializeHttpBody(ctx, op, writer) + } + .write("return response") + .closeBlock("}") + } + .closeBlock("}") + } else { + writer + .addImport(operationSerializerSymbols) + .write("") + .openBlock("internal class #T: #T.#L<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerializer, serdeMeta.variantName, serializationSymbol) + .call { + val modifier = if (serdeMeta.isStreaming) "suspend " else "" + writer.openBlock( + "override #Lfun serialize(context: #T, input: #T): #T {", + modifier, + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + RuntimeTypes.Http.Request.HttpRequestBuilder, + ) + .write("val builder = #T()", RuntimeTypes.Http.Request.HttpRequestBuilder) + .call { + renderHttpSerialize(ctx, op, writer) + } + .write("return builder") + .closeBlock("}") + } + .closeBlock("}") + } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt index 13af7c2e04..ddb70b68dd 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/serde/CborSerializerGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.kotlin.codegen.core.KotlinWriter import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes import software.amazon.smithy.kotlin.codegen.core.withBlock +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.knowledge.SerdeIndex import software.amazon.smithy.kotlin.codegen.model.targetOrSelf import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator @@ -28,9 +29,20 @@ class CborSerializerGenerator( val serializationShape = serializationTarget.get().let { ctx.model.expectShape(it) } val serializationSymbol = ctx.symbolProvider.toSymbol(serializationShape) + val serializerResultSymbol = when { + ctx.settings.build.generateServiceProject -> KotlinTypes.ByteArray + else -> RuntimeTypes.Http.HttpBody + } return op.bodySerializer(ctx.settings) { writer -> addNestedDocumentSerializers(ctx, op, writer) - writer.withBlock("private fun #L(context: #T, input: #T): #T {", "}", op.bodySerializerName(), RuntimeTypes.Core.ExecutionContext, serializationSymbol, RuntimeTypes.Http.HttpBody) { + writer.withBlock( + "private fun #L(context: #T, input: #T): #T {", + "}", + op.bodySerializerName(), + RuntimeTypes.Core.ExecutionContext, + serializationSymbol, + serializerResultSymbol, + ) { call { renderSerializeOperationBody(ctx, op, members, writer) } @@ -52,7 +64,11 @@ class CborSerializerGenerator( val serializationshape = ctx.model.expectShape(serializationTarget.get()) writer.write("val serializer = #T()", RuntimeTypes.Serde.SerdeCbor.CborSerializer) renderSerializerBody(ctx, serializationshape, documentMembers, writer) - writer.write("return serializer.toHttpBody()") + if (ctx.settings.build.generateServiceProject) { + writer.write("return serializer.toByteArray()") + } else { + writer.write("return serializer.toHttpBody()") + } } private fun renderSerializerBody( diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt index 13cf84b233..f1a5d215ed 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/KtorStubGenerator.kt @@ -8,6 +8,7 @@ import software.amazon.smithy.kotlin.codegen.core.KotlinWriter import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes import software.amazon.smithy.kotlin.codegen.core.withBlock import software.amazon.smithy.kotlin.codegen.core.withInlineBlock +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.getTrait import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.AuthTrait @@ -32,25 +33,52 @@ internal class KtorStubGenerator( ) : AbstractStubGenerator(ctx, delegator, fileManifest) { override fun renderServerFrameworkImplementation(writer: KotlinWriter) { - writer.addImport(RuntimeTypes.KtorServerNetty.Netty) - - writer.withBlock("internal class KtorServiceFramework () : ServiceFramework {", "}") { + writer.withBlock("internal fun #T.module(): Unit {", "}", RuntimeTypes.KtorServerCore.Application) { + withBlock("#T(#T) {", "}", RuntimeTypes.KtorServerCore.install, RuntimeTypes.KtorServerBodyLimit.RequestBodyLimit) { + write("bodyLimit { #T.requestBodyLimit }", ServiceTypes(pkgName).serviceFrameworkConfig) + } + write("#T()", ServiceTypes(pkgName).configureErrorHandling) + write("#T()", ServiceTypes(pkgName).configureAuthentication) + write("#T()", ServiceTypes(pkgName).configureRouting) + write("#T()", ServiceTypes(pkgName).configureLogging) + } + .write("") + writer.withBlock("internal class KtorServiceFramework() : ServiceFramework {", "}") { write("private var engine: #T<*, *>? = null", RuntimeTypes.KtorServerCore.EmbeddedServerType) + write("") + write("private val engineFactory = #T.engine.toEngineFactory()", ServiceTypes(pkgName).serviceFrameworkConfig) + write("") withBlock("override fun start() {", "}") { - withBlock( - "engine = #T(#T.engine.toEngineFactory(), port = #T.port) {", - "}.apply { start(wait = true) }", - RuntimeTypes.KtorServerCore.embeddedServer, - ServiceTypes(pkgName).serviceFrameworkConfig, - ServiceTypes(pkgName).serviceFrameworkConfig, - ) { - write("#T()", ServiceTypes(pkgName).configureErrorHandling) + withInlineBlock("engine = #T(", ")", RuntimeTypes.KtorServerCore.embeddedServer) { + write("engineFactory,") + withBlock("configure = {", "}") { + withBlock("#T {", "}", RuntimeTypes.KtorServerCore.connector) { + write("host = #S", "0.0.0.0") + write("port = #T.port", ServiceTypes(pkgName).serviceFrameworkConfig) + } + withBlock("when (this) {", "}") { + withBlock("is #T -> {", "}", RuntimeTypes.KtorServerNetty.Configuration) { + write("requestReadTimeoutSeconds = #T.requestReadTimeoutSeconds", ServiceTypes(pkgName).serviceFrameworkConfig) + write("responseWriteTimeoutSeconds = #T.responseWriteTimeoutSeconds", ServiceTypes(pkgName).serviceFrameworkConfig) + } - write("#T()", ServiceTypes(pkgName).configureAuthentication) - write("#T()", ServiceTypes(pkgName).configureRouting) - write("#T()", ServiceTypes(pkgName).configureLogging) + withBlock("is #T -> {", "}", RuntimeTypes.KtorServerCio.Configuration) { + write("connectionIdleTimeoutSeconds = #T.requestReadTimeoutSeconds", ServiceTypes(pkgName).serviceFrameworkConfig) + } + + withBlock("is #T -> {", "}", RuntimeTypes.KtorServerJettyJakarta.Configuration) { + write( + "idleTimeout = #T.requestReadTimeoutSeconds.#T", + ServiceTypes(pkgName).serviceFrameworkConfig, + KotlinTypes.Time.seconds, + ) + } + } + } } + write("{ #T() }", ServiceTypes(pkgName).module) + write("engine?.apply { start(wait = true) }") } write("") withBlock("final override fun close() {", "}") { @@ -60,7 +88,11 @@ internal class KtorStubGenerator( } } - override fun renderLogging() { + override fun renderUtils() { + renderLogging() + } + + private fun renderLogging() { delegator.useFileWriter("Logging.kt", "${ctx.settings.pkg.name}.utils") { writer -> writer.withBlock("internal fun #T.configureLogging() {", "}", RuntimeTypes.KtorServerCore.Application) { @@ -99,12 +131,20 @@ internal class KtorStubGenerator( } write("val log = #T.getLogger(#S)", RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, ctx.settings.pkg.name) + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStarting) { + write("log.info(#S)", "Server is starting...") + } + + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStarted) { + write("log.info(#S)", "Server started – ready to accept requests.") + } + withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStopping) { - write("log.warn(#S)", "▶ Server is stopping – waiting for in-flight requests...") + write("log.warn(#S)", "Server is stopping – waiting for in-flight requests...") } withBlock("monitor.subscribe(#T) {", "}", RuntimeTypes.KtorServerCore.ApplicationStopped) { - write("log.info(#S)", "⏹ Server stopped cleanly.") + write("log.info(#S)", "Server stopped cleanly.") } } } @@ -180,9 +220,8 @@ internal class KtorStubGenerator( writer.withBlock("internal fun #T.configureRouting(): Unit {", "}", RuntimeTypes.KtorServerCore.Application) { withBlock("#T {", "}", RuntimeTypes.KtorServerRouting.routing) { withBlock("#T(#S) {", "}", RuntimeTypes.KtorServerRouting.get, "/") { - write(" #T.#T(#S)", RuntimeTypes.KtorServerCore.applicationCall, RuntimeTypes.KtorServerRouting.responseText, "hello world") + write(" #T.#T(#S)", RuntimeTypes.KtorServerCore.applicationCall, RuntimeTypes.KtorServerRouting.responseResponseText, "hello world") } - operations.filter { it.hasTrait(HttpTrait.ID) } .forEach { shape -> val httpTrait = shape.getTrait()!! @@ -207,6 +246,7 @@ internal class KtorStubGenerator( withBlock("#T (#S) {", "}", RuntimeTypes.KtorServerRouting.route, uri) { write("#T(#T) { $contentTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).contentTypeGuard) + write("#T(#T) { $contentTypeGuard }", RuntimeTypes.KtorServerCore.install, ServiceTypes(pkgName).acceptTypeGuard) withBlock( "#W", "}", @@ -286,10 +326,7 @@ internal class KtorStubGenerator( RuntimeTypes.KtorServerCore.applicationCall, RuntimeTypes.KtorServerRouting.responseRespondBytes, ) { - write( - "bytes = response.body.#T() ?: ByteArray(0),", - RuntimeTypes.Http.readAll, - ) + write("bytes = response,") write("contentType = #T,", RuntimeTypes.KtorServerHttp.Cbor) write( "status = #T.fromValue($successCode),", @@ -300,13 +337,9 @@ internal class KtorStubGenerator( "#T.#T(", ")", RuntimeTypes.KtorServerCore.applicationCall, - RuntimeTypes.KtorServerRouting.responseRespond, + RuntimeTypes.KtorServerRouting.responseResponseText, ) { - write( - "bytes = response.body.#T()?.decodeToString() ?: #S,", - RuntimeTypes.Http.readAll, - "{}", - ) + write("text = response,") write("contentType = #T,", RuntimeTypes.KtorServerHttp.Json) write( "status = #T.fromValue($successCode),", @@ -319,6 +352,7 @@ internal class KtorStubGenerator( override fun renderPlugins() { renderErrorHandler() renderContentTypeGuard() + renderAcceptTypeGuard() } private fun renderErrorHandler() { @@ -349,9 +383,11 @@ internal class KtorStubGenerator( write("status: #T,", RuntimeTypes.KtorServerHttp.HttpStatusCode) } .withBlock("{", "}") { - write("val acceptsCbor = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestacceptItems, "application/cbor") - write("val acceptsJson = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestacceptItems, "application/json") - + write("val acceptsCbor = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestAcceptItems, "application/cbor") + write("val acceptsJson = request.#T().any { it.value == #S }", RuntimeTypes.KtorServerRouting.requestAcceptItems, "application/json") + write("") + write("val log = #T.getLogger(#S)", RuntimeTypes.KtorLoggingSlf4j.LoggerFactory, ctx.settings.pkg.name) + write("log.info(#S)", "Route Error Message: \${envelope.msg}") write("") withBlock("when {", "}") { withBlock("acceptsCbor -> {", "}") { @@ -362,14 +398,14 @@ internal class KtorStubGenerator( } } withBlock("acceptsJson -> {", "}") { - withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseText) { + withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseResponseText) { write("envelope.toJson(),") write("status = status,") write("contentType = #T", RuntimeTypes.KtorServerHttp.Json) } } withBlock("else -> {", "}") { - withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseText) { + withBlock("#T(", ")", RuntimeTypes.KtorServerRouting.responseResponseText) { write("envelope.msg,") write("status = status") } @@ -388,7 +424,7 @@ internal class KtorStubGenerator( withBlock("status(#T.Unauthorized) { call, status ->", "}", RuntimeTypes.KtorServerHttp.HttpStatusCode) { write("val missing = call.request.headers[#S].isNullOrBlank()", "Authorization") write("val message = if (missing) #S else #S", "Missing bearer token", "Invalid or expired bearer token") - write("call.respondEnvelope( ErrorEnvelope(#T.Unauthorized.value, message), #T.Unauthorized )", RuntimeTypes.KtorServerHttp.HttpStatusCode, RuntimeTypes.KtorServerHttp.HttpStatusCode) + write("call.respondEnvelope( ErrorEnvelope(status.value, message), status )") } write("") withBlock("#T { call, cause ->", "}", RuntimeTypes.KtorServerStatusPage.exception) { @@ -402,6 +438,11 @@ internal class KtorStubGenerator( RuntimeTypes.KtorServerCore.BadRequestException, RuntimeTypes.KtorServerHttp.HttpStatusCode, ) + write( + "is #T -> #T.PayloadTooLarge", + RuntimeTypes.KtorServerBodyLimit.PayloadTooLargeException, + RuntimeTypes.KtorServerHttp.HttpStatusCode, + ) write("else -> #T.InternalServerError", RuntimeTypes.KtorServerHttp.HttpStatusCode) } write("") @@ -456,7 +497,67 @@ internal class KtorStubGenerator( withBlock("if (incoming == #T.Any || allowed.none { incoming.match(it) }) {", "}", RuntimeTypes.KtorServerHttp.ContentType) { withBlock("throw #T(", ")", ServiceTypes(pkgName).errorEnvelope) { write("#T.UnsupportedMediaType.value, ", RuntimeTypes.KtorServerHttp.HttpStatusCode) - write("#S", "Allowed Content-Type(s): \${allowed.joinToString()}") + write("#S", "Not acceptable Content‑Type found: '\${incoming}'. Accepted content types: \${allowed.joinToString()}") + } + } + } + } + } + } + + private fun renderAcceptTypeGuard() { + delegator.useFileWriter("AcceptTypeGuard.kt", "${ctx.settings.pkg.name}.plugins") { writer -> + + writer.withBlock( + "private fun #T.acceptedContentTypes(): List<#T> {", + "}", + RuntimeTypes.KtorServerRouting.requestApplicationRequest, + RuntimeTypes.KtorServerHttp.ContentType, + ) { + write("val raw = headers[#T.Accept] ?: return emptyList()", RuntimeTypes.KtorServerHttp.HttpHeaders) + write( + "return #T(raw).mapNotNull { it.value?.let(#T::parse) }", + RuntimeTypes.KtorServerHttp.parseAndSortHeader, + RuntimeTypes.KtorServerHttp.ContentType, + ) + } + + writer.withBlock("public class AcceptTypeGuardConfig {", "}") { + write("public var allow: List<#T> = emptyList()", RuntimeTypes.KtorServerHttp.ContentType) + write("") + withBlock("public fun json(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Json) + } + write("") + withBlock("public fun cbor(): Unit {", "}") { + write("allow = listOf(#T)", RuntimeTypes.KtorServerHttp.Cbor) + } + } + .write("") + + writer.withInlineBlock( + "public val AcceptTypeGuard: #T = #T(", + ")", + RuntimeTypes.KtorServerCore.ApplicationRouteScopedPlugin, + RuntimeTypes.KtorServerCore.ApplicationCreateRouteScopedPlugin, + ) { + write("name = #S,", "AcceptTypeGuard") + write("createConfiguration = ::AcceptTypeGuardConfig,") + } + .withBlock("{", "}") { + write("val allowed: List<#T> = pluginConfig.allow", RuntimeTypes.KtorServerHttp.ContentType) + write("require(allowed.isNotEmpty()) { #S }", "AcceptTypeGuard installed with empty allow-list.") + write("") + withBlock("onCall { call ->", "}") { + write("val accepted = call.request.acceptedContentTypes()") + write("if (accepted.isEmpty()) return@onCall") + write("") + write("val isOk = accepted.any { candidate -> allowed.any { candidate.match(it) } }") + + withBlock("if (!isOk) {", "}") { + withBlock("throw #T(", ")", ServiceTypes(pkgName).errorEnvelope) { + write("#T.NotAcceptable.value, ", RuntimeTypes.KtorServerHttp.HttpStatusCode) + write("#S", "Not acceptable Accept type found: '\${accepted}'. Accepted types: \${allowed.joinToString()}") } } } diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt index 4dfc9dd5d1..ad660d2244 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceStubGenerator.kt @@ -31,7 +31,7 @@ internal abstract class AbstractStubGenerator( renderServiceFrameworkConfig() renderServiceFramework() renderPlugins() - renderLogging() + renderUtils() renderAuthModule() renderConstraintValidators() renderPerOperationHandlers() @@ -68,21 +68,28 @@ internal abstract class AbstractStubGenerator( writer.write("") writer.withBlock("internal enum class ServiceEngine(val value: String) {", "}") { - write("NETTY(#S),", "netty") + write("NETTY_ENGINE(#S),", "netty") + write("CIO_ENGINE(#S),", "cio") + write("JETTY_JAKARTA_ENGINE(#S),", "jetty-jakarta") write(";") write("") write("override fun toString(): String = value") write("") withBlock("companion object {", "}") { - withBlock("fun fromValue(value: String): #T = when (value.lowercase()) {", "}", ServiceTypes(pkgName).serviceEngine) { - write("NETTY.value -> NETTY") - write("else -> throw IllegalArgumentException(#S)", "\$value is not a valid ServerFramework value, expected \$NETTY") + withBlock("fun fromValue(value: String): #T {", "}", ServiceTypes(pkgName).serviceEngine) { + write( + "return #T.entries.firstOrNull { it.value.equals(value.lowercase(), ignoreCase = true) } ?: throw IllegalArgumentException(#S)", + ServiceTypes(pkgName).serviceEngine, + "\$value is not a validContentType value, expected one of \${ServiceEngine.entries}", + ) } } write("") withBlock("fun toEngineFactory(): #T<*, *> {", "}", RuntimeTypes.KtorServerCore.ApplicationEngineFactory) { withBlock("return when(this) {", "}") { - write("NETTY -> #T", RuntimeTypes.KtorServerNetty.Netty) + write("NETTY_ENGINE -> #T as #T<*, *>", RuntimeTypes.KtorServerNetty.Netty, RuntimeTypes.KtorServerCore.ApplicationEngineFactory) + write("CIO_ENGINE -> #T as #T<*, *>", RuntimeTypes.KtorServerCio.CIO, RuntimeTypes.KtorServerCore.ApplicationEngineFactory) + write("JETTY_JAKARTA_ENGINE -> #T as #T<*, *>", RuntimeTypes.KtorServerJettyJakarta.Jetty, RuntimeTypes.KtorServerCore.ApplicationEngineFactory) } } } @@ -94,6 +101,9 @@ internal abstract class AbstractStubGenerator( withBlock("private data class Data(", ")") { write("val port: Int,") write("val engine: #T,", ServiceTypes(pkgName).serviceEngine) + write("val requestBodyLimit: Long,") + write("val requestReadTimeoutSeconds: Int,") + write("val responseWriteTimeoutSeconds: Int,") write("val closeGracePeriodMillis: Long,") write("val closeTimeoutMillis: Long,") write("val logLevel: #T,", ServiceTypes(pkgName).logLevel) @@ -101,6 +111,9 @@ internal abstract class AbstractStubGenerator( write("") write("val port: Int get() = backing?.port ?: notInitialised(#S)", "port") write("val engine: #T get() = backing?.engine ?: notInitialised(#S)", ServiceTypes(pkgName).serviceEngine, "engine") + write("val requestBodyLimit: Long get() = backing?.requestBodyLimit ?: notInitialised(#S)", "requestBodyLimit") + write("val requestReadTimeoutSeconds: Int get() = backing?.requestReadTimeoutSeconds ?: notInitialised(#S)", "requestReadTimeoutSeconds") + write("val responseWriteTimeoutSeconds: Int get() = backing?.responseWriteTimeoutSeconds ?: notInitialised(#S)", "responseWriteTimeoutSeconds") write("val closeGracePeriodMillis: Long get() = backing?.closeGracePeriodMillis ?: notInitialised(#S)", "closeGracePeriodMillis") write("val closeTimeoutMillis: Long get() = backing?.closeTimeoutMillis ?: notInitialised(#S)", "closeTimeoutMillis") write("val logLevel: #T get() = backing?.logLevel ?: notInitialised(#S)", ServiceTypes(pkgName).logLevel, "logLevel") @@ -108,13 +121,16 @@ internal abstract class AbstractStubGenerator( withInlineBlock("fun init(", ")") { write("port: Int,") write("engine: #T,", ServiceTypes(pkgName).serviceEngine) + write("requestBodyLimit: Long,") + write("requestReadTimeoutSeconds: Int,") + write("responseWriteTimeoutSeconds: Int,") write("closeGracePeriodMillis: Long,") write("closeTimeoutMillis: Long,") write("logLevel: #T,", ServiceTypes(pkgName).logLevel) } withBlock("{", "}") { write("check(backing == null) { #S }", "ServiceFrameworkConfig has already been initialised") - write("backing = Data(port, engine, closeGracePeriodMillis, closeTimeoutMillis, logLevel)") + write("backing = Data(port, engine, requestBodyLimit, requestReadTimeoutSeconds, responseWriteTimeoutSeconds, closeGracePeriodMillis, closeTimeoutMillis, logLevel)") } write("") withBlock("private fun notInitialised(prop: String): Nothing {", "}") { @@ -143,8 +159,8 @@ internal abstract class AbstractStubGenerator( /** Emits content-type guards, error handler plugins, … */ protected abstract fun renderPlugins() - /** Emits Logback XML + KLogger wiring. */ - protected abstract fun renderLogging() + /** Emits utils. */ + protected abstract fun renderUtils() /** Auth interfaces & installers (bearer, IAM, …). */ protected abstract fun renderAuthModule() @@ -162,6 +178,9 @@ internal abstract class AbstractStubGenerator( protected fun renderMainFile() { val portName = "port" val engineFactoryName = "engineFactory" + val requestBodyLimitName = "requestBodyLimit" + val requestReadTimeoutSecondsName = "requestReadTimeoutSeconds" + val responseWriteTimeoutSecondsName = "responseWriteTimeoutSeconds" val closeGracePeriodMillisName = "closeGracePeriodMillis" val closeTimeoutMillisName = "closeTimeoutMillis" val logLevelName = "logLevel" @@ -171,7 +190,10 @@ internal abstract class AbstractStubGenerator( write("val argMap: Map = args.asList().chunked(2).associate { (k, v) -> k.removePrefix(#S) to v }", "--") write("") write("val defaultPort = 8080") - write("val defaultEngine = #T.NETTY.value", ServiceTypes(pkgName).serviceEngine) + write("val defaultEngine = #T.NETTY_ENGINE.value", ServiceTypes(pkgName).serviceEngine) + write("val defaultRequestBodyLimit = 10L * 1024 * 1024") + write("val defaultRequestReadTimeoutSeconds = 30") + write("val defaultResponseWriteTimeoutSeconds = 30") write("val defaultCloseGracePeriodMillis = 1_000L") write("val defaultCloseTimeoutMillis = 5_000L") write("val defaultLogLevel = #T.INFO.value", ServiceTypes(pkgName).logLevel) @@ -179,6 +201,9 @@ internal abstract class AbstractStubGenerator( withBlock("#T.init(", ")", ServiceTypes(pkgName).serviceFrameworkConfig) { write("port = argMap[#S]?.toInt() ?: defaultPort, ", portName) write("engine = #T.fromValue(argMap[#S] ?: defaultEngine), ", ServiceTypes(pkgName).serviceEngine, engineFactoryName) + write("requestBodyLimit = argMap[#S]?.toLong() ?: defaultRequestBodyLimit, ", requestBodyLimitName) + write("requestReadTimeoutSeconds = argMap[#S]?.toInt() ?: defaultRequestReadTimeoutSeconds, ", requestReadTimeoutSecondsName) + write("responseWriteTimeoutSeconds = argMap[#S]?.toInt() ?: defaultResponseWriteTimeoutSeconds, ", responseWriteTimeoutSecondsName) write("closeGracePeriodMillis = argMap[#S]?.toLong() ?: defaultCloseGracePeriodMillis, ", closeGracePeriodMillisName) write("closeTimeoutMillis = argMap[#S]?.toLong() ?: defaultCloseTimeoutMillis, ", closeTimeoutMillisName) write("logLevel = #T.fromValue(argMap[#S] ?: defaultLogLevel), ", ServiceTypes(pkgName).logLevel, logLevelName) diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt index 45850309a1..366e0f5a8c 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/service/ServiceTypes.kt @@ -28,9 +28,9 @@ class ServiceTypes(val pkgName: String) { namespace = "$pkgName.framework" } - val errorEnvelope = buildSymbol { - name = "ErrorEnvelope" - namespace = "$pkgName.plugins" + val module = buildSymbol { + name = "module" + namespace = "$pkgName.framework" } val configureErrorHandling = buildSymbol { @@ -48,13 +48,23 @@ class ServiceTypes(val pkgName: String) { namespace = "$pkgName.utils" } + val configureAuthentication = buildSymbol { + name = "configureAuthentication" + namespace = "$pkgName.auth" + } + + val errorEnvelope = buildSymbol { + name = "ErrorEnvelope" + namespace = "$pkgName.plugins" + } + val contentTypeGuard = buildSymbol { name = "ContentTypeGuard" namespace = "$pkgName.plugins" } - val configureAuthentication = buildSymbol { - name = "configureAuthentication" - namespace = "$pkgName.auth" + val acceptTypeGuard = buildSymbol { + name = "AcceptTypeGuard" + namespace = "$pkgName.plugins" } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 3ef5eb2d53..edfc411abf 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -27,7 +27,7 @@ kotlin-compile-testing-version = "0.7.0" kotlinx-benchmark-version = "0.4.12" kotlinx-serialization-version = "1.7.3" docker-java-version = "3.4.0" -ktor-version = "3.1.1" +ktor-version = "3.2.2" kaml-version = "0.55.0" jsoup-version = "1.19.1" @@ -88,6 +88,7 @@ kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version. kotest-assertions-core-jvm = { module = "io.kotest:kotest-assertions-core-jvm", version.ref = "kotest-version" } kotlinx-benchmark-runtime = { module = "org.jetbrains.kotlinx:kotlinx-benchmark-runtime", version.ref = "kotlinx-benchmark-version" } kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-serialization-version" } +kotlinx-serialization-cbor = { module = "org.jetbrains.kotlinx:kotlinx-serialization-cbor", version.ref = "kotlinx-serialization-version" } docker-core = { module = "com.github.docker-java:docker-java-core", version.ref = "docker-java-version" } docker-transport-zerodep = { module = "com.github.docker-java:docker-java-transport-zerodep", version.ref = "docker-java-version" } diff --git a/settings.gradle.kts b/settings.gradle.kts index 2c8d3df112..2738eb0d29 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -120,6 +120,7 @@ include(":tests:codegen:paginator-tests") include(":tests:codegen:serde-tests") include(":tests:codegen:serde-codegen-support") include(":tests:codegen:waiter-tests") +include(":tests:codegen:service-codegen-tests") include(":tests:integration:slf4j-1x-consumer") include(":tests:integration:slf4j-2x-consumer") include(":tests:integration:slf4j-hybrid-consumer") diff --git a/tests/codegen/service-codegen-tests/build.gradle.kts b/tests/codegen/service-codegen-tests/build.gradle.kts new file mode 100644 index 0000000000..7732fb1e40 --- /dev/null +++ b/tests/codegen/service-codegen-tests/build.gradle.kts @@ -0,0 +1,69 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +import aws.sdk.kotlin.gradle.dsl.skipPublishing + +plugins { + id(libs.plugins.kotlin.jvm.get().pluginId) + alias(libs.plugins.kotlinx.serialization) + alias(libs.plugins.aws.kotlin.repo.tools.smithybuild) +} + +skipPublishing() + +val optinAnnotations = listOf("kotlin.RequiresOptIn") +kotlin.sourceSets.all { + optinAnnotations.forEach { languageSettings.optIn(it) } +} + +// Create a task to run the DefaultServiceGeneratorTestKt file +val runServiceGenerator by tasks.registering(JavaExec::class) { + group = "verification" + description = "Run the DefaultServiceGeneratorTestKt file" + classpath = sourceSets["main"].runtimeClasspath + mainClass.set("com.test.DefaultServiceGeneratorTestKt") +} + +tasks.test { + dependsOn(runServiceGenerator) + useJUnitPlatform() + testLogging { + events("passed", "skipped", "failed") + showStandardStreams = true + } +} + +kotlin { + compilerOptions { + freeCompilerArgs.addAll( + "-opt-in=kotlin.io.path.ExperimentalPathApi", + "-opt-in=kotlinx.serialization.ExperimentalSerializationApi", + ) + } +} + +dependencies { + + compileOnly(project(":codegen:smithy-kotlin-codegen")) + + implementation(project(":codegen:smithy-kotlin-codegen")) + implementation(project(":codegen:smithy-aws-kotlin-codegen")) + implementation(project(":codegen:smithy-kotlin-codegen-testutils")) + + implementation(libs.kotlinx.serialization.json) + implementation(libs.kotlinx.serialization.cbor) + + testImplementation(libs.junit.jupiter) + testImplementation(libs.kotest.assertions.core.jvm) + testImplementation(libs.kotlin.test) + testImplementation(libs.kotlin.test.junit5) + testImplementation(project(":codegen:smithy-kotlin-codegen-testutils")) + testImplementation(project(":codegen:smithy-kotlin-codegen")) + testImplementation(project(":codegen:smithy-aws-kotlin-codegen")) + + testImplementation(gradleTestKit()) + + testImplementation(libs.kotlinx.serialization.json) + testImplementation(libs.kotlinx.serialization.cbor) +} diff --git a/tests/codegen/service-codegen-tests/model/service-generator-test.smithy b/tests/codegen/service-codegen-tests/model/service-generator-test.smithy new file mode 100644 index 0000000000..ef630ccc8f --- /dev/null +++ b/tests/codegen/service-codegen-tests/model/service-generator-test.smithy @@ -0,0 +1,70 @@ +$version: "2.0" + +namespace com.test + +use smithy.protocols#rpcv2Cbor + +@rpcv2Cbor +@httpBearerAuth +service ServiceGeneratorTest { + version: "1.0.0" + operations: [ + PostTest, + AuthTest, + ErrorTest, + ] +} + + +@http(method: "POST", uri: "/post", code: 201) +@auth([]) +operation PostTest { + input: PostTestInput + output: PostTestOutput +} + +@input +structure PostTestInput { + input1: String + input2: Integer +} + +@output +structure PostTestOutput { + output1: String + output2: Integer +} + +@http(method: "POST", uri: "/auth", code: 201) +operation AuthTest { + input: AuthTestInput + output: AuthTestOutput +} + +@input +structure AuthTestInput { + input1: String +} + +@output +structure AuthTestOutput { + output1: String +} + +@http(method: "POST", uri: "/error", code: 200) +operation ErrorTest { + input: ErrorTestInput + output: ErrorTestOutput +} + +@input +structure ErrorTestInput { + input1: String +} + +@output +structure ErrorTestOutput { + output1: String +} + + diff --git a/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt new file mode 100644 index 0000000000..fb74aaeba8 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/main/kotlin/com/test/DefaultServiceGeneratorTest.kt @@ -0,0 +1,108 @@ +package com.test + +import software.amazon.smithy.build.FileManifest +import software.amazon.smithy.build.PluginContext +import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin +import software.amazon.smithy.model.loader.ModelAssembler +import software.amazon.smithy.model.node.Node +import software.amazon.smithy.model.node.ObjectNode +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +internal fun main() { + val modelPath: Path = Paths.get("model", "service-generator-test.smithy") + val defaultModel = ModelAssembler() + .discoverModels() + .addImport(modelPath) + .assemble() + .unwrap() + val serviceName = "ServiceGeneratorTest" + val packageName = "com.test" + + val packagePath = packageName.replace('.', '/') + + val settings: ObjectNode = ObjectNode.builder() + .withMember("service", Node.from("$packageName#$serviceName")) + .withMember( + "package", + ObjectNode.builder() + .withMember("name", Node.from(packageName)) + .withMember("version", Node.from("1.0.0")) + .build(), + ) + .withMember( + "build", + ObjectNode.builder() + .withMember("rootProject", true) + .withMember("generateServiceProject", true) + .withMember( + "optInAnnotations", + Node.arrayNode( + Node.from("aws.smithy.kotlin.runtime.InternalApi"), + Node.from("kotlinx.serialization.ExperimentalSerializationApi"), + ), + ) + .build(), + ) + .withMember( + "serviceStub", + ObjectNode.builder().withMember("framework", Node.from("ktor")).build(), + ) + .build() + val outputDir: Path = Paths.get("build", "generated-service").also { Files.createDirectories(it) } + val manifest: FileManifest = FileManifest.create(outputDir) + + val context: PluginContext = PluginContext.builder() + .model(defaultModel) + .fileManifest(manifest) + .settings(settings) + .build() + KotlinCodegenPlugin().execute(context) + + val postTestOperation = """ + package $packageName.operations + + import $packageName.model.PostTestRequest + import $packageName.model.PostTestResponse + + public fun handlePostTestRequest(req: PostTestRequest): PostTestResponse { + val response = PostTestResponse.Builder() + val input1 = req.input1 ?: "" + val input2 = req.input2 ?: 0 + response.output1 = input1 + " world!" + response.output2 = input2 + 1 + return response.build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/PostTestOperation.kt", postTestOperation) + + val errorTestOperation = """ + package $packageName.operations + + import $packageName.model.ErrorTestRequest + import $packageName.model.ErrorTestResponse + + public fun handleErrorTestRequest(req: ErrorTestRequest): ErrorTestResponse { + val variable: String? = null + val error = variable!!.length + return ErrorTestResponse.Builder().build() + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/operations/ErrorTestOperation.kt", errorTestOperation) + + val bearerValidation = """ + package $packageName.auth + + public fun bearerValidation(token: String): UserPrincipal? { + if (token == "correctToken") return UserPrincipal("Authenticated User") else return null + } + """.trimIndent() + manifest.writeFile("src/main/kotlin/$packagePath/auth/Validation.kt", bearerValidation) + + val settingGradleKts = """ + rootProject.name = "generated-project" + includeBuild("../../../../../") + """.trimIndent() + manifest.writeFile("settings.gradle.kts", settingGradleKts) +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt new file mode 100644 index 0000000000..91438e2a89 --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceFileTest.kt @@ -0,0 +1,40 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import java.nio.file.Path +import java.nio.file.Paths +import kotlin.io.path.exists +import kotlin.test.Test +import kotlin.test.assertTrue + +class ServiceFileTest { + val packageName = "com.test" + val packagePath = packageName.replace('.', '/') + + val projectDir: Path = Paths.get("build/generated-service") + + @Test + fun `generates service and all necessary files`() { + assertTrue(projectDir.resolve("build.gradle.kts").exists()) + assertTrue(projectDir.resolve("settings.gradle.kts").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/Main.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/Routing.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/config/ServiceFrameworkConfig.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/framework/ServiceFramework.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/plugins/ContentTypeGuard.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/plugins/ErrorHandler.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/utils/Logging.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/auth/Authentication.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/auth/Validation.kt").exists()) + + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/model/PostTestRequest.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/model/PostTestResponse.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/serde/PostTestOperationSerializer.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/serde/PostTestOperationDeserializer.kt").exists()) + assertTrue(projectDir.resolve("src/main/kotlin/$packagePath/operations/PostTestOperation.kt").exists()) + } +} diff --git a/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceGeneratorTest.kt b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceGeneratorTest.kt new file mode 100644 index 0000000000..60be43bb3c --- /dev/null +++ b/tests/codegen/service-codegen-tests/src/test/kotlin/com/test/ServiceGeneratorTest.kt @@ -0,0 +1,500 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.test + +import kotlinx.serialization.Serializable +import kotlinx.serialization.cbor.Cbor +import org.gradle.testkit.runner.GradleRunner +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import java.io.IOException +import java.net.ServerSocket +import java.net.Socket +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths +import java.util.concurrent.TimeUnit +import kotlin.io.path.exists +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertTrue +import kotlin.test.fail + +@Serializable +data class MalformedPostTestRequest( + val input1: Int, + val input2: String, +) + +@Serializable +data class PostTestRequest( + val input1: String, + val input2: Int, +) + +@Serializable +data class PostTestResponse( + val output1: String? = null, + val output2: Int? = null, +) + +@Serializable +data class AuthTestRequest( + val input1: String, +) + +@Serializable +data class ErrorTestRequest( + val input1: String, +) + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class ServiceGeneratorTest { + val closeGracePeriodMillis: Long = 5_000L + val closeTimeoutMillis: Long = 1_000L + val requestBodyLimit: Long = 10L * 1024 * 1024 + val port: Int = ServerSocket(0).use { it.localPort } + + val baseUrl = "http://localhost:$port" + + val projectDir: Path = Paths.get("build/generated-service") + + private lateinit var proc: Process + + @BeforeAll + fun boot() { + proc = startService("netty", port, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit) + val ready = waitForPort(port, 180) + assertTrue(ready, "Service did not start within 180 s") + } + + @AfterAll + fun shutdown() = cleanupService(proc) + + @Test + fun `checks service with netty engine`() { + val nettyPort: Int = ServerSocket(0).use { it.localPort } + val nettyProc = startService("netty", nettyPort, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit) + val ready = waitForPort(nettyPort, 180) + assertTrue(ready, "Service did not start within 180 s") + cleanupService(nettyProc) + } + + @Test + fun `checks service with cio engine`() { + val cioPort: Int = ServerSocket(0).use { it.localPort } + val cioProc = startService("cio", cioPort, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit) + val ready = waitForPort(cioPort, 180) + assertTrue(ready, "Service did not start within 180 s") + cleanupService(cioProc) + } + + @Test + fun `checks service with jetty jakarta engine`() { + val jettyPort: Int = ServerSocket(0).use { it.localPort } + val jettyProc = startService("jetty-jakarta", jettyPort, closeGracePeriodMillis, closeTimeoutMillis, requestBodyLimit) + val ready = waitForPort(jettyPort, 180) + assertTrue(ready, "Service did not start within 180 s") + cleanupService(jettyProc) + } + + @Test + fun `checks correct POST request`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + + val body = cbor.decodeFromByteArray( + PostTestResponse.serializer(), + response.body(), + ) + + assertEquals("Hello world!", body.output1) + assertEquals(input2 + 1, body.output2) + } + + @Test + fun `checks unhandled runtime exception in handler`() { + val cbor = Cbor { } + val input1 = "Hello" + val requestBytes = cbor.encodeToByteArray( + ErrorTestRequest.serializer(), + ErrorTestRequest(input1), + ) + + val response = sendRequest( + "$baseUrl/error", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + "correctToken", + ) + assertIs>(response) + assertEquals(500, response.statusCode(), "Expected 500") + } + + @Test + fun `checks wrong content type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/json", + "application/cbor", + ) + assertIs>(response) + assertEquals(415, response.statusCode(), "Expected 415") + } + + @Test + fun `checks missing content type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + acceptType = "application/cbor", + ) + assertIs>(response) + assertEquals(415, response.statusCode(), "Expected 415") + } + + @Test + fun `checks wrong accept type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/json", + ) + assertIs>(response) + assertEquals(406, response.statusCode(), "Expected 406") + } + + @Test + fun `checks missing accept type`() { + val cbor = Cbor { } + val input1 = "Hello" + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks authentication with correct bearer token`() { + val cbor = Cbor { } + val input1 = "Hello" + val requestBytes = cbor.encodeToByteArray( + AuthTestRequest.serializer(), + AuthTestRequest(input1), + ) + + val response = sendRequest( + "$baseUrl/auth", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + "correctToken", + ) + assertIs>(response) + assertEquals(201, response.statusCode(), "Expected 201") + } + + @Test + fun `checks authentication with wrong bearer token`() { + val cbor = Cbor { } + val input1 = "Hello" + val requestBytes = cbor.encodeToByteArray( + AuthTestRequest.serializer(), + AuthTestRequest(input1), + ) + + val response = sendRequest( + "$baseUrl/auth", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + "wrongToken", + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + } + + @Test + fun `checks authentication without bearer token`() { + val cbor = Cbor { } + val input1 = "Hello" + val requestBytes = cbor.encodeToByteArray( + AuthTestRequest.serializer(), + AuthTestRequest(input1), + ) + + val response = sendRequest( + "$baseUrl/auth", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(401, response.statusCode(), "Expected 401") + } + + @Test + fun `checks malformed input`() { + val cbor = Cbor { } + val input1 = 123 + val input2 = "Hello" + val requestBytes = cbor.encodeToByteArray( + MalformedPostTestRequest.serializer(), + MalformedPostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(400, response.statusCode(), "Expected 400") + } + + @Test + fun `checks route not found`() { + val requestBytes = ByteArray(0) + val response = sendRequest( + "$baseUrl/does-not-exist", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(404, response.statusCode(), "Expected 404") + } + + @Test + fun `checks method not allowed`() { + val cbor = Cbor { } + val input1 = 123 + val input2 = "Hello" + val requestBytes = cbor.encodeToByteArray( + MalformedPostTestRequest.serializer(), + MalformedPostTestRequest(input1, input2), + ) + + val response = sendRequest( + "$baseUrl/post", + "PUT", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(405, response.statusCode(), "Expected 405") + } + + @Test + fun `checks request body limit`() { + val cbor = Cbor { } + val overLimitPayload = "x".repeat(requestBodyLimit.toInt() + 1) + val input2 = 617 + val requestBytes = cbor.encodeToByteArray( + PostTestRequest.serializer(), + PostTestRequest(overLimitPayload, input2), + ) + require(requestBytes.size > 10 * 1024 * 1024) + + val response = sendRequest( + "$baseUrl/post", + "POST", + requestBytes, + "application/cbor", + "application/cbor", + ) + assertIs>(response) + assertEquals(413, response.statusCode(), "Expected 413") + } +} + +internal fun ServiceGeneratorTest.startService( + engineFactory: String = "netty", + port: Int = 8080, + closeGracePeriodMillis: Long = 1000, + closeTimeoutMillis: Long = 1000, + requestBodyLimit: Long = 10L * 1024 * 1024, +): Process { + if (!Files.exists(projectDir.resolve("gradlew"))) { + GradleRunner.create() + .withProjectDir(projectDir.toFile()) + .withArguments( + "wrapper", + "--quiet", + ) + .build() + } + val isWindows = System.getProperty("os.name").startsWith("Windows", ignoreCase = true) + val gradleCmd = if (isWindows) "gradlew.bat" else "./gradlew" + val baseCmd = if (isWindows) listOf("cmd", "/c", gradleCmd) else listOf(gradleCmd) + + return ProcessBuilder( + baseCmd + listOf( + "--no-daemon", + "--quiet", + "run", + "--args=--engineFactory $engineFactory " + + "--port $port " + + "--closeGracePeriodMillis ${closeGracePeriodMillis.toInt()} " + + "--closeTimeoutMillis ${closeTimeoutMillis.toInt()} " + + "--requestBodyLimit $requestBodyLimit", + ), + ) + .directory(projectDir.toFile()) + .redirectErrorStream(true) + .start() +} + +internal fun ServiceGeneratorTest.cleanupService(proc: Process) { + val gracefulWindow = closeGracePeriodMillis + closeTimeoutMillis + val okExitCodes = if (isWindows()) { + setOf(0, 1, 143, -1, -1073741510) + } else { + setOf(0, 143) + } + + try { + proc.destroy() + val exited = proc.waitFor(gracefulWindow, TimeUnit.MILLISECONDS) + + if (!exited) { + proc.destroyForcibly() + fail("Service did not shut down within $gracefulWindow ms") + } + + assertTrue( + proc.exitValue() in okExitCodes, + "Service exited with ${proc.exitValue()} – shutdown not graceful?", + ) + } catch (e: Exception) { + proc.destroyForcibly() + throw e + } +} + +private fun isWindows() = System.getProperty("os.name").lowercase().contains("windows") + +internal fun waitForPort(port: Int, timeoutSec: Long = 180): Boolean { + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toNanos(timeoutSec) + while (System.currentTimeMillis() < deadline) { + try { + Socket("localhost", port).use { + return true // Port is available + } + } catch (e: IOException) { + Thread.sleep(100) + } + } + return false +} + +internal fun sendRequest( + url: String, + method: String, + data: Any? = null, + contentType: String? = null, + acceptType: String? = null, + bearerToken: String? = null, +): HttpResponse<*> { + val client = HttpClient.newHttpClient() + + val bodyPublisher = when (data) { + null -> HttpRequest.BodyPublishers.noBody() + is ByteArray -> HttpRequest.BodyPublishers.ofByteArray(data) + is String -> HttpRequest.BodyPublishers.ofString(data) + else -> throw IllegalArgumentException( + "Unsupported body type: ${data::class.qualifiedName}", + ) + } + + val request = HttpRequest.newBuilder() + .uri(URI.create(url)) + .apply { + contentType?.let { header("Content-Type", it) } + acceptType?.let { header("Accept", it) } + bearerToken?.let { header("Authorization", "Bearer $it") } + } + .method(method, bodyPublisher) + .build() + + val bodyHandler = when { + acceptType?.contains("json", ignoreCase = true) == true || + acceptType?.startsWith("text", ignoreCase = true) == true + -> HttpResponse.BodyHandlers.ofString() + else -> HttpResponse.BodyHandlers.ofByteArray() + } + + return client.send(request, bodyHandler) +}