Skip to content

Commit 9e32cbb

Browse files
committed
grpc-pb: Support repeated messages
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 585a43e commit 9e32cbb

File tree

6 files changed

+143
-58
lines changed

6 files changed

+143
-58
lines changed

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import test.recursive.Recursive
2121
import test.recursive.RecursiveInternal
2222
import test.recursive.RecursiveReq
2323
import test.recursive.invoke
24+
import test.submsg.*
2425
import kotlin.test.Test
2526
import kotlin.test.assertEquals
2627
import kotlin.test.assertFailsWith
@@ -65,19 +66,25 @@ class ProtosTest {
6566

6667
@Test
6768
fun testRepeatedProto() {
69+
val elem = { i: Int -> Repeated.Other { a = i } }
6870
val msg = Repeated {
6971
listFixed32 = listOf(1, 5, 3).map { it.toUInt() }
7072
listFixed32Packed = listOf(1, 2, 3).map { it.toUInt() }
7173
listInt32 = listOf(4, 7, 6)
7274
listInt32Packed = listOf(4, 5, 6)
7375
listString = listOf("a", "b", "c")
76+
listMessage = listOf(elem(1), elem(2), elem(3))
7477
}
7578

7679
val decoded = encodeDecode(msg, RepeatedInternal.CODEC)
7780

7881
assertEquals(msg.listInt32, decoded.listInt32)
7982
assertEquals(msg.listFixed32, decoded.listFixed32)
8083
assertEquals(msg.listString, decoded.listString)
84+
assertEquals(msg.listMessage.size, decoded.listMessage.size)
85+
for (i in msg.listMessage.indices) {
86+
assertEquals(msg.listMessage[i].a, decoded.listMessage[i].a)
87+
}
8188
}
8289

8390
@Test
@@ -254,7 +261,32 @@ class ProtosTest {
254261
NestedOuterInternal.InnerInternal.SuperInnerInternal.DuperInnerInternal.EvenMoreInnerInternal.CantBelieveItsSoInnerInternal.CODEC
255262
)
256263
assertEquals(123456789, decodedInner.num)
264+
}
265+
266+
@Test
267+
fun testMessageMerging() {
268+
269+
val buffer = Buffer()
270+
val encoder = WireEncoder(buffer)
271+
272+
val firstPart = Other {
273+
arg1 = "first"
274+
arg2 = "second"
275+
}
276+
val secondPart = Other {
277+
arg2 = "third"
278+
arg3 = "fourth"
279+
}
280+
281+
encoder.writeMessage(1, firstPart as OtherInternal) { encodeWith(encoder) }
282+
encoder.flush()
283+
encoder.writeMessage(1, secondPart as OtherInternal) { encodeWith(encoder) }
284+
encoder.flush()
257285

286+
val decoded = ReferenceInternal.CODEC.decode(buffer)
287+
assertEquals("first", decoded.other.arg1)
288+
assertEquals("third", decoded.other.arg2)
289+
assertEquals("fourth", decoded.other.arg3)
258290
}
259291

260292
}

grpc/grpc-core/src/commonTest/proto/exclude/reference.proto

Lines changed: 0 additions & 9 deletions
This file was deleted.

grpc/grpc-core/src/commonTest/proto/nested.proto

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,34 @@ package test.nested;
1111

1212
message NestedOuter {
1313
message Inner {
14-
// message InnerSubMsg {
15-
// optional bool flag = 1;
16-
// }
17-
//
18-
// enum InnerEnum {
19-
// INNER_ENUM_UNSPECIFIED = 0;
20-
// INNER_ENUM_FOO = 1;
21-
// }
14+
message InnerSubMsg {
15+
optional bool flag = 1;
16+
}
17+
18+
enum InnerEnum {
19+
INNER_ENUM_UNSPECIFIED = 0;
20+
INNER_ENUM_FOO = 1;
21+
}
2222

23-
// optional double double = 1;
24-
// optional float float = 2;
25-
// optional int32 int32 = 3;
26-
// optional int64 int64 = 4;
27-
// optional uint32 uint32 = 5;
28-
// optional uint64 uint64 = 6;
29-
// optional sint32 sint32 = 7;
30-
// optional sint64 sint64 = 8;
31-
// optional fixed32 fixed32 = 9;
32-
// optional fixed64 fixed64 = 10;
33-
// optional sfixed32 sfixed32 = 11;
34-
// optional sfixed64 sfixed64 = 12;
35-
// optional bool bool = 13;
36-
// optional string string = 14;
37-
// optional bytes bytes = 15;
38-
// optional InnerSubMsg inner_submsg = 16;
39-
// optional InnerEnum inner_enum = 17;
40-
// repeated int32 repeated_int32 = 18 [packed = true];
41-
// repeated InnerSubMsg repeated_inner_submsg = 19;
23+
optional double double = 1;
24+
optional float float = 2;
25+
optional int32 int32 = 3;
26+
optional int64 int64 = 4;
27+
optional uint32 uint32 = 5;
28+
optional uint64 uint64 = 6;
29+
optional sint32 sint32 = 7;
30+
optional sint64 sint64 = 8;
31+
optional fixed32 fixed32 = 9;
32+
optional fixed64 fixed64 = 10;
33+
optional sfixed32 sfixed32 = 11;
34+
optional sfixed64 sfixed64 = 12;
35+
optional bool bool = 13;
36+
optional string string = 14;
37+
optional bytes bytes = 15;
38+
optional InnerSubMsg inner_submsg = 16;
39+
optional InnerEnum inner_enum = 17;
40+
repeated int32 repeated_int32 = 18 [packed = true];
41+
repeated InnerSubMsg repeated_inner_submsg = 19;
4242
// map<string, string> string_map = 20;
4343

4444
message SuperInner {

grpc/grpc-core/src/commonTest/proto/repeated.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,9 @@ message Repeated {
88
repeated int32 listInt32 = 3 [packed = false];
99
repeated int32 listInt32Packed = 4 [packed = true];
1010
repeated string listString = 5;
11+
repeated Other listMessage = 6;
12+
13+
message Other {
14+
int32 a = 1;
15+
}
1116
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
syntax = "proto3";
2+
3+
package test.submsg;
4+
5+
message Other {
6+
optional string arg1 = 1;
7+
optional string arg2 = 2;
8+
optional string arg3 = 3;
9+
}
10+
11+
message Reference {
12+
Other other = 1;
13+
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -285,24 +285,21 @@ class ModelToKotlinCommonGenerator(
285285
) {
286286
when (val fieldType = field.type) {
287287
is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") {
288-
val raw = "decoder.read${field.type.decodeEncodeFuncName()}()"
289-
code("$lvalue = ${wrapperCtor(raw)}")
288+
generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor)
290289
}
291290

292291
is FieldType.List -> if (field.dec.isPacked) {
293292
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") {
294-
code("$lvalue = decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()")
293+
generateDecodeFieldValue(fieldType, lvalue, isPacked = true, wrapperCtor = wrapperCtor)
295294
}
296295
} else {
297296
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") {
298-
code("(msg.${field.name} as ArrayList).add(decoder.read${fieldType.value.decodeEncodeFuncName()}())")
297+
generateDecodeFieldValue(fieldType, lvalue, isPacked = false, wrapperCtor = wrapperCtor)
299298
}
300299
}
301300

302301
is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") {
303-
val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber"
304-
val raw = "$fromNum(decoder.read${field.type.decodeEncodeFuncName()}())"
305-
code("$lvalue = ${wrapperCtor(raw)}")
302+
generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor)
306303
}
307304

308305
is FieldType.OneOf -> {
@@ -317,20 +314,71 @@ class ModelToKotlinCommonGenerator(
317314
}
318315

319316
is FieldType.Message -> {
320-
val internalClassName = fieldType.dec.value.internalClassFullName()
321317
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") {
322318
// check if the the current sub message object
323319
ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = {
324320
code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()")
325321
})
326-
code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)")
322+
generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor)
327323
}
328324
}
329325

330326
is FieldType.Map -> TODO()
331327
}
332328
}
333329

330+
private fun CodeGenerator.generateDecodeFieldValue(
331+
fieldType: FieldType,
332+
lvalue: String,
333+
isPacked: Boolean = false,
334+
wrapperCtor: (String) -> String = { it }
335+
) {
336+
when (fieldType) {
337+
is FieldType.IntegralType -> {
338+
val raw = "decoder.read${fieldType.decodeEncodeFuncName()}()"
339+
code("$lvalue = ${wrapperCtor(raw)}")
340+
}
341+
342+
is FieldType.List -> if (isPacked) {
343+
code("$lvalue = decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()")
344+
} else {
345+
when (val elemType = fieldType.value) {
346+
is FieldType.Message -> {
347+
code("val elem = ${elemType.dec.value.internalClassFullName()}()")
348+
generateDecodeFieldValue(fieldType.value, "elem", wrapperCtor = wrapperCtor)
349+
}
350+
351+
else -> generateDecodeFieldValue(fieldType.value, "val elem", wrapperCtor = wrapperCtor)
352+
}
353+
code("($lvalue as ArrayList).add(elem)")
354+
}
355+
356+
is FieldType.Enum -> {
357+
val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber"
358+
val raw = "$fromNum(decoder.read${fieldType.decodeEncodeFuncName()}())"
359+
code("$lvalue = ${wrapperCtor(raw)}")
360+
}
361+
362+
is FieldType.OneOf -> {
363+
fieldType.dec.variants.forEach { variant ->
364+
val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}"
365+
readMatchCase(
366+
field = variant,
367+
lvalue = lvalue,
368+
wrapperCtor = { "$variantName($it)" }
369+
)
370+
}
371+
}
372+
373+
is FieldType.Message -> {
374+
val internalClassName = fieldType.dec.value.internalClassFullName()
375+
code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)")
376+
}
377+
378+
is FieldType.Map -> TODO()
379+
}
380+
}
381+
334382
private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function(
335383
name = "encodeWith",
336384
modifiers = "internal",
@@ -374,7 +422,7 @@ class ModelToKotlinCommonGenerator(
374422
field.dec.isPacked && !field.packedFixedSize ->
375423
code(
376424
"encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar, fieldSize = ${
377-
field.valueSizeCall(valueVar)
425+
field.type.valueSizeCall(valueVar, number, true)
378426
})"
379427
)
380428

@@ -500,7 +548,7 @@ class ModelToKotlinCommonGenerator(
500548

501549

502550
private fun CodeGenerator.generateFieldComputeSizeCall(field: FieldDeclaration, variable: String) {
503-
val valueSize by lazy { field.valueSizeCall(variable) }
551+
val valueSize by lazy { field.type.valueSizeCall(variable, field.number, field.dec.isPacked) }
504552
val tagSize = tagSizeCall(field.number, field.type.wireType)
505553

506554
when (field.type) {
@@ -541,24 +589,20 @@ class ModelToKotlinCommonGenerator(
541589
}
542590
}
543591

544-
private fun FieldDeclaration.valueSizeCall(variable: String): String {
545-
val sizeFunName = type.decodeEncodeFuncName()?.decapitalize()
592+
private fun FieldType.valueSizeCall(variable: String, number: Int, isPacked: Boolean = false): String {
593+
val sizeFunName = decodeEncodeFuncName()?.decapitalize()
546594
val sizeFunc = "$PB_PKG.WireSize.$sizeFunName($variable)"
547595

548-
return when (type) {
596+
return when (this) {
549597
is FieldType.IntegralType -> sizeFunc
550598

551599
is FieldType.List -> when {
552-
dec.isPacked -> sizeFunc
600+
isPacked -> sizeFunc
553601
else -> {
554602
// calculate the size of the values within the list.
555-
val valueTypeSizeFunc = type.value.decodeEncodeFuncName()?.decapitalize()
556-
"$variable.sumOf { $PB_PKG.WireSize.$valueTypeSizeFunc(it) + ${
557-
tagSizeCall(
558-
number,
559-
type.value.wireType
560-
)
561-
} }"
603+
val valueSize = value.valueSizeCall("it", number)
604+
val tagSize = tagSizeCall(number, value.wireType)
605+
"$variable.sumOf { $valueSize + $tagSize }"
562606
}
563607
}
564608

0 commit comments

Comments
 (0)