Skip to content

Commit 9c58222

Browse files
committed
grpc-pb: Support recursive messages
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 4d2fce0 commit 9c58222

File tree

8 files changed

+188
-31
lines changed

8 files changed

+188
-31
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,39 @@ package kotlinx.rpc.grpc.pb
66

77
import kotlinx.rpc.grpc.utils.BitSet
88
import kotlinx.rpc.internal.utils.InternalRpcApi
9+
import kotlin.properties.ReadWriteProperty
10+
import kotlin.reflect.KProperty
911

1012
@InternalRpcApi
1113
public abstract class InternalMessage(fieldsWithPresence: Int) {
1214
public val presenceMask: BitSet = BitSet(fieldsWithPresence)
1315
public abstract val _size: Int
16+
}
17+
18+
public class MsgFieldDelegate<T : Any>(
19+
private val presenceIdx: Int? = null,
20+
private val defaultProvider: (() -> T)? = null
21+
) : ReadWriteProperty<InternalMessage, T> {
22+
23+
private var valueSet = false
24+
private var _value: T? = null
25+
26+
override operator fun getValue(thisRef: InternalMessage, property: KProperty<*>): T {
27+
if (!valueSet) {
28+
if (defaultProvider != null) {
29+
_value = defaultProvider.invoke()
30+
valueSet = true
31+
} else {
32+
error("Property ${property.name} not initialized")
33+
}
34+
}
35+
@Suppress("UNCHECKED_CAST")
36+
return _value as T
37+
}
38+
39+
override operator fun setValue(thisRef: InternalMessage, property: KProperty<*>, new: T) {
40+
presenceIdx?.let { thisRef.presenceMask[it] = true }
41+
_value = new
42+
valueSet = true
43+
}
1444
}

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,16 @@ import OuterInternal
1111
import invoke
1212
import kotlinx.io.Buffer
1313
import kotlinx.rpc.grpc.internal.MessageCodec
14-
import kotlinx.rpc.grpc.test.Enum
15-
import kotlinx.rpc.grpc.test.UsingEnum
16-
import kotlinx.rpc.grpc.test.UsingEnumInternal
14+
import kotlinx.rpc.grpc.test.*
1715
import kotlinx.rpc.grpc.test.common.*
18-
import kotlinx.rpc.grpc.test.invoke
1916
import kotlin.test.Test
2017
import kotlin.test.assertEquals
2118
import kotlin.test.assertFailsWith
2219
import kotlin.test.assertNull
2320

2421
class ProtosTest {
2522

26-
private fun <M> decodeEncode(
23+
private fun <M> encodeDecode(
2724
msg: M,
2825
codec: MessageCodec<M>
2926
): M {
@@ -53,7 +50,7 @@ class ProtosTest {
5350

5451
val msgObj = msg
5552

56-
val decoded = decodeEncode(msgObj, AllPrimitivesInternal.CODEC)
53+
val decoded = encodeDecode(msgObj, AllPrimitivesInternal.CODEC)
5754

5855
assertEquals(msg.double, decoded.double)
5956
}
@@ -68,7 +65,7 @@ class ProtosTest {
6865
listString = listOf("a", "b", "c")
6966
}
7067

71-
val decoded = decodeEncode(msg, RepeatedInternal.CODEC)
68+
val decoded = encodeDecode(msg, RepeatedInternal.CODEC)
7269

7370
assertEquals(msg.listInt32, decoded.listInt32)
7471
assertEquals(msg.listFixed32, decoded.listFixed32)
@@ -112,7 +109,7 @@ class ProtosTest {
112109
enum = Enum.ONE_SECOND
113110
}
114111

115-
val decodedMsg = decodeEncode(msg, UsingEnumInternal.CODEC)
112+
val decodedMsg = encodeDecode(msg, UsingEnumInternal.CODEC)
116113
assertEquals(Enum.ONE, decodedMsg.enum)
117114
assertEquals(Enum.ONE_SECOND, decodedMsg.enum)
118115
}
@@ -135,13 +132,13 @@ class ProtosTest {
135132
val msg1 = OneOfMsg {
136133
field = OneOfMsg.Field.Sint(23)
137134
}
138-
val decoded1 = decodeEncode(msg1, OneOfMsgInternal.CODEC)
135+
val decoded1 = encodeDecode(msg1, OneOfMsgInternal.CODEC)
139136
assertEquals(OneOfMsg.Field.Sint(23), decoded1.field)
140137

141138
val msg2 = OneOfMsg {
142139
field = OneOfMsg.Field.Fixed(21u)
143140
}
144-
val decoded2 = decodeEncode(msg2, OneOfMsgInternal.CODEC)
141+
val decoded2 = encodeDecode(msg2, OneOfMsgInternal.CODEC)
145142
assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field)
146143
}
147144

@@ -175,8 +172,40 @@ class ProtosTest {
175172
field = 12345678
176173
}
177174
}
178-
val decoded = decodeEncode(msg, OuterInternal.CODEC)
175+
val decoded = encodeDecode(msg, OuterInternal.CODEC)
179176
assertEquals(msg.inner.field, decoded.inner.field)
180177
}
181178

179+
@Test
180+
fun testRecursiveReqNotSet() {
181+
assertFailsWith<IllegalStateException>("RecursiveReq is missing required field: rec") {
182+
val msg = RecursiveReq {
183+
rec = RecursiveReq {
184+
rec = RecursiveReq {
185+
186+
}
187+
num = 3
188+
}
189+
}
190+
}
191+
}
192+
193+
@Test
194+
fun testRecursive() {
195+
val msg = Recursive {
196+
rec = Recursive {
197+
rec = Recursive {}
198+
num = 3
199+
}
200+
}
201+
202+
assertEquals(null, msg.rec.rec.rec.rec.num)
203+
assertEquals(3, msg.rec.num)
204+
205+
val decoded = encodeDecode(msg, RecursiveInternal.CODEC)
206+
207+
assertEquals(3, decoded.rec.num)
208+
assertEquals(null, decoded.rec.rec.rec.rec.num)
209+
}
210+
182211
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Protocol Buffers - Google's data interchange format
2+
// Copyright 2023 Google LLC. All rights reserved.
3+
//
4+
// Use of this source code is governed by a BSD-style
5+
// license that can be found in the LICENSE file or at
6+
// https://developers.google.com/open-source/licenses/bsd
7+
8+
syntax = "proto2";
9+
10+
package kotlinx.rpc.grpc.test;
11+
12+
message Outer {
13+
message Inner {
14+
message InnerSubMsg {
15+
optional bool flag = 1;
16+
}
17+
18+
enum InnerEnum {
19+
INNER_ENUM_UNSPECIFIED = 0;
20+
INNER_ENUM_FOO = 1;
21+
}
22+
23+
optional double double = 1;
24+
optional float float = 2;
25+
optional int32 int32 = 3;
26+
optional int64 int64 = 4;
27+
optional uint32 uint32 = 5;
28+
optional uint64 uint64 = 6;
29+
optional sint32 sint32 = 7;
30+
optional sint64 sint64 = 8;
31+
optional fixed32 fixed32 = 9;
32+
optional fixed64 fixed64 = 10;
33+
optional sfixed32 sfixed32 = 11;
34+
optional sfixed64 sfixed64 = 12;
35+
optional bool bool = 13;
36+
optional string string = 14;
37+
optional bytes bytes = 15;
38+
optional InnerSubMsg inner_submsg = 16;
39+
optional InnerEnum inner_enum = 17;
40+
repeated int32 repeated_int32 = 18 [packed = true];
41+
repeated InnerSubMsg repeated_inner_submsg = 19;
42+
// map<string, string> string_map = 20;
43+
44+
message SuperInner {
45+
message DuperInner {
46+
message EvenMoreInner {
47+
message CantBelieveItsSoInner {
48+
optional int32 num = 99;
49+
}
50+
51+
enum JustWayTooInner {
52+
JUST_WAY_TOO_INNER_UNSPECIFIED = 0;
53+
}
54+
}
55+
}
56+
}
57+
}
58+
// optional Inner inner = 1;
59+
// optional .kotlinx.rpc.grpc.test.Outer.Inner.SuperInner.DuperInner.EvenMoreInner
60+
// .CantBelieveItsSoInner deep = 2;
61+
//
62+
// optional .kotlinx.rpc.grpc.test.Outer.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner
63+
// deep_enum = 4;
64+
65+
optional NotInside notinside = 3;
66+
}
67+
68+
message NotInside {
69+
optional int32 num = 1;
70+
}
71+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto2";
2+
3+
package kotlinx.rpc.grpc.test;
4+
5+
message RecursiveReq {
6+
required RecursiveReq rec = 1;
7+
optional int32 num = 2;
8+
}
9+
10+
message Recursive {
11+
optional Recursive rec = 1;
12+
optional int32 num = 2;
13+
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,16 @@ class ModelToKotlinCommonGenerator(
181181
"= null"
182182
}
183183

184-
field.type is FieldType.Message -> "= ${field.type.dec.internalClassFullName()}()"
184+
field.type is FieldType.Message ->
185+
"by MsgFieldDelegate(PresenceIndices.${field.name}) { ${field.type.dec.value.internalClassFullName()}() }"
185186

186187
else -> {
187-
"= ${field.type.defaultValue}"
188+
val fieldPresence = if (field.presenceIdx != null) "PresenceIndices.${field.name}" else ""
189+
"by MsgFieldDelegate($fieldPresence) { ${field.type.defaultValue} }"
188190
}
189191
}
190192

191193
code("override var $fieldDeclaration $value")
192-
if (field.presenceIdx != null) {
193-
scope("set(value) ") {
194-
code("presenceMask[PresenceIndices.${field.name}] = true")
195-
code("field = value")
196-
}
197-
}
198194
newLine()
199195
}
200196

@@ -321,11 +317,11 @@ class ModelToKotlinCommonGenerator(
321317
}
322318

323319
is FieldType.Message -> {
324-
val internalClassName = fieldType.dec.internalClassFullName()
320+
val internalClassName = fieldType.dec.value.internalClassFullName()
325321
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") {
326322
// check if the the current sub message object
327323
ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = {
328-
code("$lvalue = ${fieldType.dec.internalClassFullName()}()")
324+
code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()")
329325
})
330326
code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)")
331327
}
@@ -357,7 +353,9 @@ class ModelToKotlinCommonGenerator(
357353
writeFieldValue(field, field.name)
358354
})
359355
} else {
360-
writeFieldValue(field, field.name)
356+
ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = {
357+
writeFieldValue(field, field.name)
358+
})
361359
}
362360
}
363361
}
@@ -380,7 +378,16 @@ class ModelToKotlinCommonGenerator(
380378
})"
381379
)
382380

383-
else -> code("$valueVar.forEach { encoder.write${encFunc!!}($number, it) }")
381+
fieldType.value is FieldType.Message -> scope("$valueVar.forEach") {
382+
code("encoder.writeMessage(fieldNr = ${field.number}, value = it.asInternal()) { encodeWith(it) }")
383+
}
384+
385+
else -> {
386+
require(encFunc != null) { "No encode function for list type: $fieldType" }
387+
scope("$valueVar.forEach") {
388+
code("encoder.write${encFunc}($number, it)")
389+
}
390+
}
384391
}
385392
}
386393

@@ -471,7 +478,9 @@ class ModelToKotlinCommonGenerator(
471478
generateFieldComputeSizeCall(field, fieldName)
472479
}
473480
} else {
474-
generateFieldComputeSizeCall(field, fieldName)
481+
scope("if (presenceMask[${field.presenceIdx}])") {
482+
generateFieldComputeSizeCall(field, fieldName)
483+
}
475484
}
476485
}
477486
code("return result")
@@ -614,7 +623,7 @@ class ModelToKotlinCommonGenerator(
614623
private fun FieldDeclaration.typeFqName(): String {
615624
return when (type) {
616625
is FieldType.Message -> {
617-
type.dec.name.safeFullName()
626+
type.dec.value.name.safeFullName()
618627
}
619628

620629
is FieldType.Enum -> type.dec.name.safeFullName()
@@ -627,7 +636,7 @@ class ModelToKotlinCommonGenerator(
627636

628637
is FieldType.List -> {
629638
val fqValue = when (val value = type.value) {
630-
is FieldType.Message -> value.dec.name
639+
is FieldType.Message -> value.dec.value.name
631640
is FieldType.IntegralType -> value.fqName
632641
else -> error("Unsupported type: $value")
633642
}
@@ -639,13 +648,13 @@ class ModelToKotlinCommonGenerator(
639648
val entry by type.entry
640649

641650
val fqKey = when (val key = entry.key) {
642-
is FieldType.Message -> key.dec.name
651+
is FieldType.Message -> key.dec.value.name
643652
is FieldType.IntegralType -> key.fqName
644653
else -> error("Unsupported type: $key")
645654
}
646655

647656
val fqValue = when (val value = entry.value) {
648-
is FieldType.Message -> value.dec.name
657+
is FieldType.Message -> value.dec.value.name
649658
is FieldType.IntegralType -> value.fqName
650659
else -> error("Unsupported type: $value")
651660
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import com.google.protobuf.Descriptors
99
import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest
1010
import kotlinx.rpc.protobuf.model.*
1111

12+
private val nameCache = mutableMapOf<Descriptors.GenericDescriptor, FqName>()
1213
private val modelCache = mutableMapOf<Descriptors.GenericDescriptor, Any>()
1314

1415
/**
@@ -65,9 +66,10 @@ private fun DescriptorProtos.FileDescriptorProto.toDescriptor(
6566
* @return The fully qualified name represented as an instance of FqName, specific to the descriptor's context.
6667
*/
6768
private fun Descriptors.GenericDescriptor.fqName(): FqName {
69+
if (nameCache.containsKey(this)) return nameCache[this]!!
6870
val nameCapital = name.simpleProtoNameToKotlin(firstLetterUpper = true)
6971
val nameLower = name.simpleProtoNameToKotlin()
70-
return when (this) {
72+
val fqName = when (this) {
7173
is Descriptors.FileDescriptor -> FqName.Package.fromString(`package`)
7274
is Descriptors.Descriptor -> FqName.Declaration(nameCapital, containingType?.fqName() ?: file.fqName())
7375
is Descriptors.FieldDescriptor -> {
@@ -82,6 +84,8 @@ private fun Descriptors.GenericDescriptor.fqName(): FqName {
8284
is Descriptors.MethodDescriptor -> FqName.Declaration(nameLower, service?.fqName() ?: file.fqName())
8385
else -> error("Unknown generic descriptor: $this")
8486
}
87+
nameCache[this] = fqName
88+
return fqName
8589
}
8690

8791
/**
@@ -246,7 +250,7 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType {
246250
Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32
247251
Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64
248252
Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Enum(enumType.toModel())
249-
Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Message(messageType!!.toModel())
253+
Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Message(lazy { messageType!!.toModel() })
250254
Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported")
251255
}
252256

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ sealed interface FieldType {
3737
override val wireType: WireType = WireType.VARINT
3838
}
3939

40-
data class Message(val dec: MessageDeclaration) : FieldType {
40+
data class Message(val dec: Lazy<MessageDeclaration>) : FieldType {
4141
override val defaultValue: String? = null
4242
override val wireType: WireType = WireType.LENGTH_DELIMITED
4343
}

0 commit comments

Comments
 (0)