Skip to content

Commit 5f48bfb

Browse files
committed
grpc-pb: Support message in oneof
Signed-off-by: Johannes Zottele <[email protected]>
1 parent b733ebc commit 5f48bfb

File tree

4 files changed

+53
-15
lines changed

4 files changed

+53
-15
lines changed

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ import test.recursive.RecursiveInternal
2222
import test.recursive.RecursiveReq
2323
import test.recursive.invoke
2424
import test.submsg.*
25-
import kotlin.test.Test
26-
import kotlin.test.assertEquals
27-
import kotlin.test.assertFailsWith
28-
import kotlin.test.assertNull
25+
import kotlin.test.*
2926

3027
class ProtosTest {
3128

@@ -155,6 +152,15 @@ class ProtosTest {
155152
}
156153
val decoded2 = encodeDecode(msg2, OneOfMsgInternal.CODEC)
157154
assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field)
155+
156+
val msg3 = OneOfMsg {
157+
field = OneOfMsg.Field.Other(Other { arg2 = "test" })
158+
}
159+
val decoded3 = encodeDecode(msg3, OneOfMsgInternal.CODEC)
160+
assertIs<OneOfMsg.Field.Other>(decoded3.field)
161+
assertNull((decoded3.field as OneOfMsg.Field.Other).value.arg1)
162+
assertEquals("test", (decoded3.field as OneOfMsg.Field.Other).value.arg2)
163+
assertNull((decoded3.field as OneOfMsg.Field.Other).value.arg3)
158164
}
159165

160166
@Test
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import "sub_message.proto";
2+
13
message OneOfMsg {
24
oneof field {
35
int32 sint = 2;
46
fixed64 fixed = 3;
7+
test.submsg.Other other = 4;
58
}
69
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -282,44 +282,68 @@ class ModelToKotlinCommonGenerator(
282282
private fun CodeGenerator.readMatchCase(
283283
field: FieldDeclaration,
284284
lvalue: String = "msg.${field.name}",
285-
wrapperCtor: (String) -> String = { it }
285+
wrapperCtor: (String) -> String = { it },
286+
beforeValueDecoding: CodeGenerator.() -> Unit = {},
286287
) {
287288
when (val fieldType = field.type) {
288289
is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") {
290+
beforeValueDecoding()
289291
generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor)
290292
}
291293

292294
is FieldType.List -> if (field.dec.isPacked) {
293295
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") {
296+
beforeValueDecoding()
294297
generateDecodeFieldValue(fieldType, lvalue, isPacked = true, wrapperCtor = wrapperCtor)
295298
}
296299
} else {
297300
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") {
301+
beforeValueDecoding()
298302
generateDecodeFieldValue(fieldType, lvalue, isPacked = false, wrapperCtor = wrapperCtor)
299303
}
300304
}
301305

302306
is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") {
307+
beforeValueDecoding()
303308
generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor)
304309
}
305310

306311
is FieldType.OneOf -> {
307312
fieldType.dec.variants.forEach { variant ->
308313
val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}"
309-
readMatchCase(
310-
field = variant,
311-
lvalue = lvalue,
312-
wrapperCtor = { "$variantName($it)" }
313-
)
314+
if (variant.type is FieldType.Message) {
315+
// in case of a message, we must construct an empty message before reading the message
316+
readMatchCase(
317+
field = variant,
318+
lvalue = "field.value",
319+
beforeValueDecoding = {
320+
beforeValueDecoding()
321+
scope("val field = ($lvalue as? $variantName) ?: $variantName(${variant.type.internalCtor()}).also") {
322+
// write the constructed oneof variant to the field
323+
code("$lvalue = it")
324+
}
325+
})
326+
} else {
327+
readMatchCase(
328+
field = variant,
329+
lvalue = lvalue,
330+
wrapperCtor = { "$variantName($it)" },
331+
beforeValueDecoding = beforeValueDecoding
332+
)
333+
}
314334
}
315335
}
316336

317337
is FieldType.Message -> {
318338
whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") {
319-
// check if the the current sub message object
320-
ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = {
321-
code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()")
322-
})
339+
if (field.presenceIdx != null) {
340+
// check if the current sub message object was already set, if not, set a new one
341+
// to set the field's presence tracker to true
342+
ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = {
343+
code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()")
344+
})
345+
}
346+
beforeValueDecoding()
323347
generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor)
324348
}
325349
}
@@ -820,6 +844,9 @@ class ModelToKotlinCommonGenerator(
820844
}
821845
}
822846

847+
private fun FieldType.Message.internalCtor() =
848+
dec.value.internalClassFullName() + "()"
849+
823850
private fun MessageDeclaration.internalClassFullName(): String {
824851
return name.safeFullName(MSG_INTERNAL_SUFFIX)
825852
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ data class FieldDeclaration(
7575
) {
7676
val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32
7777

78+
val isPartOfOneof: Boolean = dec.realContainingOneof != null
79+
7880
// aligns with edition settings and backward compatibility with proto2 and proto3
7981
val nullable: Boolean = (dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue()
8082
&& !dec.isRepeated // repeated fields cannot be nullable (just empty)
81-
&& dec.realContainingOneof == null // upper conditions would match oneof inner fields
83+
&& !isPartOfOneof // upper conditions would match oneof inner fields
8284
&& type !is FieldType.Message // messages must not be null (to conform protobuf standards)
8385
)
8486
|| type is FieldType.OneOf // all OneOf fields are nullable

0 commit comments

Comments
 (0)