diff --git a/go/ai/format_array.go b/go/ai/format_array.go index 3f2322c60b..a063231e45 100644 --- a/go/ai/format_array.go +++ b/go/ai/format_array.go @@ -91,7 +91,7 @@ func (a arrayHandler) ParseMessage(m *Message) (*Message, error) { } var newParts []*Part - lines := base.GetJsonObjectLines(accumulatedText.String()) + lines := base.GetJSONObjectLines(accumulatedText.String()) for _, line := range lines { var schemaBytes []byte schemaBytes, err := json.Marshal(a.config.Schema["items"]) diff --git a/go/ai/format_jsonl.go b/go/ai/format_jsonl.go index 6bc14c347e..f45c180129 100644 --- a/go/ai/format_jsonl.go +++ b/go/ai/format_jsonl.go @@ -92,7 +92,7 @@ func (j jsonlHandler) ParseMessage(m *Message) (*Message, error) { } var newParts []*Part - lines := base.GetJsonObjectLines(accumulatedText.String()) + lines := base.GetJSONObjectLines(accumulatedText.String()) for _, line := range lines { if j.config.Schema != nil { var schemaBytes []byte diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index fd2866d240..1b55f1bad7 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -23,7 +23,7 @@ import ( ) func TestConstrainedGenerate(t *testing.T) { - JSON := "\n{\"foo\": \"bar\"}\n" + JSON := "{\"foo\": \"bar\"}" JSONmd := "```json" + JSON + "```" modelOpts := ModelOptions{ @@ -532,7 +532,7 @@ func TestJsonParser(t *testing.T) { want: &Message{ Role: RoleModel, Content: []*Part{ - NewJSONPart("\n{\"name\": \"John\", \"age\": 19}\n"), + NewJSONPart("{\"name\": \"John\", \"age\": 19}"), }, }, wantErr: false, @@ -573,7 +573,7 @@ func TestJsonParser(t *testing.T) { want: &Message{ Role: RoleModel, Content: []*Part{ - NewJSONPart("\n{\"id\": 1}\n"), + NewJSONPart("{\"id\": 1}"), }, }, wantErr: false, @@ -1062,7 +1062,7 @@ func TestJsonParserStreaming(t *testing.T) { want: &Message{ Role: RoleModel, Content: []*Part{ - NewJSONPart("\n{\"id\": 1}\n"), + NewJSONPart("{\"id\": 1}"), }, }, wantErr: false, diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index 22ed9ca728..bd7bd17c3d 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -242,7 +242,7 @@ func TestValidMessage(t *testing.T) { } func TestGenerate(t *testing.T) { - JSON := "\n{\"subject\": \"bananas\", \"location\": \"tropics\"}\n" + JSON := "{\"subject\": \"bananas\", \"location\": \"tropics\"}" JSONmd := "```json" + JSON + "```" bananaModel := DefineModel(r, "test/banana", &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamCallback) (*ModelResponse, error) { @@ -629,7 +629,8 @@ func TestGenerate(t *testing.T) { dynamicTool := NewTool("dynamicTestTool", "a tool that is dynamically registered", func(ctx *ToolContext, input struct { Message string - }) (string, error) { + }, + ) (string, error) { return "Dynamic: " + input.Message, nil }, ) @@ -815,7 +816,8 @@ func TestToolInterruptsAndResume(t *testing.T) { func(ctx *ToolContext, input struct { Value string Interrupt bool - }) (string, error) { + }, + ) (string, error) { if input.Interrupt { return "", ctx.Interrupt(&InterruptOptions{ Metadata: map[string]any{ @@ -833,7 +835,8 @@ func TestToolInterruptsAndResume(t *testing.T) { func(ctx *ToolContext, input struct { Action string Data string - }) (string, error) { + }, + ) (string, error) { if ctx.Resumed != nil { resumedData, ok := ctx.Resumed["data"].(string) if ok { diff --git a/go/internal/base/json.go b/go/internal/base/json.go index a891c765e0..0a930fb2c6 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -119,22 +119,23 @@ func SchemaAsMap(s *jsonschema.Schema) map[string]any { return m } -var jsonMarkdownRegex = regexp.MustCompile("```(json)?((\n|.)*?)```") +// jsonMarkdownRegex specifically looks for "json" language identifier +var jsonMarkdownRegex = regexp.MustCompile("(?s)```json(.*?)```") // ExtractJSONFromMarkdown returns the contents of the first fenced code block in // the markdown text md. If there is none, it returns md. func ExtractJSONFromMarkdown(md string) string { - // TODO: improve this matches := jsonMarkdownRegex.FindStringSubmatch(md) - if matches == nil { + if len(matches) < 2 { return md } - return matches[2] + // capture group 1 matches the actual fenced JSON block + return strings.TrimSpace(matches[1]) } -// GetJsonObjectLines splits a string by newlines, trims whitespace from each line, +// GetJSONObjectLines splits a string by newlines, trims whitespace from each line, // and returns a slice containing only the lines that start with '{'. -func GetJsonObjectLines(text string) []string { +func GetJSONObjectLines(text string) []string { jsonText := ExtractJSONFromMarkdown(text) // Handle both actual "\n" newline strings, as well as newline bytes diff --git a/go/internal/base/json_test.go b/go/internal/base/json_test.go index 0acecd00d3..86f79ebf60 100644 --- a/go/internal/base/json_test.go +++ b/go/internal/base/json_test.go @@ -41,7 +41,12 @@ func TestExtractJSONFromMarkdown(t *testing.T) { { desc: "simple markdown", in: "```foo bar```", - want: "foo bar", + want: "```foo bar```", + }, + { + desc: "empty markdown", + in: "``` ```", + want: "``` ```", }, { desc: "json markdown", @@ -49,15 +54,30 @@ func TestExtractJSONFromMarkdown(t *testing.T) { want: "{\"a\":1}", }, { - desc: "json multipline markdown", + desc: "json multiple line markdown", in: "```json\n{\"a\": 1}\n```", - want: "\n{\"a\": 1}\n", + want: "{\"a\": 1}", }, { desc: "returns first of multiple blocks", in: "```json{\"a\":\n1}```\n```json\n{\"b\":\n1}```", want: "{\"a\":\n1}", }, + { + desc: "yaml markdown", + in: "```yaml\nkey: 1\nanother-key: 2```", + want: "```yaml\nkey: 1\nanother-key: 2```", + }, + { + desc: "yaml + json markdown", + in: "```yaml\nkey: 1\nanother-key: 2``` ```json\n{\"a\": 1}\n```", + want: "{\"a\": 1}", + }, + { + desc: "json + yaml markdown", + in: "```json\n{\"a\": 1}\n``` ```yaml\nkey: 1\nanother-key: 2```", + want: "{\"a\": 1}", + }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) {