|
1 | 1 | package grpc |
2 | 2 |
|
3 | | -import "fmt" |
| 3 | +import ( |
| 4 | + "google.golang.org/grpc/encoding" |
| 5 | + "google.golang.org/grpc/mem" |
| 6 | + |
| 7 | + // Guarantee that the built-in proto is called registered before this one |
| 8 | + // so that it can be replaced. |
| 9 | + _ "google.golang.org/grpc/encoding/proto" |
| 10 | +) |
4 | 11 |
|
5 | 12 | // Name is the name registered for the proto compressor. |
6 | 13 | const Name = "proto" |
7 | 14 |
|
8 | | -type Codec struct{} |
9 | | - |
10 | 15 | type vtprotoMessage interface { |
11 | | - MarshalVT() ([]byte, error) |
| 16 | + MarshalToVT(data []byte) (int, error) |
12 | 17 | UnmarshalVT([]byte) error |
| 18 | + SizeVT() int |
| 19 | +} |
| 20 | + |
| 21 | +type Codec struct { |
| 22 | + fallback encoding.CodecV2 |
13 | 23 | } |
14 | 24 |
|
15 | | -func (Codec) Marshal(v interface{}) ([]byte, error) { |
16 | | - vt, ok := v.(vtprotoMessage) |
17 | | - if !ok { |
18 | | - return nil, fmt.Errorf("failed to marshal, message is %T (missing vtprotobuf helpers)", v) |
| 25 | +func (Codec) Name() string { return Name } |
| 26 | + |
| 27 | +func (c *Codec) Marshal(v any) (data mem.BufferSlice, err error) { |
| 28 | + if m, ok := v.(vtprotoMessage); ok { |
| 29 | + size := m.SizeVT() |
| 30 | + if mem.IsBelowBufferPoolingThreshold(size) { |
| 31 | + buf := make([]byte, 0, size) |
| 32 | + if _, err := m.MarshalToVT(buf[:0]); err != nil { |
| 33 | + return nil, err |
| 34 | + } |
| 35 | + data = append(data, mem.SliceBuffer(buf)) |
| 36 | + } else { |
| 37 | + pool := mem.DefaultBufferPool() |
| 38 | + buf := pool.Get(size) |
| 39 | + if _, err := m.MarshalToVT((*buf)[:0]); err != nil { |
| 40 | + pool.Put(buf) |
| 41 | + return nil, err |
| 42 | + } |
| 43 | + data = append(data, mem.NewBuffer(buf, pool)) |
| 44 | + } |
| 45 | + return data, nil |
19 | 46 | } |
20 | | - return vt.MarshalVT() |
| 47 | + |
| 48 | + return c.fallback.Marshal(v) |
21 | 49 | } |
22 | 50 |
|
23 | | -func (Codec) Unmarshal(data []byte, v interface{}) error { |
24 | | - vt, ok := v.(vtprotoMessage) |
25 | | - if !ok { |
26 | | - return fmt.Errorf("failed to unmarshal, message is %T (missing vtprotobuf helpers)", v) |
| 51 | +func (c *Codec) Unmarshal(data mem.BufferSlice, v any) error { |
| 52 | + if m, ok := v.(vtprotoMessage); ok { |
| 53 | + buf := data.MaterializeToBuffer(mem.DefaultBufferPool()) |
| 54 | + defer buf.Free() |
| 55 | + return m.UnmarshalVT(buf.ReadOnlyData()) |
27 | 56 | } |
28 | | - return vt.UnmarshalVT(data) |
| 57 | + |
| 58 | + return c.fallback.Unmarshal(data, v) |
29 | 59 | } |
30 | 60 |
|
31 | | -func (Codec) Name() string { |
32 | | - return Name |
| 61 | +func init() { |
| 62 | + encoding.RegisterCodecV2(&Codec{ |
| 63 | + fallback: encoding.GetCodecV2("proto"), |
| 64 | + }) |
33 | 65 | } |
0 commit comments