diff --git a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts index 5994aa8e4..77e18b3a8 100644 --- a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts +++ b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts @@ -5,7 +5,6 @@ import org.jetbrains.kotlin.gradle.dsl.KotlinCommonCompilerOptions import org.jetbrains.kotlin.gradle.dsl.KotlinProjectExtension import org.jetbrains.kotlin.gradle.dsl.KotlinVersion -import org.jetbrains.kotlin.gradle.plugin.KotlinCompilation import util.withKotlinJvmExtension import util.withKotlinKmpExtension diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt new file mode 100644 index 000000000..f9d07f017 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -0,0 +1,13 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.pb + +import kotlinx.rpc.grpc.utils.BitSet +import kotlinx.rpc.internal.utils.InternalRpcApi + +@InternalRpcApi +public abstract class InternalMessage(fieldsWithPresence: Int) { + public val presenceMask: BitSet = BitSet(fieldsWithPresence) +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt new file mode 100644 index 000000000..85520ca04 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt @@ -0,0 +1,68 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.utils + +import kotlinx.rpc.internal.utils.InternalRpcApi + +/** + * A fixed-sized vector of bits, allowing one to set/clear/read bits from it by a bit index. + */ +@InternalRpcApi +public class BitSet(public val size: Int) { + private val data: LongArray = LongArray((size + 63) ushr 6) + + /** Sets the bit at [index] to 1. */ + public operator fun set(index: Int, value: Boolean) { + if (!value) return clear(index) + require(index in 0 until size) { "Index $index out‑of‑bounds for length $size" } + val word = index ushr 6 + val mask = 1L shl (index and 63) + data[word] = data[word] or mask + } + + /** Clears the bit at [index] (sets to 0). */ + public fun clear(index: Int) { + require(index >= 0 && index < size) { "Index $index out of bounds for length $size" } + val word = index ushr 6 + data[word] = data[word] and (1L shl (index and 63)).inv() + } + + /** Returns true if the bit at [index] is set. */ + public operator fun get(index: Int): Boolean { + require(index >= 0 && index < size) { "Index $index out of bounds for length $size" } + val word = index ushr 6 + return (data[word] ushr (index and 63) and 1L) != 0L + } + + /** Clears all bits. */ + public fun clearAll() { + data.fill(0L) + } + + /** Returns the number of bits set to 1. */ + public fun cardinality(): Int { + var sum = 0 + for (w in data) { + sum += w.countOneBits() + } + return sum + } + + /** Returns true if all bits are set. */ + public fun allSet(): Boolean { + val fullWords = size ushr 6 + // check full 64-bit words + for (i in 0 until fullWords) { + if (data[i] != -1L) return false + } + // check leftover bits + val rem = size and 63 + if (rem != 0) { + val mask = (-1L ushr (64 - rem)) + if (data[fullWords] != mask) return false + } + return true + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt new file mode 100644 index 000000000..0644f91c6 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt @@ -0,0 +1,305 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlinx.rpc.grpc.utils.BitSet +import kotlin.test.* + +class BitSetTest { + + @Test + fun testConstructor() { + // Test with size 0 + val bitSet0 = BitSet(0) + assertEquals(0, bitSet0.size) + assertEquals(0, bitSet0.cardinality()) + + // Test with small size + val bitSet10 = BitSet(10) + assertEquals(10, bitSet10.size) + assertEquals(0, bitSet10.cardinality()) + + // Test with size that spans multiple words + val bitSet100 = BitSet(100) + assertEquals(100, bitSet100.size) + assertEquals(0, bitSet100.cardinality()) + + // Test with size at word boundary + val bitSet64 = BitSet(64) + assertEquals(64, bitSet64.size) + assertEquals(0, bitSet64.cardinality()) + + // Test with size just over word boundary + val bitSet65 = BitSet(65) + assertEquals(65, bitSet65.size) + assertEquals(0, bitSet65.cardinality()) + } + + @Test + fun testSetAndGet() { + val bitSet = BitSet(100) + + // Initially all bits should be unset + for (i in 0 until 100) { + assertFalse(bitSet[i], "Bit $i should be initially unset") + } + + // Set some bits + bitSet[0] = true + bitSet[1] = true + bitSet[63] = true + bitSet[64] = true + bitSet[99] = true + + // Verify the bits are set + assertTrue(bitSet[0], "Bit 0 should be set") + assertTrue(bitSet[1], "Bit 1 should be set") + assertTrue(bitSet[63], "Bit 63 should be set") + assertTrue(bitSet[64], "Bit 64 should be set") + assertTrue(bitSet[99], "Bit 99 should be set") + + // Verify other bits are still unset + assertFalse(bitSet[2], "Bit 2 should be unset") + assertFalse(bitSet[62], "Bit 62 should be unset") + assertFalse(bitSet[65], "Bit 65 should be unset") + assertFalse(bitSet[98], "Bit 98 should be unset") + } + + @Test + fun testClear() { + val bitSet = BitSet(100) + + // Set all bits + for (i in 0 until 100) { + bitSet[i] = true + } + + // Verify all bits are set + for (i in 0 until 100) { + assertTrue(bitSet[i], "Bit $i should be set") + } + + // Clear some bits + bitSet[0] = false + bitSet[1] = false + bitSet[63] = false + bitSet[64] = false + bitSet[99] = false + + // Verify the bits are cleared + assertFalse(bitSet[0], "Bit 0 should be cleared") + assertFalse(bitSet[1], "Bit 1 should be cleared") + assertFalse(bitSet[63], "Bit 63 should be cleared") + assertFalse(bitSet[64], "Bit 64 should be cleared") + assertFalse(bitSet[99], "Bit 99 should be cleared") + + // Verify other bits are still set + assertTrue(bitSet[2], "Bit 2 should still be set") + assertTrue(bitSet[62], "Bit 62 should still be set") + assertTrue(bitSet[65], "Bit 65 should still be set") + assertTrue(bitSet[98], "Bit 98 should still be set") + } + + @Test + fun testClearAll() { + val bitSet = BitSet(100) + + // Set all bits + for (i in 0 until 100) { + bitSet[i] = true + } + + // Verify all bits are set + for (i in 0 until 100) { + assertTrue(bitSet[i], "Bit $i should be set") + } + + // Clear all bits + bitSet.clearAll() + + // Verify all bits are cleared + for (i in 0 until 100) { + assertFalse(bitSet[i], "Bit $i should be cleared after clearAll") + } + } + + @Test + fun testCardinality() { + val bitSet = BitSet(100) + assertEquals(0, bitSet.cardinality(), "Initial cardinality should be 0") + + // Set some bits + bitSet[0] = true + assertEquals(1, bitSet.cardinality(), "Cardinality should be 1 after setting 1 bit") + + bitSet[63] = true + assertEquals(2, bitSet.cardinality(), "Cardinality should be 2 after setting 2 bits") + + bitSet[64] = true + assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after setting 3 bits") + + bitSet[99] = true + assertEquals(4, bitSet.cardinality(), "Cardinality should be 4 after setting 4 bits") + + // Clear a bit + bitSet.clear(0) + assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after clearing 1 bit") + + // Set a bit that's already set + bitSet[63] = true + assertEquals(3, bitSet.cardinality(), "Cardinality should still be 3 after setting an already set bit") + + // Clear all bits + bitSet.clearAll() + assertEquals(0, bitSet.cardinality(), "Cardinality should be 0 after clearAll") + } + + @Test + fun testAllSet() { + // Test with empty BitSet + val emptyBitSet = BitSet(0) + assertTrue(emptyBitSet.allSet(), "Empty BitSet should return true for allSet") + + // Test with small BitSet + val smallBitSet = BitSet(5) + assertFalse(smallBitSet.allSet(), "New BitSet should return false for allSet") + + smallBitSet[0] = true + smallBitSet[1] = true + smallBitSet[2] = true + smallBitSet[3] = true + smallBitSet[4] = true + assertTrue(smallBitSet.allSet(), "BitSet with all bits set should return true for allSet") + + smallBitSet.clear(2) + assertFalse(smallBitSet.allSet(), "BitSet with one bit cleared should return false for allSet") + + // Test with BitSet that spans multiple words + val largeBitSet = BitSet(100) + assertFalse(largeBitSet.allSet(), "New large BitSet should return false for allSet") + + for (i in 0 until 100) { + largeBitSet[i] = true + } + assertTrue(largeBitSet.allSet(), "Large BitSet with all bits set should return true for allSet") + + largeBitSet.clear(63) + assertFalse(largeBitSet.allSet(), "Large BitSet with one bit cleared should return false for allSet") + + // Test with BitSet at word boundary + val wordBoundaryBitSet = BitSet(64) + assertFalse(wordBoundaryBitSet.allSet(), "New word boundary BitSet should return false for allSet") + + for (i in 0 until 64) { + wordBoundaryBitSet[i] = true + } + assertTrue(wordBoundaryBitSet.allSet(), "Word boundary BitSet with all bits set should return true for allSet") + } + + @Test + fun testEdgeCases() { + val bitSet = BitSet(100) + + // Test setting and getting at boundaries + bitSet[0] = true + assertTrue(bitSet[0], "Should be able to set and get bit 0") + + bitSet[99] = true + assertTrue(bitSet[99], "Should be able to set and get bit at size-1") + + // Test clearing at boundaries + bitSet.clear(0) + assertFalse(bitSet[0], "Should be able to clear bit 0") + + bitSet.clear(99) + assertFalse(bitSet[99], "Should be able to clear bit at size-1") + + // Test out of bounds access + assertFailsWith { + bitSet[100] = true + } + + assertFailsWith { + bitSet.clear(100) + } + + assertFailsWith { + bitSet[100] + } + + assertFailsWith { + bitSet[-1] = true + } + + assertFailsWith { + bitSet.clear(-1) + } + + assertFailsWith { + bitSet[-1] + } + } + + @Test + fun testWordBoundaries() { + // Test BitSet with size at word boundaries + for (size in listOf(63, 64, 65, 127, 128, 129)) { + val bitSet = BitSet(size) + + // Set all bits + for (i in 0 until size) { + bitSet[i] = true + } + + // Verify all bits are set + for (i in 0 until size) { + assertTrue(bitSet[i], "Bit $i should be set in BitSet of size $size") + } + + // Verify cardinality + assertEquals(size, bitSet.cardinality(), "Cardinality should equal size for fully set BitSet") + + // Verify allSet + assertTrue(bitSet.allSet(), "allSet should return true for fully set BitSet") + + // Clear all bits + bitSet.clearAll() + + // Verify all bits are cleared + for (i in 0 until size) { + assertFalse(bitSet[i], "Bit $i should be cleared in BitSet of size $size after clearAll") + } + + // Verify cardinality + assertEquals(0, bitSet.cardinality(), "Cardinality should be 0 after clearAll") + + // Verify allSet + assertFalse(bitSet.allSet(), "allSet should return false after clearAll") + } + } + + @Test + fun testLargeCardinality() { + // Test with a large BitSet to verify cardinality calculation + val size = 1000 + val bitSet = BitSet(size) + + // Set every other bit + for (i in 0 until size step 2) { + bitSet[i] = true + } + + // Verify cardinality + assertEquals(size / 2, bitSet.cardinality(), "Cardinality should be half the size when every other bit is set") + + // Set all bits + for (i in 0 until size) { + bitSet[i] = true + } + + // Verify cardinality + assertEquals(size, bitSet.cardinality(), "Cardinality should equal size when all bits are set") + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index f12e932fa..86426cff0 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -5,32 +5,25 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer +import kotlinx.rpc.grpc.internal.MessageCodec import kotlinx.rpc.grpc.test.common.* import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith class ProtosTest { - private fun decodeEncode( - msg: T, - enc: T.(WireEncoder) -> Unit, - dec: (WireDecoder) -> T? - ): T? { - val buffer = Buffer() - val encoder = WireEncoder(buffer) - - msg.enc(encoder) - encoder.flush() - - return WireDecoder(buffer).use { - dec(it) - } + private fun decodeEncode( + msg: M, + codec: MessageCodec + ): M { + val source = codec.encode(msg) + return codec.decode(source) } - @Test fun testAllPrimitiveProto() { - val msg = AllPrimitivesCommon { + val msg = AllPrimitives { int32 = 12 int64 = 1234567890123456789L uint32 = 12345u @@ -48,24 +41,48 @@ class ProtosTest { bytes = byteArrayOf(1, 2, 3) } - val decoded = decodeEncode(msg, { encodeWith(it) }, AllPrimitivesCommon::decodeWith) + val msgObj = msg + + val decoded = decodeEncode(msgObj, AllPrimitivesInternal.CODEC) - assertEquals(msg.double, decoded?.double) + assertEquals(msg.double, decoded.double) } @Test fun testRepeatedProto() { - val msg = RepeatedCommon { - listFixed32 = listOf(1, 2, 3).map { it.toUInt() } - listInt32 = listOf(4, 5, 6) + val msg = Repeated { + listFixed32 = listOf(1, 5, 3).map { it.toUInt() } + listFixed32Packed = listOf(1, 2, 3).map { it.toUInt() } + listInt32 = listOf(4, 7, 6) + listInt32Packed = listOf(4, 5, 6) listString = listOf("a", "b", "c") } - val decoded = decodeEncode(msg, { encodeWith(it) }, RepeatedCommon::decodeWith) + val decoded = decodeEncode(msg, RepeatedInternal.CODEC) + + assertEquals(msg.listInt32, decoded.listInt32) + assertEquals(msg.listFixed32, decoded.listFixed32) + assertEquals(msg.listString, decoded.listString) + } + + @Test + fun testPresenceCheckProto() { + + // Check a missing required field in a user-constructed message + assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { + PresenceCheck {} + } + + // Test missing field during decoding of an encoded message + val buffer = Buffer() + val encoder = WireEncoder(buffer) + encoder.writeFloat(2, 1f) + encoder.flush() - assertEquals(msg.listInt32, decoded?.listInt32) - assertEquals(msg.listFixed32, decoded?.listFixed32) - assertEquals(msg.listString, decoded?.listString) + assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { + PresenceCheckInternal.CODEC.decode(buffer) + } } -} + +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/all_primitives.proto b/grpc/grpc-core/src/commonTest/proto/all_primitives.proto index 14772b74a..99b12cd57 100644 --- a/grpc/grpc-core/src/commonTest/proto/all_primitives.proto +++ b/grpc/grpc-core/src/commonTest/proto/all_primitives.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test.common; -message AllPrimitivesCommon { +message AllPrimitives { double double = 1; float float = 2; int32 int32 = 3; @@ -18,4 +18,4 @@ message AllPrimitivesCommon { optional bool bool = 13; optional string string = 14; optional bytes bytes = 15; -} +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/presence_check.proto b/grpc/grpc-core/src/commonTest/proto/presence_check.proto new file mode 100644 index 000000000..f428ce2c9 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/presence_check.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package kotlinx.rpc.grpc.test.common; + +message PresenceCheck { + required int32 RequiredPresence = 1; + optional float OptionalPresence = 2; +} + diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index 55d1a2a19..e80c7b445 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -2,12 +2,10 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test.common; -message RepeatedCommon { - repeated fixed32 listFixed32 = 1 [packed = true]; - repeated int32 listInt32 = 2 [packed = false]; - repeated string listString = 3; - - message InnerClass { - - } -} +message Repeated { + repeated fixed32 listFixed32 = 1 [packed = false]; + repeated fixed32 listFixed32Packed = 2 [packed = true]; + repeated int32 listInt32 = 3 [packed = false]; + repeated int32 listInt32Packed = 4 [packed = true]; + repeated string listString = 5; +} \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt index b1a6cdeff..de3f31e20 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt @@ -233,6 +233,7 @@ open class CodeGenerator( fun clazz( name: String, modifiers: String = "", + constructorModifiers: String = "", constructorArgs: List> = emptyList(), superTypes: List = emptyList(), annotations: List = emptyList(), @@ -258,8 +259,12 @@ open class CodeGenerator( "$arg$defaultString" } + val constructorModifiersTransformed = if (constructorModifiers.isEmpty()) "" else + " ${constructorModifiers.trim()} constructor " + when { shouldPutArgsOnNewLines && constructorArgsTransformed.isNotEmpty() -> { + append(constructorModifiersTransformed) append("(") newLine() withNextIndent { @@ -271,10 +276,15 @@ open class CodeGenerator( } constructorArgsTransformed.isNotEmpty() -> { + append(constructorModifiersTransformed) append("(") append(constructorArgsTransformed.joinToString(", ")) append(")") } + + constructorModifiersTransformed.isNotEmpty() -> { + append("$constructorModifiersTransformed()") + } } val superString = superTypes diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index fadeaebcb..8c79cca77 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -11,6 +11,9 @@ import kotlinx.rpc.protobuf.model.* import org.slf4j.Logger private const val RPC_INTERNAL_PACKAGE_SUFFIX = "_rpc_internal" +private const val MSG_INTERNAL_SUFFIX = "Internal" +private const val PB_PKG = "kotlinx.rpc.grpc.pb" +private const val INTERNAL_RPC_API_ANNO = "kotlinx.rpc.internal.utils.InternalRpcApi" class ModelToKotlinCommonGenerator( private val model: Model, @@ -67,7 +70,7 @@ class ModelToKotlinCommonGenerator( this@generateInternalKotlinFile.packageName.safeFullName() .packageNameSuffixed(RPC_INTERNAL_PACKAGE_SUFFIX) - fileOptIns = listOf("ExperimentalRpcApi::class", "InternalRpcApi::class") + fileOptIns = listOf("ExperimentalRpcApi::class", "$INTERNAL_RPC_API_ANNO::class") dependencies.forEach { dependency -> importPackage(dependency.packageName.safeFullName()) @@ -75,11 +78,6 @@ class ModelToKotlinCommonGenerator( generateInternalDeclaredEntities(this@generateInternalKotlinFile) - import("kotlinx.rpc.internal.utils.*") - import("kotlinx.coroutines.flow.*") - import("kotlinx.rpc.grpc.pb.*") - - additionalInternalImports.forEach { import(it) } @@ -100,9 +98,14 @@ class ModelToKotlinCommonGenerator( fileDeclaration.messageDeclarations.forEach { generateMessageConstructor(it) - generateMessageDecoder(it) + } + + fileDeclaration.messageDeclarations.forEach { + generateRequiredCheck(it) generateMessageEncoder(it) + generateMessageDecoder(it) } + } private fun MessageDeclaration.fields() = actualFields.map { @@ -140,11 +143,19 @@ class ModelToKotlinCommonGenerator( @Suppress("detekt.CyclomaticComplexMethod") private fun CodeGenerator.generateInternalMessage(declaration: MessageDeclaration) { + val internalClassName = declaration.internalClassName() clazz( - name = "${declaration.name.simpleName}Builder", + name = internalClassName, + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), declarationType = DeclarationType.Class, - superTypes = listOf(declaration.name.safeFullName()), + superTypes = listOf( + declaration.name.safeFullName(), + "$PB_PKG.InternalMessage(fieldsWithPresence = ${declaration.presenceMaskSize})" + ), ) { + + generatePresenceIndicesObject(declaration) + declaration.fields().forEach { (fieldDeclaration, field) -> val value = when { field.nullable -> { @@ -162,32 +173,84 @@ class ModelToKotlinCommonGenerator( } code("override var $fieldDeclaration $value") + if (field.presenceIdx != null) { + scope("set(value) ") { + code("presenceMask[PresenceIndices.${field.name}] = true") + code("field = value") + } + } newLine() } declaration.nestedDeclarations.forEach { nested -> generateInternalMessage(nested) } + + scope("companion object") { + generateCodecObject(declaration) + } + } + } + + private fun CodeGenerator.generatePresenceIndicesObject(declaration: MessageDeclaration) { + if (declaration.presenceMaskSize == 0) { + return + } + scope("private object PresenceIndices") { + declaration.fields().forEach { (_, field) -> + if (field.presenceIdx != null) { + code("const val ${field.name} = ${field.presenceIdx}") + newLine() + } + } + } + } + + private fun CodeGenerator.generateCodecObject(declaration: MessageDeclaration) { + val msgFqName = declaration.name.safeFullName() + val downCastErrorStr = + "\${value::class.simpleName} implements ${msgFqName}, which is prohibited." + val sourceFqName = "kotlinx.io.Source" + val bufferFqName = "kotlinx.io.Buffer" + scope("val CODEC = object : kotlinx.rpc.grpc.internal.MessageCodec<$msgFqName>") { + function("encode", modifiers = "override", args = "value: $msgFqName", returnType = sourceFqName) { + code("val msg = value as? ${declaration.internalClassFullName()} ?: error { \"$downCastErrorStr\" }") + code("val buffer = $bufferFqName()") + code("val encoder = $PB_PKG.WireEncoder(buffer)") + code("msg.encodeWith(encoder)") + code("encoder.flush()") + code("return buffer") + } + + function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { + scope("$PB_PKG.WireDecoder(stream as $bufferFqName).use") { + code("return ${declaration.internalClassFullName()}.decodeWith(it)") + } + } } } private fun CodeGenerator.generateMessageConstructor(declaration: MessageDeclaration) = function( name = "invoke", modifiers = "operator", - args = "body: ${declaration.name.safeFullName("Builder")}.() -> Unit", + args = "body: ${declaration.internalClassFullName()}.() -> Unit", contextReceiver = "${declaration.name.safeFullName()}.Companion", returnType = declaration.name.safeFullName(), ) { - code("return ${declaration.name.safeFullName("Builder")}().apply(body)") + code("val msg = ${declaration.internalClassFullName()}().apply(body)") + // check if the user set all required fields + code("msg.checkRequiredFields()") + code("return msg") } private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( name = "decodeWith", - args = "decoder: WireDecoder", - contextReceiver = "${declaration.name.safeFullName()}.Companion", - returnType = "${declaration.name.safeFullName()}?" + modifiers = "private", + args = "decoder: $PB_PKG.WireDecoder", + contextReceiver = "${declaration.internalClassFullName()}.Companion", + returnType = declaration.internalClassName() ) { - code("val msg = ${declaration.name.safeFullName("Builder")}()") + code("val msg = ${declaration.internalClassFullName()}()") whileBlock("!decoder.hadError()") { code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") whenBlock { @@ -197,9 +260,11 @@ class ModelToKotlinCommonGenerator( } ifBranch( condition = "decoder.hadError()", - ifBlock = { code("return null") } + ifBlock = { code("error(\"Error during decoding of ${declaration.name.simpleName}\")") } ) + code("msg.checkRequiredFields()") + // TODO: Make a lists immutable code("return msg") } @@ -208,16 +273,16 @@ class ModelToKotlinCommonGenerator( val encFuncName = field.type.decodeEncodeFuncName() val assignment = "msg.${field.name} =" when (val fieldType = field.type) { - is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.${field.type.wireType.name}") { + is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") { code("$assignment decoder.read$encFuncName()") } is FieldType.List -> if (field.dec.isPacked) { - whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.LENGTH_DELIMITED") { + whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { code("$assignment decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") } } else { - whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.LENGTH_DELIMITED") { + whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") { code("(msg.${field.name} as ArrayList).add(decoder.read${fieldType.value.decodeEncodeFuncName()}())") } } @@ -230,9 +295,14 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", - args = "encoder: WireEncoder", - contextReceiver = declaration.name.safeFullName(), + modifiers = "private", + args = "encoder: $PB_PKG.WireEncoder", + contextReceiver = declaration.internalClassFullName(), ) { + if (declaration.fields().isEmpty()) { + code("// no fields to encode") + return@function + } declaration.fields().forEach { (_, field) -> val fieldName = field.name @@ -252,13 +322,13 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.writeValue(variable: String): String { return when (val fieldType = type) { - is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}($number, $variable)" + is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" is FieldType.List -> when { dec.isPacked && packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $variable)" + "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" dec.isPacked && !packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $variable, ${ + "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable, fieldSize = ${ wireSizeCall( variable ) @@ -274,8 +344,33 @@ class ModelToKotlinCommonGenerator( } } + + /** + * Generates a function to check for the presence of all required fields in a message declaration. + */ + private fun CodeGenerator.generateRequiredCheck(declaration: MessageDeclaration) = function( + name = "checkRequiredFields", + modifiers = "private", + contextReceiver = declaration.internalClassFullName(), + ) { + val requiredFields = declaration.actualFields + .filter { it.dec.isRequired } + + if (requiredFields.isEmpty()) { + code("// no fields to check") + return@function + } + + requiredFields.forEach { field -> + ifBranch(condition = "!presenceMask[${field.presenceIdx}]", ifBlock = { + code("error(\"${declaration.name.simpleName} is missing required field: ${field.name}\")") + }) + } + } + + private fun FieldDeclaration.wireSizeCall(variable: String): String { - val sizeFunc = "WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)" + val sizeFunc = "$PB_PKG.WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)" return when (val fieldType = type) { is FieldType.IntegralType -> when { fieldType.wireType == WireType.FIXED32 -> "32" @@ -464,8 +559,17 @@ class ModelToKotlinCommonGenerator( } } } + + private fun MessageDeclaration.internalClassFullName(): String { + return name.safeFullName(MSG_INTERNAL_SUFFIX) + } + + private fun MessageDeclaration.internalClassName(): String { + return name.simpleName + MSG_INTERNAL_SUFFIX + } } + private fun String.packageNameSuffixed(suffix: String): String { return if (isEmpty()) suffix else "$this.$suffix" } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index 3f1702701..a357d5944 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -111,10 +111,18 @@ private fun Descriptors.FileDescriptor.toModel(): FileDeclaration = cached { } private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached { - val regularFields = fields.filter { field -> field.realContainingOneof == null }.map { it.toModel() } + var currPresenceIdx = 0 + val regularFields = fields + // only fields that are not part of a oneOf declaration + .filter { field -> field.realContainingOneof == null } + .map { + val presenceIdx = if (it.hasPresence()) currPresenceIdx++ else null + it.toModel(presenceIdx = presenceIdx) + } return MessageDeclaration( name = fqName(), + presenceMaskSize = currPresenceIdx, actualFields = regularFields, // get all oneof declarations that are not created from an optional in proto3 https://github.com/googleapis/api-linter/issues/1323 oneOfDeclarations = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toModel() }, @@ -125,11 +133,12 @@ private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached { ) } -private fun Descriptors.FieldDescriptor.toModel(): FieldDeclaration = cached { +private fun Descriptors.FieldDescriptor.toModel(presenceIdx: Int? = null): FieldDeclaration = cached { toProto().hasProto3Optional() return FieldDeclaration( name = fqName().simpleName, type = modelType(), + presenceIdx = presenceIdx, doc = null, dec = this, ) diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index 57267380a..8eff14079 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -23,6 +23,7 @@ data class FileDeclaration( data class MessageDeclaration( val name: FqName, + val presenceMaskSize: Int, val actualFields: List, // excludes oneOf fields, but includes oneOf itself val oneOfDeclarations: List, val enumDeclarations: List, @@ -62,12 +63,15 @@ data class FieldDeclaration( val name: String, val type: FieldType, val doc: String?, - val dec: Descriptors.FieldDescriptor + val dec: Descriptors.FieldDescriptor, + // defines the index in the presenceMask of the Message. + // this cannot be the number, as only fields with hasPresence == true are part of the presenceMask + val presenceIdx: Int? = null ) { val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32 // aligns with edition settings and backward compatibility with proto2 and proto3 - val nullable: Boolean = dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() + val nullable: Boolean = dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() && !dec.isRepeated val number: Int = dec.number }