Skip to content

Commit c69309d

Browse files
committed
Fix marshalling of empty oneOf messages
Fixes #61
1 parent e7d7219 commit c69309d

File tree

13 files changed

+211
-9
lines changed

13 files changed

+211
-9
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package conformance
2+
3+
import (
4+
"github.com/stretchr/testify/require"
5+
"google.golang.org/protobuf/encoding/protojson"
6+
"google.golang.org/protobuf/encoding/prototext"
7+
"google.golang.org/protobuf/proto"
8+
"strings"
9+
"testing"
10+
)
11+
12+
func roundTripUpstream(b []byte) ([]byte, error) {
13+
msg := &TestAllTypesProto3{}
14+
if err := proto.Unmarshal(b, msg); err != nil {
15+
return nil, err
16+
}
17+
res, err := proto.Marshal(msg)
18+
if err != nil {
19+
return nil, err
20+
}
21+
return res, nil
22+
}
23+
24+
func roundTripVtprotobuf(b []byte) ([]byte, error) {
25+
msg := &TestAllTypesProto3{}
26+
if err := msg.UnmarshalVT(b); err != nil {
27+
return nil, err
28+
}
29+
res, err := msg.MarshalVT()
30+
if err != nil {
31+
return nil, err
32+
}
33+
return res, nil
34+
}
35+
36+
func FuzzProto(f *testing.F) {
37+
f.Fuzz(func(t *testing.T, b []byte) {
38+
u, uerr := roundTripUpstream(b)
39+
v, verr := roundTripVtprotobuf(b)
40+
if verr != nil && strings.Contains(verr.Error(), "wrong wireType") {
41+
t.Skip()
42+
}
43+
if uerr != nil && strings.Contains(uerr.Error(), "cannot parse invalid wire-format data") {
44+
t.Skip()
45+
}
46+
if (uerr != nil) != (verr != nil) {
47+
t.Fatalf("upstream err: %v (%v), vtprotobuf err: %v (%v)", uerr, u, verr, v)
48+
}
49+
vt := &TestAllTypesProto3{}
50+
_ = vt.UnmarshalVT(b)
51+
t.Logf("vtprotobuf: %v, %v", protojson.Format(vt), prototext.Format(vt))
52+
us := &TestAllTypesProto3{}
53+
_ = proto.Unmarshal(b, us)
54+
55+
t.Logf("upstream: %v, %v", protojson.Format(us), prototext.Format(us))
56+
require.Equal(t, u, v)
57+
})
58+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package conformance
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
"google.golang.org/protobuf/proto"
8+
)
9+
10+
func TestEmptyOneoff(t *testing.T) {
11+
// Regression test for https://github.com/planetscale/vtprotobuf/issues/61
12+
msg := &TestAllTypesProto3{OneofField: &TestAllTypesProto3_OneofNestedMessage{}}
13+
upstream, _ := proto.Marshal(msg)
14+
vt, _ := msg.MarshalVTStrict()
15+
require.Equal(t, upstream, vt)
16+
}

conformance/internal/conformance/test_messages_proto2_vtproto.pb.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

conformance/internal/conformance/test_messages_proto3_vtproto.pb.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
go test fuzz v1
2+
[]byte("\xe30$")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
go test fuzz v1
2+
[]byte("8\xb30")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
go test fuzz v1
2+
[]byte("\x80\xff\x000")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
go test fuzz v1
2+
[]byte("X\xb30")

features/marshal/marshalto.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"strings"
1313

1414
"github.com/planetscale/vtprotobuf/generator"
15-
1615
"google.golang.org/protobuf/compiler/protogen"
1716
"google.golang.org/protobuf/encoding/protowire"
1817
"google.golang.org/protobuf/reflect/protoreflect"
@@ -520,7 +519,12 @@ func (p *marshal) field(oneof bool, numGen *counter, field *protogen.Field) {
520519
default:
521520
panic("not implemented")
522521
}
523-
if repeated || nullable {
522+
if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() {
523+
p.P("} else {")
524+
p.P("i = protohelpers.EncodeVarint(dAtA, i, 0)")
525+
p.encodeKey(fieldNumber, wireType)
526+
p.P("}")
527+
} else if repeated || nullable {
524528
p.P(`}`)
525529
}
526530
}
@@ -676,7 +680,7 @@ func (p *marshal) message(message *protogen.Message) {
676680
p.P(`}`)
677681
p.P()
678682

679-
//Generate MarshalToVT methods for oneof fields
683+
// Generate MarshalToVT methods for oneof fields
680684
for _, field := range message.Fields {
681685
if field.Oneof == nil || field.Oneof.Desc.IsSynthetic() {
682686
continue
@@ -709,7 +713,6 @@ func (p *marshal) marshalBackwardSize(varInt bool) {
709713
if varInt {
710714
p.encodeVarint(`size`)
711715
}
712-
713716
}
714717

715718
func (p *marshal) marshalBackward(varName string, varInt bool, message *protogen.Message) {

features/size/size.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ package size
88
import (
99
"strconv"
1010

11+
"github.com/planetscale/vtprotobuf/generator"
1112
"google.golang.org/protobuf/compiler/protogen"
1213
"google.golang.org/protobuf/encoding/protowire"
1314
"google.golang.org/protobuf/reflect/protoreflect"
14-
15-
"github.com/planetscale/vtprotobuf/generator"
1615
)
1716

1817
func init() {
@@ -266,7 +265,9 @@ func (p *size) field(oneof bool, field *protogen.Field, sizeName string) {
266265
default:
267266
panic("not implemented")
268267
}
269-
if repeated || nullable {
268+
if oneof && field.Desc.Kind() == protoreflect.MessageKind && !field.Desc.IsMap() && !field.Desc.IsList() {
269+
p.P("} else { n += 3 }")
270+
} else if repeated || nullable {
270271
p.P(`}`)
271272
}
272273
}
@@ -310,8 +311,6 @@ func (p *size) message(message *protogen.Message) {
310311
}
311312
p.P(`}`)
312313
} else {
313-
//if _, ok := oneofs[fieldname]; !ok {
314-
//oneofs[fieldname] = struct{}{}
315314
p.P(`if vtmsg, ok := m.`, fieldname, `.(interface{ SizeVT() int }); ok {`)
316315
p.P(`n+=vtmsg.`, sizeName, `()`)
317316
p.P(`}`)

0 commit comments

Comments
 (0)