Skip to content

json: support omitzero #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
version: "2"
linters:
exclusions:
rules:
# Tests copied from the stdlib are not meant to be linted.
- path: 'golang_(.+_)?test\.go'
source: "^" # regex
2 changes: 1 addition & 1 deletion benchmarks/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ benchstat := ${GOPATH}/bin/benchstat
all:

$(benchstat):
go install golang.org/x/perf/cmd/benchstat
go install golang.org/x/perf/cmd/benchstat@latest

$(benchmark.cmd.dir)/message.pb.go: $(benchmark.cmd.dir)/message.proto
@protoc -I. \
Expand Down
153 changes: 96 additions & 57 deletions json/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"encoding"
"encoding/json"
"fmt"
"maps"
"math/big"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"
Expand Down Expand Up @@ -36,8 +38,19 @@ type encoder struct {
// encoder starts tracking pointers it has seen as an attempt to detect
// whether it has entered a pointer cycle and needs to error before the
// goroutine runs out of stack space.
//
// This relies on encoder being passed as a value,
// and encoder methods calling each other in a traditional stack
// (not using trampoline techniques),
// since ptrDepth is never decremented.
ptrDepth uint32
ptrSeen map[unsafe.Pointer]struct{}
ptrSeen cycleMap
}

type cycleMap map[unsafe.Pointer]struct{}

var cycleMapPool = sync.Pool{
New: func() any { return make(cycleMap) },
}

type decoder struct {
Expand Down Expand Up @@ -73,12 +86,9 @@ func cacheLoad() map[unsafe.Pointer]codec {

func cacheStore(typ reflect.Type, cod codec, oldCodecs map[unsafe.Pointer]codec) {
newCodecs := make(map[unsafe.Pointer]codec, len(oldCodecs)+1)
maps.Copy(newCodecs, oldCodecs)
newCodecs[typeid(typ)] = cod

for t, c := range oldCodecs {
newCodecs[t] = c
}

cache.Store(&newCodecs)
}

Expand Down Expand Up @@ -205,7 +215,7 @@ func constructCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr b
c = constructUnsupportedTypeCodec(t)
}

p := reflect.PtrTo(t)
p := reflect.PointerTo(t)

if canAddr {
switch {
Expand Down Expand Up @@ -291,7 +301,7 @@ func constructSliceCodec(t reflect.Type, seen map[reflect.Type]*structType) code
// Go 1.7+ behavior: slices of byte types (and aliases) may override the
// default encoding and decoding behaviors by implementing marshaler and
// unmarshaler interfaces.
p := reflect.PtrTo(e)
p := reflect.PointerTo(e)
c := codec{}

switch {
Expand Down Expand Up @@ -391,7 +401,7 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec
kc := codec{}
vc := constructCodec(v, seen, false)

if k.Implements(textMarshalerType) || reflect.PtrTo(k).Implements(textUnmarshalerType) {
if k.Implements(textMarshalerType) || reflect.PointerTo(k).Implements(textUnmarshalerType) {
kc.encode = constructTextMarshalerEncodeFunc(k, false)
kc.decode = constructTextUnmarshalerDecodeFunc(k, true)

Expand Down Expand Up @@ -570,6 +580,7 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
anonymous = f.Anonymous
tag = false
omitempty = false
omitzero = false
stringify = false
unexported = len(f.PkgPath) != 0
)
Expand All @@ -595,6 +606,8 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
switch tag {
case "omitempty":
omitempty = true
case "omitzero":
omitzero = true
case "string":
stringify = true
}
Expand Down Expand Up @@ -677,9 +690,11 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se
fields = append(fields, structField{
codec: codec,
offset: offset + f.Offset,
empty: emptyFuncOf(f.Type),
isEmpty: emptyFuncOf(f.Type),
isZero: zeroFuncOf(f.Type),
tag: tag,
omitempty: omitempty,
omitzero: omitzero,
name: name,
index: i << 32,
typ: f.Type,
Expand Down Expand Up @@ -897,6 +912,18 @@ func isValidTag(s string) bool {
return true
}

func zeroFuncOf(t reflect.Type) emptyFunc {
if t.Implements(isZeroerType) {
return func(p unsafe.Pointer) bool {
return unsafeToAny(t, p).(isZeroer).IsZero()
}
}

return func(p unsafe.Pointer) bool {
return reflectDeref(t, p).IsZero()
}
}

func emptyFuncOf(t reflect.Type) emptyFunc {
switch t {
case bytesType, rawMessageType:
Expand All @@ -910,7 +937,7 @@ func emptyFuncOf(t reflect.Type) emptyFunc {
}

case reflect.Map:
return func(p unsafe.Pointer) bool { return reflect.NewAt(t, p).Elem().Len() == 0 }
return func(p unsafe.Pointer) bool { return reflectDeref(t, p).Len() == 0 }

case reflect.Slice:
return func(p unsafe.Pointer) bool { return (*slice)(p).len == 0 }
Expand Down Expand Up @@ -955,6 +982,14 @@ func emptyFuncOf(t reflect.Type) emptyFunc {
return func(unsafe.Pointer) bool { return false }
}

func reflectDeref(t reflect.Type, p unsafe.Pointer) reflect.Value {
return reflect.NewAt(t, p).Elem()
}

func unsafeToAny(t reflect.Type, p unsafe.Pointer) any {
return reflectDeref(t, p).Interface()
}

type iface struct {
typ unsafe.Pointer
ptr unsafe.Pointer
Expand All @@ -972,15 +1007,16 @@ type structType struct {
ficaseIndex map[string]*structField
keyset []byte
typ reflect.Type
inlined bool
}

type structField struct {
codec codec
offset uintptr
empty emptyFunc
isEmpty emptyFunc
isZero emptyFunc
tag bool
omitempty bool
omitzero bool
json string
html string
name string
Expand Down Expand Up @@ -1066,53 +1102,56 @@ type sliceHeader struct {
Cap int
}

type isZeroer interface{ IsZero() bool }

var (
nullType = reflect.TypeOf(nil)
boolType = reflect.TypeOf(false)

intType = reflect.TypeOf(int(0))
int8Type = reflect.TypeOf(int8(0))
int16Type = reflect.TypeOf(int16(0))
int32Type = reflect.TypeOf(int32(0))
int64Type = reflect.TypeOf(int64(0))

uintType = reflect.TypeOf(uint(0))
uint8Type = reflect.TypeOf(uint8(0))
uint16Type = reflect.TypeOf(uint16(0))
uint32Type = reflect.TypeOf(uint32(0))
uint64Type = reflect.TypeOf(uint64(0))
uintptrType = reflect.TypeOf(uintptr(0))

float32Type = reflect.TypeOf(float32(0))
float64Type = reflect.TypeOf(float64(0))

bigIntType = reflect.TypeOf(new(big.Int))
numberType = reflect.TypeOf(json.Number(""))
stringType = reflect.TypeOf("")
stringsType = reflect.TypeOf([]string(nil))
bytesType = reflect.TypeOf(([]byte)(nil))
durationType = reflect.TypeOf(time.Duration(0))
timeType = reflect.TypeOf(time.Time{})
rawMessageType = reflect.TypeOf(RawMessage(nil))

numberPtrType = reflect.PtrTo(numberType)
durationPtrType = reflect.PtrTo(durationType)
timePtrType = reflect.PtrTo(timeType)
rawMessagePtrType = reflect.PtrTo(rawMessageType)

sliceInterfaceType = reflect.TypeOf(([]any)(nil))
sliceStringType = reflect.TypeOf(([]any)(nil))
mapStringInterfaceType = reflect.TypeOf((map[string]any)(nil))
mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil))
mapStringStringType = reflect.TypeOf((map[string]string)(nil))
mapStringStringSliceType = reflect.TypeOf((map[string][]string)(nil))
mapStringBoolType = reflect.TypeOf((map[string]bool)(nil))

interfaceType = reflect.TypeOf((*any)(nil)).Elem()
jsonMarshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
jsonUnmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
boolType = reflect.TypeFor[bool]()

intType = reflect.TypeFor[int]()
int8Type = reflect.TypeFor[int8]()
int16Type = reflect.TypeFor[int16]()
int32Type = reflect.TypeFor[int32]()
int64Type = reflect.TypeFor[int64]()

uintType = reflect.TypeFor[uint]()
uint8Type = reflect.TypeFor[uint8]()
uint16Type = reflect.TypeFor[uint16]()
uint32Type = reflect.TypeFor[uint32]()
uint64Type = reflect.TypeFor[uint64]()
uintptrType = reflect.TypeFor[uintptr]()

float32Type = reflect.TypeFor[float32]()
float64Type = reflect.TypeFor[float64]()

bigIntType = reflect.TypeFor[*big.Int]()
numberType = reflect.TypeFor[json.Number]()
stringType = reflect.TypeFor[string]()
stringsType = reflect.TypeFor[[]string]()
bytesType = reflect.TypeFor[[]byte]()
durationType = reflect.TypeFor[time.Duration]()
timeType = reflect.TypeFor[time.Time]()
rawMessageType = reflect.TypeFor[RawMessage]()

numberPtrType = reflect.PointerTo(numberType)
durationPtrType = reflect.PointerTo(durationType)
timePtrType = reflect.PointerTo(timeType)
rawMessagePtrType = reflect.PointerTo(rawMessageType)

sliceInterfaceType = reflect.TypeFor[[]any]()
sliceStringType = reflect.TypeFor[[]any]()
mapStringInterfaceType = reflect.TypeFor[map[string]any]()
mapStringRawMessageType = reflect.TypeFor[map[string]RawMessage]()
mapStringStringType = reflect.TypeFor[map[string]string]()
mapStringStringSliceType = reflect.TypeFor[map[string][]string]()
mapStringBoolType = reflect.TypeFor[map[string]bool]()

interfaceType = reflect.TypeFor[any]()
jsonMarshalerType = reflect.TypeFor[Marshaler]()
jsonUnmarshalerType = reflect.TypeFor[Unmarshaler]()
textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
isZeroerType = reflect.TypeFor[isZeroer]()

bigIntDecoder = constructJSONUnmarshalerDecodeFunc(bigIntType, false)
)
Expand Down
4 changes: 2 additions & 2 deletions json/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ func (d decoder) decodeMaybeEmptyInterface(b []byte, p unsafe.Pointer, t reflect
return d.decodeUnmarshalTypeError(b, p, t)
}

func (d decoder) decodeUnmarshalTypeError(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) {
func (d decoder) decodeUnmarshalTypeError(b []byte, _ unsafe.Pointer, t reflect.Type) ([]byte, error) {
v, b, _, err := d.parseValue(b)
if err != nil {
return b, err
Expand Down Expand Up @@ -1500,7 +1500,7 @@ func (d decoder) decodeTextUnmarshaler(b []byte, p unsafe.Pointer, t reflect.Typ
value = "array"
}

return b, &UnmarshalTypeError{Value: value, Type: reflect.PtrTo(t)}
return b, &UnmarshalTypeError{Value: value, Type: reflect.PointerTo(t)}
}

func (d decoder) prependField(key, field string) string {
Expand Down
Loading
Loading