diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..01db160 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,7 @@ +version: "2" +linters: + exclusions: + rules: + # Tests copied from the stdlib are not meant to be linted. + - path: 'golang_(.+_)?test\.go' + source: "^" # regex diff --git a/benchmarks/Makefile b/benchmarks/Makefile index e14cb3e..8b36507 100644 --- a/benchmarks/Makefile +++ b/benchmarks/Makefile @@ -74,7 +74,7 @@ benchstat := ${GOPATH}/bin/benchstat all: $(benchstat): - go install golang.org/x/perf/cmd/benchstat + go install golang.org/x/perf/cmd/benchstat@latest $(benchmark.cmd.dir)/message.pb.go: $(benchmark.cmd.dir)/message.proto @protoc -I. \ diff --git a/json/codec.go b/json/codec.go index 32c078f..c6ee2f5 100644 --- a/json/codec.go +++ b/json/codec.go @@ -4,11 +4,13 @@ import ( "encoding" "encoding/json" "fmt" + "maps" "math/big" "reflect" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" "unicode" @@ -36,8 +38,19 @@ type encoder struct { // encoder starts tracking pointers it has seen as an attempt to detect // whether it has entered a pointer cycle and needs to error before the // goroutine runs out of stack space. + // + // This relies on encoder being passed as a value, + // and encoder methods calling each other in a traditional stack + // (not using trampoline techniques), + // since ptrDepth is never decremented. ptrDepth uint32 - ptrSeen map[unsafe.Pointer]struct{} + ptrSeen cycleMap +} + +type cycleMap map[unsafe.Pointer]struct{} + +var cycleMapPool = sync.Pool{ + New: func() any { return make(cycleMap) }, } type decoder struct { @@ -73,12 +86,9 @@ func cacheLoad() map[unsafe.Pointer]codec { func cacheStore(typ reflect.Type, cod codec, oldCodecs map[unsafe.Pointer]codec) { newCodecs := make(map[unsafe.Pointer]codec, len(oldCodecs)+1) + maps.Copy(newCodecs, oldCodecs) newCodecs[typeid(typ)] = cod - for t, c := range oldCodecs { - newCodecs[t] = c - } - cache.Store(&newCodecs) } @@ -205,7 +215,7 @@ func constructCodec(t reflect.Type, seen map[reflect.Type]*structType, canAddr b c = constructUnsupportedTypeCodec(t) } - p := reflect.PtrTo(t) + p := reflect.PointerTo(t) if canAddr { switch { @@ -291,7 +301,7 @@ func constructSliceCodec(t reflect.Type, seen map[reflect.Type]*structType) code // Go 1.7+ behavior: slices of byte types (and aliases) may override the // default encoding and decoding behaviors by implementing marshaler and // unmarshaler interfaces. - p := reflect.PtrTo(e) + p := reflect.PointerTo(e) c := codec{} switch { @@ -391,7 +401,7 @@ func constructMapCodec(t reflect.Type, seen map[reflect.Type]*structType) codec kc := codec{} vc := constructCodec(v, seen, false) - if k.Implements(textMarshalerType) || reflect.PtrTo(k).Implements(textUnmarshalerType) { + if k.Implements(textMarshalerType) || reflect.PointerTo(k).Implements(textUnmarshalerType) { kc.encode = constructTextMarshalerEncodeFunc(k, false) kc.decode = constructTextUnmarshalerDecodeFunc(k, true) @@ -570,6 +580,7 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se anonymous = f.Anonymous tag = false omitempty = false + omitzero = false stringify = false unexported = len(f.PkgPath) != 0 ) @@ -595,6 +606,8 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se switch tag { case "omitempty": omitempty = true + case "omitzero": + omitzero = true case "string": stringify = true } @@ -677,9 +690,11 @@ func appendStructFields(fields []structField, t reflect.Type, offset uintptr, se fields = append(fields, structField{ codec: codec, offset: offset + f.Offset, - empty: emptyFuncOf(f.Type), + isEmpty: emptyFuncOf(f.Type), + isZero: zeroFuncOf(f.Type), tag: tag, omitempty: omitempty, + omitzero: omitzero, name: name, index: i << 32, typ: f.Type, @@ -897,6 +912,18 @@ func isValidTag(s string) bool { return true } +func zeroFuncOf(t reflect.Type) emptyFunc { + if t.Implements(isZeroerType) { + return func(p unsafe.Pointer) bool { + return unsafeToAny(t, p).(isZeroer).IsZero() + } + } + + return func(p unsafe.Pointer) bool { + return reflectDeref(t, p).IsZero() + } +} + func emptyFuncOf(t reflect.Type) emptyFunc { switch t { case bytesType, rawMessageType: @@ -910,7 +937,7 @@ func emptyFuncOf(t reflect.Type) emptyFunc { } case reflect.Map: - return func(p unsafe.Pointer) bool { return reflect.NewAt(t, p).Elem().Len() == 0 } + return func(p unsafe.Pointer) bool { return reflectDeref(t, p).Len() == 0 } case reflect.Slice: return func(p unsafe.Pointer) bool { return (*slice)(p).len == 0 } @@ -955,6 +982,14 @@ func emptyFuncOf(t reflect.Type) emptyFunc { return func(unsafe.Pointer) bool { return false } } +func reflectDeref(t reflect.Type, p unsafe.Pointer) reflect.Value { + return reflect.NewAt(t, p).Elem() +} + +func unsafeToAny(t reflect.Type, p unsafe.Pointer) any { + return reflectDeref(t, p).Interface() +} + type iface struct { typ unsafe.Pointer ptr unsafe.Pointer @@ -972,15 +1007,16 @@ type structType struct { ficaseIndex map[string]*structField keyset []byte typ reflect.Type - inlined bool } type structField struct { codec codec offset uintptr - empty emptyFunc + isEmpty emptyFunc + isZero emptyFunc tag bool omitempty bool + omitzero bool json string html string name string @@ -1066,53 +1102,56 @@ type sliceHeader struct { Cap int } +type isZeroer interface{ IsZero() bool } + var ( nullType = reflect.TypeOf(nil) - boolType = reflect.TypeOf(false) - - intType = reflect.TypeOf(int(0)) - int8Type = reflect.TypeOf(int8(0)) - int16Type = reflect.TypeOf(int16(0)) - int32Type = reflect.TypeOf(int32(0)) - int64Type = reflect.TypeOf(int64(0)) - - uintType = reflect.TypeOf(uint(0)) - uint8Type = reflect.TypeOf(uint8(0)) - uint16Type = reflect.TypeOf(uint16(0)) - uint32Type = reflect.TypeOf(uint32(0)) - uint64Type = reflect.TypeOf(uint64(0)) - uintptrType = reflect.TypeOf(uintptr(0)) - - float32Type = reflect.TypeOf(float32(0)) - float64Type = reflect.TypeOf(float64(0)) - - bigIntType = reflect.TypeOf(new(big.Int)) - numberType = reflect.TypeOf(json.Number("")) - stringType = reflect.TypeOf("") - stringsType = reflect.TypeOf([]string(nil)) - bytesType = reflect.TypeOf(([]byte)(nil)) - durationType = reflect.TypeOf(time.Duration(0)) - timeType = reflect.TypeOf(time.Time{}) - rawMessageType = reflect.TypeOf(RawMessage(nil)) - - numberPtrType = reflect.PtrTo(numberType) - durationPtrType = reflect.PtrTo(durationType) - timePtrType = reflect.PtrTo(timeType) - rawMessagePtrType = reflect.PtrTo(rawMessageType) - - sliceInterfaceType = reflect.TypeOf(([]any)(nil)) - sliceStringType = reflect.TypeOf(([]any)(nil)) - mapStringInterfaceType = reflect.TypeOf((map[string]any)(nil)) - mapStringRawMessageType = reflect.TypeOf((map[string]RawMessage)(nil)) - mapStringStringType = reflect.TypeOf((map[string]string)(nil)) - mapStringStringSliceType = reflect.TypeOf((map[string][]string)(nil)) - mapStringBoolType = reflect.TypeOf((map[string]bool)(nil)) - - interfaceType = reflect.TypeOf((*any)(nil)).Elem() - jsonMarshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() - jsonUnmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() - textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() - textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + boolType = reflect.TypeFor[bool]() + + intType = reflect.TypeFor[int]() + int8Type = reflect.TypeFor[int8]() + int16Type = reflect.TypeFor[int16]() + int32Type = reflect.TypeFor[int32]() + int64Type = reflect.TypeFor[int64]() + + uintType = reflect.TypeFor[uint]() + uint8Type = reflect.TypeFor[uint8]() + uint16Type = reflect.TypeFor[uint16]() + uint32Type = reflect.TypeFor[uint32]() + uint64Type = reflect.TypeFor[uint64]() + uintptrType = reflect.TypeFor[uintptr]() + + float32Type = reflect.TypeFor[float32]() + float64Type = reflect.TypeFor[float64]() + + bigIntType = reflect.TypeFor[*big.Int]() + numberType = reflect.TypeFor[json.Number]() + stringType = reflect.TypeFor[string]() + stringsType = reflect.TypeFor[[]string]() + bytesType = reflect.TypeFor[[]byte]() + durationType = reflect.TypeFor[time.Duration]() + timeType = reflect.TypeFor[time.Time]() + rawMessageType = reflect.TypeFor[RawMessage]() + + numberPtrType = reflect.PointerTo(numberType) + durationPtrType = reflect.PointerTo(durationType) + timePtrType = reflect.PointerTo(timeType) + rawMessagePtrType = reflect.PointerTo(rawMessageType) + + sliceInterfaceType = reflect.TypeFor[[]any]() + sliceStringType = reflect.TypeFor[[]any]() + mapStringInterfaceType = reflect.TypeFor[map[string]any]() + mapStringRawMessageType = reflect.TypeFor[map[string]RawMessage]() + mapStringStringType = reflect.TypeFor[map[string]string]() + mapStringStringSliceType = reflect.TypeFor[map[string][]string]() + mapStringBoolType = reflect.TypeFor[map[string]bool]() + + interfaceType = reflect.TypeFor[any]() + jsonMarshalerType = reflect.TypeFor[Marshaler]() + jsonUnmarshalerType = reflect.TypeFor[Unmarshaler]() + textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]() + textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() + isZeroerType = reflect.TypeFor[isZeroer]() bigIntDecoder = constructJSONUnmarshalerDecodeFunc(bigIntType, false) ) diff --git a/json/decode.go b/json/decode.go index b44dde6..c87f01e 100644 --- a/json/decode.go +++ b/json/decode.go @@ -1410,7 +1410,7 @@ func (d decoder) decodeMaybeEmptyInterface(b []byte, p unsafe.Pointer, t reflect return d.decodeUnmarshalTypeError(b, p, t) } -func (d decoder) decodeUnmarshalTypeError(b []byte, p unsafe.Pointer, t reflect.Type) ([]byte, error) { +func (d decoder) decodeUnmarshalTypeError(b []byte, _ unsafe.Pointer, t reflect.Type) ([]byte, error) { v, b, _, err := d.parseValue(b) if err != nil { return b, err @@ -1500,7 +1500,7 @@ func (d decoder) decodeTextUnmarshaler(b []byte, p unsafe.Pointer, t reflect.Typ value = "array" } - return b, &UnmarshalTypeError{Value: value, Type: reflect.PtrTo(t)} + return b, &UnmarshalTypeError{Value: value, Type: reflect.PointerTo(t)} } func (d decoder) prependField(key, field string) string { diff --git a/json/encode.go b/json/encode.go index 2a6da07..5d0c029 100644 --- a/json/encode.go +++ b/json/encode.go @@ -748,7 +748,11 @@ func (e encoder) encodeStruct(b []byte, p unsafe.Pointer, st *structType) ([]byt f := &st.fields[i] v := unsafe.Pointer(uintptr(p) + f.offset) - if f.omitempty && f.empty(v) { + switch { + case f.omitempty && f.isEmpty(v): + continue + + case f.omitzero && f.isZero(v): continue } @@ -794,22 +798,21 @@ func (e encoder) encodeEmbeddedStructPointer(b []byte, p unsafe.Pointer, t refle } func (e encoder) encodePointer(b []byte, p unsafe.Pointer, t reflect.Type, encode encodeFunc) ([]byte, error) { - if p = *(*unsafe.Pointer)(p); p != nil { - if e.ptrDepth++; e.ptrDepth >= startDetectingCyclesAfter { - if _, seen := e.ptrSeen[p]; seen { - // TODO: reconstruct the reflect.Value from p + t so we can set - // the erorr's Value field? - return b, &UnsupportedValueError{Str: fmt.Sprintf("encountered a cycle via %s", t)} - } - if e.ptrSeen == nil { - e.ptrSeen = make(map[unsafe.Pointer]struct{}) - } - e.ptrSeen[p] = struct{}{} - defer delete(e.ptrSeen, p) - } - return encode(e, b, p) + // p was a pointer to the actual user data pointer: + // dereference it to operate on the user data pointer. + p = *(*unsafe.Pointer)(p) + if p == nil { + return e.encodeNull(b, nil) } - return e.encodeNull(b, nil) + + err := checkRefCycle(&e, t, p) + if err != nil { + return b, err + } + + defer freeRefCycleInfo(&e, p) + + return encode(e, b, p) } func (e encoder) encodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) { @@ -968,3 +971,46 @@ func appendCompactEscapeHTML(dst []byte, src []byte) []byte { return dst } + +// checkRefCycle returns an error if a reference cycle was detected. +func checkRefCycle(e *encoder, t reflect.Type, p unsafe.Pointer) error { + e.ptrDepth++ + if e.ptrDepth < startDetectingCyclesAfter { + return nil + } + + _, seen := e.ptrSeen[p] + if seen { + v := reflect.NewAt(t, p) + return &UnsupportedValueError{ + Value: v, + Str: fmt.Sprintf("encountered a cycle via %s", t), + } + } + + if e.ptrSeen == nil { + e.ptrSeen = cycleMapPool.Get().(cycleMap) + } + + e.ptrSeen[p] = struct{}{} + + return nil +} + +func freeRefCycleInfo(e *encoder, p unsafe.Pointer) { + if e.ptrSeen == nil { + // The map hasn't yet been allocated (not enough recursion depth), + // so there's not any work to do in this function. + return + } + + delete(e.ptrSeen, p) + if len(e.ptrSeen) != 0 { + // There are other keys in the map, so we can't release it into the pool. + return + } + + m := e.ptrSeen + e.ptrSeen = nil + cycleMapPool.Put(m) +} diff --git a/json/golang_encode_test.go b/json/golang_encode_test.go index 5e334a6..86f4f8d 100644 --- a/json/golang_encode_test.go +++ b/json/golang_encode_test.go @@ -136,21 +136,85 @@ func TestEncodeRenamedByteSlice(t *testing.T) { } } -var unsupportedValues = []any{ - math.NaN(), - math.Inf(-1), - math.Inf(1), +type SamePointerNoCycle struct { + Ptr1, Ptr2 *SamePointerNoCycle +} + +var samePointerNoCycle = &SamePointerNoCycle{} + +type PointerCycle struct { + Ptr *PointerCycle +} + +var pointerCycle = &PointerCycle{} + +type PointerCycleIndirect struct { + Ptrs []any +} + +type RecursiveSlice []RecursiveSlice + +var ( + pointerCycleIndirect = &PointerCycleIndirect{} + mapCycle = make(map[string]any) + sliceCycle = []any{nil} + sliceNoCycle = []any{nil, nil} + recursiveSliceCycle = []RecursiveSlice{nil} +) + +func init() { + ptr := &SamePointerNoCycle{} + samePointerNoCycle.Ptr1 = ptr + samePointerNoCycle.Ptr2 = ptr + + pointerCycle.Ptr = pointerCycle + pointerCycleIndirect.Ptrs = []any{pointerCycleIndirect} + + mapCycle["x"] = mapCycle + sliceCycle[0] = sliceCycle + sliceNoCycle[1] = sliceNoCycle[:1] + for i := startDetectingCyclesAfter; i > 0; i-- { + sliceNoCycle = []any{sliceNoCycle} + } + recursiveSliceCycle[0] = recursiveSliceCycle +} + +func TestSamePointerNoCycle(t *testing.T) { + if _, err := Marshal(samePointerNoCycle); err != nil { + t.Fatalf("Marshal error: %v", err) + } +} + +func TestSliceNoCycle(t *testing.T) { + if _, err := Marshal(sliceNoCycle); err != nil { + t.Fatalf("Marshal error: %v", err) + } } func TestUnsupportedValues(t *testing.T) { - for _, v := range unsupportedValues { - if _, err := Marshal(v); err != nil { - if _, ok := err.(*UnsupportedValueError); !ok { - t.Errorf("for %v, got %T want UnsupportedValueError", v, err) + tests := []struct { + CaseName + in any + }{ + {Name(""), math.NaN()}, + {Name(""), math.Inf(-1)}, + {Name(""), math.Inf(1)}, + {Name(""), pointerCycle}, + {Name(""), pointerCycleIndirect}, + {Name(""), mapCycle}, + {Name(""), sliceCycle}, + {Name(""), recursiveSliceCycle}, + } + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + if _, err := Marshal(tt.in); err != nil { + if _, ok := err.(*UnsupportedValueError); !ok { + t.Errorf("%s: Marshal error:\n\tgot: %T\n\twant: %T", tt.Where, err, new(UnsupportedValueError)) + } + } else { + t.Errorf("%s: Marshal error: got nil, want non-nil", tt.Where) } - } else { - t.Errorf("for %v, expected error", v) - } + }) } } diff --git a/json/golang_shim_test.go b/json/golang_shim_test.go index 5a19b7f..90e4fa9 100644 --- a/json/golang_shim_test.go +++ b/json/golang_shim_test.go @@ -4,7 +4,10 @@ package json import ( "bytes" + "fmt" + "path" "reflect" + "runtime" "sync" "testing" ) @@ -68,3 +71,30 @@ func errorWithPrefixes(t *testing.T, prefixes []any, format string, elements ... } t.Errorf(fullFormat, allElements...) } + +// ============================================================================= +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// CaseName is a case name annotated with a file and line. +type CaseName struct { + Name string + Where CasePos +} + +// Name annotates a case name with the file and line of the caller. +func Name(s string) (c CaseName) { + c.Name = s + runtime.Callers(2, c.Where.pc[:]) + return c +} + +// CasePos represents a file and line number. +type CasePos struct{ pc [1]uintptr } + +func (pos CasePos) String() string { + frames := runtime.CallersFrames(pos.pc[:]) + frame, _ := frames.Next() + return fmt.Sprintf("%s:%d", path.Base(frame.File), frame.Line) +} diff --git a/json/json.go b/json/json.go index 11ec69c..028fd1f 100644 --- a/json/json.go +++ b/json/json.go @@ -15,7 +15,7 @@ import ( type Delim = json.Delim // InvalidUTF8Error is documented at https://golang.org/pkg/encoding/json/#InvalidUTF8Error -type InvalidUTF8Error = json.InvalidUTF8Error +type InvalidUTF8Error = json.InvalidUTF8Error //nolint:staticcheck // compat. // InvalidUnmarshalError is documented at https://golang.org/pkg/encoding/json/#InvalidUnmarshalError type InvalidUnmarshalError = json.InvalidUnmarshalError @@ -39,7 +39,7 @@ type SyntaxError = json.SyntaxError type Token = json.Token // UnmarshalFieldError is documented at https://golang.org/pkg/encoding/json/#UnmarshalFieldError -type UnmarshalFieldError = json.UnmarshalFieldError +type UnmarshalFieldError = json.UnmarshalFieldError //nolint:staticcheck // compat. // UnmarshalTypeError is documented at https://golang.org/pkg/encoding/json/#UnmarshalTypeError type UnmarshalTypeError = json.UnmarshalTypeError diff --git a/json/json_test.go b/json/json_test.go index b40e000..8256be2 100644 --- a/json/json_test.go +++ b/json/json_test.go @@ -240,7 +240,7 @@ var testValues = [...]any{ A string `json:"name"` B string `json:"-"` C string `json:",omitempty"` - D map[string]any `json:",string"` + D map[string]any `json:",string"` //nolint:staticcheck // intentional e string }{A: "Luke", D: map[string]any{"answer": float64(42)}}, struct{ point }{point{1, 2}}, @@ -880,12 +880,11 @@ func TestDecodeLines(t *testing.T) { t.Run(test.desc, func(t *testing.T) { d := NewDecoder(test.reader) var count int - var err error for { var o obj - err = d.Decode(&o) + err := d.Decode(&o) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } @@ -904,10 +903,6 @@ func TestDecodeLines(t *testing.T) { count++ } - if err != nil && err != io.EOF { - t.Error(err) - } - if count != test.expectCount { t.Errorf("expected %d objects, got %d", test.expectCount, count) } diff --git a/json/parse.go b/json/parse.go index 949e7f3..d0ee221 100644 --- a/json/parse.go +++ b/json/parse.go @@ -21,11 +21,6 @@ const ( cr = '\r' ) -const ( - escape = '\\' - quote = '"' -) - func internalParseFlags(b []byte) (flags ParseFlags) { // Don't consider surrounding whitespace b = skipSpaces(b)