Skip to content

Commit 3773447

Browse files
authored
fix(go/internal): refine JSON markdown detection (#3661)
1 parent 107be7a commit 3773447

File tree

6 files changed

+43
-19
lines changed

6 files changed

+43
-19
lines changed

go/ai/format_array.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func (a arrayHandler) ParseMessage(m *Message) (*Message, error) {
9191
}
9292

9393
var newParts []*Part
94-
lines := base.GetJsonObjectLines(accumulatedText.String())
94+
lines := base.GetJSONObjectLines(accumulatedText.String())
9595
for _, line := range lines {
9696
var schemaBytes []byte
9797
schemaBytes, err := json.Marshal(a.config.Schema["items"])

go/ai/format_jsonl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) {
9292
}
9393

9494
var newParts []*Part
95-
lines := base.GetJsonObjectLines(accumulatedText.String())
95+
lines := base.GetJSONObjectLines(accumulatedText.String())
9696
for _, line := range lines {
9797
if j.config.Schema != nil {
9898
var schemaBytes []byte

go/ai/formatter_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import (
2323
)
2424

2525
func TestConstrainedGenerate(t *testing.T) {
26-
JSON := "\n{\"foo\": \"bar\"}\n"
26+
JSON := "{\"foo\": \"bar\"}"
2727
JSONmd := "```json" + JSON + "```"
2828

2929
modelOpts := ModelOptions{
@@ -532,7 +532,7 @@ func TestJsonParser(t *testing.T) {
532532
want: &Message{
533533
Role: RoleModel,
534534
Content: []*Part{
535-
NewJSONPart("\n{\"name\": \"John\", \"age\": 19}\n"),
535+
NewJSONPart("{\"name\": \"John\", \"age\": 19}"),
536536
},
537537
},
538538
wantErr: false,
@@ -573,7 +573,7 @@ func TestJsonParser(t *testing.T) {
573573
want: &Message{
574574
Role: RoleModel,
575575
Content: []*Part{
576-
NewJSONPart("\n{\"id\": 1}\n"),
576+
NewJSONPart("{\"id\": 1}"),
577577
},
578578
},
579579
wantErr: false,
@@ -1062,7 +1062,7 @@ func TestJsonParserStreaming(t *testing.T) {
10621062
want: &Message{
10631063
Role: RoleModel,
10641064
Content: []*Part{
1065-
NewJSONPart("\n{\"id\": 1}\n"),
1065+
NewJSONPart("{\"id\": 1}"),
10661066
},
10671067
},
10681068
wantErr: false,

go/ai/generate_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ func TestValidMessage(t *testing.T) {
242242
}
243243

244244
func TestGenerate(t *testing.T) {
245-
JSON := "\n{\"subject\": \"bananas\", \"location\": \"tropics\"}\n"
245+
JSON := "{\"subject\": \"bananas\", \"location\": \"tropics\"}"
246246
JSONmd := "```json" + JSON + "```"
247247

248248
bananaModel := DefineModel(r, "test/banana", &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) {
@@ -629,7 +629,8 @@ func TestGenerate(t *testing.T) {
629629
dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered",
630630
func(ctx *ToolContext, input struct {
631631
Message string
632-
}) (string, error) {
632+
},
633+
) (string, error) {
633634
return "Dynamic: " + input.Message, nil
634635
},
635636
)
@@ -815,7 +816,8 @@ func TestToolInterruptsAndResume(t *testing.T) {
815816
func(ctx *ToolContext, input struct {
816817
Value string
817818
Interrupt bool
818-
}) (string, error) {
819+
},
820+
) (string, error) {
819821
if input.Interrupt {
820822
return "", ctx.Interrupt(&InterruptOptions{
821823
Metadata: map[string]any{
@@ -833,7 +835,8 @@ func TestToolInterruptsAndResume(t *testing.T) {
833835
func(ctx *ToolContext, input struct {
834836
Action string
835837
Data string
836-
}) (string, error) {
838+
},
839+
) (string, error) {
837840
if ctx.Resumed != nil {
838841
resumedData, ok := ctx.Resumed["data"].(string)
839842
if ok {

go/internal/base/json.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,23 @@ func SchemaAsMap(s *jsonschema.Schema) map[string]any {
119119
return m
120120
}
121121

122-
var jsonMarkdownRegex = regexp.MustCompile("```(json)?((\n|.)*?)```")
122+
// jsonMarkdownRegex specifically looks for "json" language identifier
123+
var jsonMarkdownRegex = regexp.MustCompile("(?s)```json(.*?)```")
123124

124125
// ExtractJSONFromMarkdown returns the contents of the first fenced code block in
125126
// the markdown text md. If there is none, it returns md.
126127
func ExtractJSONFromMarkdown(md string) string {
127-
// TODO: improve this
128128
matches := jsonMarkdownRegex.FindStringSubmatch(md)
129-
if matches == nil {
129+
if len(matches) < 2 {
130130
return md
131131
}
132-
return matches[2]
132+
// capture group 1 matches the actual fenced JSON block
133+
return strings.TrimSpace(matches[1])
133134
}
134135

135-
// GetJsonObjectLines splits a string by newlines, trims whitespace from each line,
136+
// GetJSONObjectLines splits a string by newlines, trims whitespace from each line,
136137
// and returns a slice containing only the lines that start with '{'.
137-
func GetJsonObjectLines(text string) []string {
138+
func GetJSONObjectLines(text string) []string {
138139
jsonText := ExtractJSONFromMarkdown(text)
139140

140141
// Handle both actual "\n" newline strings, as well as newline bytes

go/internal/base/json_test.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,43 @@ func TestExtractJSONFromMarkdown(t *testing.T) {
4141
{
4242
desc: "simple markdown",
4343
in: "```foo bar```",
44-
want: "foo bar",
44+
want: "```foo bar```",
45+
},
46+
{
47+
desc: "empty markdown",
48+
in: "``` ```",
49+
want: "``` ```",
4550
},
4651
{
4752
desc: "json markdown",
4853
in: "```json{\"a\":1}```",
4954
want: "{\"a\":1}",
5055
},
5156
{
52-
desc: "json multipline markdown",
57+
desc: "json multiple line markdown",
5358
in: "```json\n{\"a\": 1}\n```",
54-
want: "\n{\"a\": 1}\n",
59+
want: "{\"a\": 1}",
5560
},
5661
{
5762
desc: "returns first of multiple blocks",
5863
in: "```json{\"a\":\n1}```\n```json\n{\"b\":\n1}```",
5964
want: "{\"a\":\n1}",
6065
},
66+
{
67+
desc: "yaml markdown",
68+
in: "```yaml\nkey: 1\nanother-key: 2```",
69+
want: "```yaml\nkey: 1\nanother-key: 2```",
70+
},
71+
{
72+
desc: "yaml + json markdown",
73+
in: "```yaml\nkey: 1\nanother-key: 2``` ```json\n{\"a\": 1}\n```",
74+
want: "{\"a\": 1}",
75+
},
76+
{
77+
desc: "json + yaml markdown",
78+
in: "```json\n{\"a\": 1}\n``` ```yaml\nkey: 1\nanother-key: 2```",
79+
want: "{\"a\": 1}",
80+
},
6181
}
6282
for _, tc := range tests {
6383
t.Run(tc.desc, func(t *testing.T) {

0 commit comments

Comments
 (0)