Skip to content

Commit 4f4b66a

Browse files
jbasamthanawalla
authored andcommitted
jsonschema: inference ignores invalid types (#147)
Add ForLax[T], which ignores invalid types in schema inference instead of returning an error. This allows additional customization of a schema after inference does what it can. For example, an interface type where all the possible implementations are known can be described with "oneof". For #136.
1 parent 2ab004f commit 4f4b66a

File tree

2 files changed

+162
-90
lines changed

2 files changed

+162
-90
lines changed

jsonschema/infer.go

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,30 @@ import (
4747
func For[T any]() (*Schema, error) {
4848
// TODO: consider skipping incompatible fields, instead of failing.
4949
seen := make(map[reflect.Type]bool)
50-
s, err := forType(reflect.TypeFor[T](), seen)
50+
s, err := forType(reflect.TypeFor[T](), seen, false)
5151
if err != nil {
5252
var z T
5353
return nil, fmt.Errorf("For[%T](): %w", z, err)
5454
}
5555
return s, nil
5656
}
5757

58-
func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
58+
// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of
59+
// returning an error. That allows callers to adjust the resulting schema using custom knowledge.
60+
// For example, an interface type where all the possible implementations are known
61+
// can be described with "oneof".
62+
func ForLax[T any]() (*Schema, error) {
63+
// TODO: consider skipping incompatible fields, instead of failing.
64+
seen := make(map[reflect.Type]bool)
65+
s, err := forType(reflect.TypeFor[T](), seen, true)
66+
if err != nil {
67+
var z T
68+
return nil, fmt.Errorf("ForLax[%T](): %w", z, err)
69+
}
70+
return s, nil
71+
}
72+
73+
func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) {
5974
// Follow pointers: the schema for *T is almost the same as for T, except that
6075
// an explicit JSON "null" is allowed for the pointer.
6176
allowNull := false
@@ -96,20 +111,33 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
96111

97112
case reflect.Map:
98113
if t.Key().Kind() != reflect.String {
114+
if lax {
115+
return nil, nil // ignore
116+
}
99117
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
100118
}
119+
if t.Key().Kind() != reflect.String {
120+
}
101121
s.Type = "object"
102-
s.AdditionalProperties, err = forType(t.Elem(), seen)
122+
s.AdditionalProperties, err = forType(t.Elem(), seen, lax)
103123
if err != nil {
104124
return nil, fmt.Errorf("computing map value schema: %v", err)
105125
}
126+
if lax && s.AdditionalProperties == nil {
127+
// Ignore if the element type is invalid.
128+
return nil, nil
129+
}
106130

107131
case reflect.Slice, reflect.Array:
108132
s.Type = "array"
109-
s.Items, err = forType(t.Elem(), seen)
133+
s.Items, err = forType(t.Elem(), seen, lax)
110134
if err != nil {
111135
return nil, fmt.Errorf("computing element schema: %v", err)
112136
}
137+
if lax && s.Items == nil {
138+
// Ignore if the element type is invalid.
139+
return nil, nil
140+
}
113141
if t.Kind() == reflect.Array {
114142
s.MinItems = Ptr(t.Len())
115143
s.MaxItems = Ptr(t.Len())
@@ -132,10 +160,14 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
132160
if s.Properties == nil {
133161
s.Properties = make(map[string]*Schema)
134162
}
135-
fs, err := forType(field.Type, seen)
163+
fs, err := forType(field.Type, seen, lax)
136164
if err != nil {
137165
return nil, err
138166
}
167+
if lax && fs == nil {
168+
// Skip fields of invalid type.
169+
continue
170+
}
139171
if tag, ok := field.Tag.Lookup("jsonschema"); ok {
140172
if tag == "" {
141173
return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name)
@@ -152,6 +184,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
152184
}
153185

154186
default:
187+
if lax {
188+
// Ignore.
189+
return nil, nil
190+
}
155191
return nil, fmt.Errorf("type %v is unsupported by jsonschema", t)
156192
}
157193
if allowNull && s.Type != "" {

jsonschema/infer_test.go

Lines changed: 121 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@ import (
1313
"github.com/modelcontextprotocol/go-sdk/jsonschema"
1414
)
1515

16-
func forType[T any]() *jsonschema.Schema {
17-
s, err := jsonschema.For[T]()
16+
func forType[T any](lax bool) *jsonschema.Schema {
17+
var s *jsonschema.Schema
18+
var err error
19+
if lax {
20+
s, err = jsonschema.ForLax[T]()
21+
} else {
22+
s, err = jsonschema.For[T]()
23+
}
1824
if err != nil {
1925
panic(err)
2026
}
@@ -28,104 +34,134 @@ func TestFor(t *testing.T) {
2834
B int `jsonschema:"bdesc"`
2935
}
3036

31-
tests := []struct {
37+
type test struct {
3238
name string
3339
got *jsonschema.Schema
3440
want *jsonschema.Schema
35-
}{
36-
{"string", forType[string](), &schema{Type: "string"}},
37-
{"int", forType[int](), &schema{Type: "integer"}},
38-
{"int16", forType[int16](), &schema{Type: "integer"}},
39-
{"uint32", forType[int16](), &schema{Type: "integer"}},
40-
{"float64", forType[float64](), &schema{Type: "number"}},
41-
{"bool", forType[bool](), &schema{Type: "boolean"}},
42-
{"intmap", forType[map[string]int](), &schema{
43-
Type: "object",
44-
AdditionalProperties: &schema{Type: "integer"},
45-
}},
46-
{"anymap", forType[map[string]any](), &schema{
47-
Type: "object",
48-
AdditionalProperties: &schema{},
49-
}},
50-
{
51-
"struct",
52-
forType[struct {
53-
F int `json:"f" jsonschema:"fdesc"`
54-
G []float64
55-
P *bool `jsonschema:"pdesc"`
56-
Skip string `json:"-"`
57-
NoSkip string `json:",omitempty"`
58-
unexported float64
59-
unexported2 int `json:"No"`
60-
}](),
61-
&schema{
62-
Type: "object",
63-
Properties: map[string]*schema{
64-
"f": {Type: "integer", Description: "fdesc"},
65-
"G": {Type: "array", Items: &schema{Type: "number"}},
66-
"P": {Types: []string{"null", "boolean"}, Description: "pdesc"},
67-
"NoSkip": {Type: "string"},
41+
}
42+
43+
tests := func(lax bool) []test {
44+
return []test{
45+
{"string", forType[string](lax), &schema{Type: "string"}},
46+
{"int", forType[int](lax), &schema{Type: "integer"}},
47+
{"int16", forType[int16](lax), &schema{Type: "integer"}},
48+
{"uint32", forType[int16](lax), &schema{Type: "integer"}},
49+
{"float64", forType[float64](lax), &schema{Type: "number"}},
50+
{"bool", forType[bool](lax), &schema{Type: "boolean"}},
51+
{"intmap", forType[map[string]int](lax), &schema{
52+
Type: "object",
53+
AdditionalProperties: &schema{Type: "integer"},
54+
}},
55+
{"anymap", forType[map[string]any](lax), &schema{
56+
Type: "object",
57+
AdditionalProperties: &schema{},
58+
}},
59+
{
60+
"struct",
61+
forType[struct {
62+
F int `json:"f" jsonschema:"fdesc"`
63+
G []float64
64+
P *bool `jsonschema:"pdesc"`
65+
Skip string `json:"-"`
66+
NoSkip string `json:",omitempty"`
67+
unexported float64
68+
unexported2 int `json:"No"`
69+
}](lax),
70+
&schema{
71+
Type: "object",
72+
Properties: map[string]*schema{
73+
"f": {Type: "integer", Description: "fdesc"},
74+
"G": {Type: "array", Items: &schema{Type: "number"}},
75+
"P": {Types: []string{"null", "boolean"}, Description: "pdesc"},
76+
"NoSkip": {Type: "string"},
77+
},
78+
Required: []string{"f", "G", "P"},
79+
AdditionalProperties: falseSchema(),
6880
},
69-
Required: []string{"f", "G", "P"},
70-
AdditionalProperties: falseSchema(),
7181
},
72-
},
73-
{
74-
"no sharing",
75-
forType[struct{ X, Y int }](),
76-
&schema{
77-
Type: "object",
78-
Properties: map[string]*schema{
79-
"X": {Type: "integer"},
80-
"Y": {Type: "integer"},
82+
{
83+
"no sharing",
84+
forType[struct{ X, Y int }](lax),
85+
&schema{
86+
Type: "object",
87+
Properties: map[string]*schema{
88+
"X": {Type: "integer"},
89+
"Y": {Type: "integer"},
90+
},
91+
Required: []string{"X", "Y"},
92+
AdditionalProperties: falseSchema(),
8193
},
82-
Required: []string{"X", "Y"},
83-
AdditionalProperties: falseSchema(),
8494
},
85-
},
86-
{
87-
"nested and embedded",
88-
forType[struct {
89-
A S
90-
S
91-
}](),
92-
&schema{
93-
Type: "object",
94-
Properties: map[string]*schema{
95-
"A": {
96-
Type: "object",
97-
Properties: map[string]*schema{
98-
"B": {Type: "integer", Description: "bdesc"},
95+
{
96+
"nested and embedded",
97+
forType[struct {
98+
A S
99+
S
100+
}](lax),
101+
&schema{
102+
Type: "object",
103+
Properties: map[string]*schema{
104+
"A": {
105+
Type: "object",
106+
Properties: map[string]*schema{
107+
"B": {Type: "integer", Description: "bdesc"},
108+
},
109+
Required: []string{"B"},
110+
AdditionalProperties: falseSchema(),
99111
},
100-
Required: []string{"B"},
101-
AdditionalProperties: falseSchema(),
102-
},
103-
"S": {
104-
Type: "object",
105-
Properties: map[string]*schema{
106-
"B": {Type: "integer", Description: "bdesc"},
112+
"S": {
113+
Type: "object",
114+
Properties: map[string]*schema{
115+
"B": {Type: "integer", Description: "bdesc"},
116+
},
117+
Required: []string{"B"},
118+
AdditionalProperties: falseSchema(),
107119
},
108-
Required: []string{"B"},
109-
AdditionalProperties: falseSchema(),
110120
},
121+
Required: []string{"A", "S"},
122+
AdditionalProperties: falseSchema(),
111123
},
112-
Required: []string{"A", "S"},
113-
AdditionalProperties: falseSchema(),
114124
},
115-
},
125+
}
116126
}
117127

118-
for _, test := range tests {
119-
t.Run(test.name, func(t *testing.T) {
120-
if diff := cmp.Diff(test.want, test.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
121-
t.Fatalf("ForType mismatch (-want +got):\n%s", diff)
122-
}
123-
// These schemas should all resolve.
124-
if _, err := test.got.Resolve(nil); err != nil {
125-
t.Fatalf("Resolving: %v", err)
126-
}
127-
})
128+
run := func(t *testing.T, tt test) {
129+
if diff := cmp.Diff(tt.want, tt.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
130+
t.Fatalf("ForType mismatch (-want +got):\n%s", diff)
131+
}
132+
// These schemas should all resolve.
133+
if _, err := tt.got.Resolve(nil); err != nil {
134+
t.Fatalf("Resolving: %v", err)
135+
}
128136
}
137+
138+
t.Run("strict", func(t *testing.T) {
139+
for _, test := range tests(false) {
140+
t.Run(test.name, func(t *testing.T) { run(t, test) })
141+
}
142+
})
143+
144+
laxTests := append(tests(true), test{
145+
"ignore",
146+
forType[struct {
147+
A int
148+
B map[int]int
149+
C func()
150+
}](true),
151+
&schema{
152+
Type: "object",
153+
Properties: map[string]*schema{
154+
"A": {Type: "integer"},
155+
},
156+
Required: []string{"A"},
157+
AdditionalProperties: falseSchema(),
158+
},
159+
})
160+
t.Run("lax", func(t *testing.T) {
161+
for _, test := range laxTests {
162+
t.Run(test.name, func(t *testing.T) { run(t, test) })
163+
}
164+
})
129165
}
130166

131167
func forErr[T any]() error {

0 commit comments

Comments
 (0)