diff --git a/arshal_default.go b/arshal_default.go index 6481629..6eaec64 100644 --- a/arshal_default.go +++ b/arshal_default.go @@ -786,6 +786,9 @@ func makeMapArshaler(t reflect.Type) *arshaler { k.SetIterKey(iter) err := marshalKey(enc, k, mo) if err != nil { + if errors.Is(err, errSkipMember) { + continue + } if mo.Flags.Get(jsonflags.CallMethodsWithLegacySemantics) && errors.Is(err, jsontext.ErrNonStringName) && nillableLegacyKey && k.IsNil() { err = enc.WriteToken(jsontext.String("")) @@ -799,6 +802,14 @@ func makeMapArshaler(t reflect.Type) *arshaler { } v.SetIterValue(iter) if err := marshalVal(enc, v, mo); err != nil { + if errors.Is(err, errSkipMember) { + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixString(xe.Buf) + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixByte(xe.Buf, ',') + xe.Tokens.Last-- + continue + } return err } } @@ -817,6 +828,14 @@ func makeMapArshaler(t reflect.Type) *arshaler { k.SetString(name) v.Set(va.MapIndex(k.Value)) if err := marshalVal(enc, v, mo); err != nil { + if errors.Is(err, errSkipMember) { + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixString(xe.Buf) + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixByte(xe.Buf, ',') + xe.Tokens.Last-- + continue + } return err } } @@ -838,6 +857,9 @@ func makeMapArshaler(t reflect.Type) *arshaler { v.SetIterValue(iter) err := marshalKey(enc, k, mo) if err != nil { + if errors.Is(err, errSkipMember) { + continue + } if mo.Flags.Get(jsonflags.CallMethodsWithLegacySemantics) && errors.Is(err, jsontext.ErrNonStringName) && nillableLegacyKey && k.IsNil() { err = enc.WriteToken(jsontext.String("")) @@ -863,6 +885,14 @@ func makeMapArshaler(t reflect.Type) *arshaler { return err } if err := marshalVal(enc, member.val, mo); err != nil { + if errors.Is(err, errSkipMember) { + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixString(xe.Buf) + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixByte(xe.Buf, ',') + xe.Tokens.Last-- + continue + } return err } } @@ -979,6 +1009,10 @@ func makeMapArshaler(t reflect.Type) *arshaler { seen.SetMapIndex(k.Value, reflect.Zero(emptyStructType)) } if err != nil { + if errors.Is(err, errSkipMember) { + dec.ReadToken() + continue + } if isFatalError(err, uo.Flags) { return err } @@ -1141,6 +1175,14 @@ func makeStructArshaler(t reflect.Type) *arshaler { mo.Flags = flagsOriginal mo.Format = "" if err != nil { + if errors.Is(err, errSkipMember) { + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixString(xe.Buf) + xe.Buf = jsonwire.TrimSuffixWhitespace(xe.Buf) + xe.Buf = jsonwire.TrimSuffixByte(xe.Buf, ',') + xe.Tokens.Last-- + continue + } return err } @@ -1304,6 +1346,10 @@ func makeStructArshaler(t reflect.Type) *arshaler { uo.Flags = flagsOriginal uo.Format = "" if err != nil { + if errors.Is(err, errSkipMember) { + dec.ReadToken() + continue // skip member + } if isFatalError(err, uo.Flags) { return err } @@ -1444,6 +1490,9 @@ func makeSliceArshaler(t reflect.Type) *arshaler { for i := range n { v := addressableValue{va.Index(i), false} // indexed slice element is always addressable if err := marshal(enc, v, mo); err != nil { + if errors.Is(err, errSkipMember) { + continue + } return err } } @@ -1499,6 +1548,10 @@ func makeSliceArshaler(t reflect.Type) *arshaler { v.SetZero() } if err := unmarshal(dec, v, uo); err != nil { + if errors.Is(err, errSkipMember) { + dec.ReadToken() + continue // skip member + } if isFatalError(err, uo.Flags) { va.SetLen(i) return err @@ -1550,6 +1603,9 @@ func makeArrayArshaler(t reflect.Type) *arshaler { for i := range n { v := addressableValue{va.Index(i), va.forcedAddr} // indexed array element is addressable if array is addressable if err := marshal(enc, v, mo); err != nil { + if errors.Is(err, errSkipMember) { + continue + } return err } } @@ -1595,6 +1651,10 @@ func makeArrayArshaler(t reflect.Type) *arshaler { v.SetZero() } if err := unmarshal(dec, v, uo); err != nil { + if errors.Is(err, errSkipMember) { + dec.ReadToken() + continue // skip member + } if isFatalError(err, uo.Flags) { return err } diff --git a/arshal_funcs.go b/arshal_funcs.go index 0b7e82f..21477be 100644 --- a/arshal_funcs.go +++ b/arshal_funcs.go @@ -25,6 +25,14 @@ import ( // [jsontext.Encoder.WriteToken] since such methods mutate the state. var SkipFunc = errors.New("json: skip function") +// SkipMember may be returned by [MarshalToFunc] and [UnmarshalFromFunc] functions. +// +// Any function that returns SkipMember must not cause observable side effects +// on the provided [jsontext.Encoder] or [jsontext.Decoder]. +// When a member (json property) is skipped, the corresponding member name (key name) +// is also removed from the byte stream. +var errSkipMember = errors.New("json: skip member") + var errSkipMutation = errors.New("must not read or write any tokens when skipping") var errNonSingularValue = errors.New("must read or write exactly one value") diff --git a/arshal_methods.go b/arshal_methods.go index 10efe14..e13563f 100644 --- a/arshal_methods.go +++ b/arshal_methods.go @@ -193,6 +193,9 @@ func makeMethodArshaler(fncs *arshaler, t reflect.Type) *arshaler { prevDepth, prevLength := xe.Tokens.DepthLength() xe.Flags.Set(jsonflags.WithinArshalCall | 1) err := va.Addr().Interface().(MarshalerTo).MarshalJSONTo(enc) + if errors.Is(err, errSkipMember) { // should this apply to v1 and v2? + return err + } xe.Flags.Set(jsonflags.WithinArshalCall | 0) currDepth, currLength := xe.Tokens.DepthLength() if (prevDepth != currDepth || prevLength+1 != currLength) && err == nil { @@ -283,6 +286,9 @@ func makeMethodArshaler(fncs *arshaler, t reflect.Type) *arshaler { prevDepth, prevLength := xd.Tokens.DepthLength() xd.Flags.Set(jsonflags.WithinArshalCall | 1) err := va.Addr().Interface().(UnmarshalerFrom).UnmarshalJSONFrom(dec) + if errors.Is(err, errSkipMember) { + return err + } xd.Flags.Set(jsonflags.WithinArshalCall | 0) currDepth, currLength := xd.Tokens.DepthLength() if (prevDepth != currDepth || prevLength+1 != currLength) && err == nil { diff --git a/arshal_test.go b/arshal_test.go index c543dfa..7da4511 100644 --- a/arshal_test.go +++ b/arshal_test.go @@ -560,6 +560,16 @@ type ( pointerAlwaysZero string pointerNeverZero string + skipStruct struct { + namedString string + } + skipString string + skipInt int + skipFloat64 float64 + skipBool bool + skipMap map[string]string + skipSlice []string + valueStringer struct{} pointerStringer struct{} @@ -698,6 +708,49 @@ func addr[T any](v T) *T { return &v } +func (*skipString) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipString) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} +func (*skipInt) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipInt) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} +func (*skipFloat64) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipFloat64) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} +func (*skipBool) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipBool) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} +func (*skipStruct) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipStruct) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} +func (*skipMap) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipMap) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} +func (*skipSlice) MarshalJSONTo(enc *jsontext.Encoder) error { + return errSkipMember +} +func (*skipSlice) UnmarshalJSONFrom(*jsontext.Decoder) error { + return errSkipMember +} + func mustParseTime(layout, value string) time.Time { t, err := time.Parse(layout, value) if err != nil { @@ -4544,6 +4597,60 @@ func TestMarshal(t *testing.T) { opts: []Options{invalidFormatOption}, in: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), want: `"2000-01-01T00:00:00Z"`, + }, { + name: jsontest.Name("Skip/Struct"), + in: skipStruct{ + namedString: "this is a named string", + }, + want: ``, + wantErr: errSkipMember, + }, { + name: jsontest.Name("Skip/StructMember"), + in: struct { + X1 string + Y1 skipString + X2 int + Y2 skipInt + X3 float64 + Y3 skipFloat64 + X4 bool + Y4 skipBool + }{ + "this is X", + "skipString", + 5, + 5, + 0.123, + 0.123, + true, + true, + }, + want: `{"X1":"this is X","X2":5,"X3":0.123,"X4":true}`, + }, { + name: jsontest.Name("Skip/Map"), + in: struct { + skipMap + }{ + skipMap{"hello": "world"}, + }, + want: ``, + wantErr: errSkipMember, + }, { + name: jsontest.Name("Skip/MapMember"), + in: map[string]any{ + "X1": "X1 should be visible", + "X2": skipString("X2 should not be visible"), + }, + want: `{"X1":"X1 should be visible"}`, + }, { + name: jsontest.Name("Skip/Slice"), + in: skipSlice{"hello", "world!"}, + want: "", + wantErr: errSkipMember, + }, { + name: jsontest.Name("Skip/SliceMember"), + in: []any{"Visible", skipString("Hidden")}, + want: "[\"Visible\"]", }} for _, tt := range tests { @@ -9110,6 +9217,48 @@ func TestUnmarshal(t *testing.T) { inBuf: `"2000-01-01T00:00:00Z"`, inVal: addr(time.Time{}), want: addr(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + }, { + name: jsontest.Name("Skip/Struct"), + inBuf: `{"X":"this is namedString"}`, + inVal: new(skipStruct), + want: addr(skipStruct{}), + wantErr: errSkipMember, + }, { + name: jsontest.Name("Skip/StructMember"), + inBuf: `{"X":"hello", "Y": "should not show"}`, + inVal: new(struct { + X string + Y skipString + }), + want: addr(struct { + X string + Y skipString + }{X: "hello"}), + }, { + name: jsontest.Name("Skip/MapMember"), + inBuf: `{"K1":"hello", "K2": "this is wrong", "K3": 3}`, + inVal: addr(map[string]any{ + "K1": string("should not show"), + "K2": skipString("this is correct"), + "K3": 0, + }), + want: addr(map[string]any{ + "K1": string("hello"), + "K2": skipString("this is correct"), + "K3": 3, + }), + }, { + name: jsontest.Name("Skip/ArrayMember"), + inBuf: `["X", "Y", "Z"]`, + inVal: addr([]skipString{}), + // using make is needed to solve issue + // *[]json.skipString, &[%!t(json.skipString=) %!t(json.skipString=) %!t(json.skipString=)] + want: addr(make([]skipString, 3)), + }, { + name: jsontest.Name("Skip/SliceMember"), + inBuf: `["X", "Y", "Z"]`, + inVal: addr([]skipString{}), + want: addr(make([]skipString, 3)), }} for _, tt := range tests {