diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index e85f9e03342..2370f2ae3f1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,8 @@ import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinition; import com.azure.ai.openai.models.ChatCompletionsFunctionToolDefinitionFunction; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat; +import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema; import com.azure.ai.openai.models.ChatCompletionsOptions; import com.azure.ai.openai.models.ChatCompletionsResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; @@ -901,7 +903,14 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { * @return Azure response format */ private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseFormat responseFormat) { - if (responseFormat == AzureOpenAiResponseFormat.JSON) { + if (responseFormat.getType() == AzureOpenAiResponseFormat.Type.JSON_SCHEMA) { + ChatCompletionsJsonSchemaResponseFormatJsonSchema jsonSchema = new ChatCompletionsJsonSchemaResponseFormatJsonSchema( + responseFormat.getJsonSchema().getName()); + jsonSchema.setSchema(BinaryData.fromObject(responseFormat.getJsonSchema().getSchema())); + jsonSchema.setStrict(responseFormat.getJsonSchema().getStrict()); + return new ChatCompletionsJsonSchemaResponseFormat(jsonSchema); + } + else if (responseFormat.getType() == AzureOpenAiResponseFormat.Type.JSON_OBJECT) { return new ChatCompletionsJsonResponseFormat(); } return new ChatCompletionsTextResponseFormat(); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java index fd83532ec77..2c90b46ba1f 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,27 +16,233 @@ package org.springframework.ai.azure.openai; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.ModelOptionsUtils; + +import java.util.Map; +import java.util.Objects; + /** * Utility enumeration for representing the response format that may be requested from the - * Azure OpenAI model. Please check OpenAI - * API documentation for more details. + * Azure OpenAI model. Please check + * Azure + * OpenAI API documentation for more details. + * + * @author Jonghoon Park */ -public enum AzureOpenAiResponseFormat { - - // default value used by OpenAI - TEXT, - /* - * From the OpenAI API documentation: Compatability: Compatible with GPT-4 Turbo and - * all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Caveats: This enables JSON - * mode, which guarantees the message the model generates is valid JSON. Important: - * when using JSON mode, you must also instruct the model to produce JSON yourself via - * a system or user message. Without this, the model may generate an unending stream - * of whitespace until the generation reaches the token limit, resulting in a - * long-running and seemingly "stuck" request. Also note that the message content may - * be partially cut off if finish_reason="length", which indicates the generation - * exceeded max_tokens or the conversation exceeded the max context length. +@JsonInclude(JsonInclude.Include.NON_NULL) +public class AzureOpenAiResponseFormat { + + /** + * Type Must be one of 'text', 'json_object' or 'json_schema'. + */ + @JsonProperty("type") + private Type type; + + public AzureOpenAiResponseFormat() { + + } + + /** + * JSON schema object that describes the format of the JSON object. Only applicable + * when type is 'json_schema'. + */ + @JsonProperty("json_schema") + private JsonSchema jsonSchema; + + public Type getType() { + return this.type; + } + + public JsonSchema getJsonSchema() { + return this.jsonSchema; + } + + public static Builder builder() { + return new Builder(); + } + + private String schema; + + private AzureOpenAiResponseFormat(Type type, JsonSchema jsonSchema) { + this.type = type; + this.jsonSchema = jsonSchema; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AzureOpenAiResponseFormat that = (AzureOpenAiResponseFormat) o; + return this.type == that.type && Objects.equals(this.jsonSchema, that.jsonSchema); + } + + @Override + public int hashCode() { + return Objects.hash(this.type, this.jsonSchema); + } + + @Override + public String toString() { + return "AzureOpenAiResponseFormat{" + "type=" + this.type + ", jsonSchema=" + this.jsonSchema + '}'; + } + + public static final class Builder { + + private Type type; + + private JsonSchema jsonSchema; + + private Builder() { + } + + public Builder type(Type type) { + this.type = type; + return this; + } + + public Builder jsonSchema(JsonSchema jsonSchema) { + this.jsonSchema = jsonSchema; + return this; + } + + public Builder jsonSchema(String jsonSchema) { + this.jsonSchema = JsonSchema.builder().schema(jsonSchema).build(); + return this; + } + + public AzureOpenAiResponseFormat build() { + return new AzureOpenAiResponseFormat(this.type, this.jsonSchema); + } + + } + + public enum Type { + + /** + * Generates a text response. (default) + */ + @JsonProperty("text") + TEXT, + + /** + * Enables JSON mode, which guarantees the message the model generates is valid + * JSON. + */ + @JsonProperty("json_object") + JSON_OBJECT, + + /** + * Enables Structured Outputs which guarantees the model will match your supplied + * JSON schema. + */ + @JsonProperty("json_schema") + JSON_SCHEMA + + } + + /** + * JSON schema object that describes the format of the JSON object. Applicable for the + * 'json_schema' type only. */ - JSON + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class JsonSchema { + + @JsonProperty("name") + private String name; + + @JsonProperty("schema") + private Map schema; + + @JsonProperty("strict") + private Boolean strict; + + public JsonSchema() { + + } + + public String getName() { + return this.name; + } + + public Map getSchema() { + return this.schema; + } + + public Boolean getStrict() { + return this.strict; + } + + private JsonSchema(String name, Map schema, Boolean strict) { + this.name = name; + this.schema = schema; + this.strict = strict; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public int hashCode() { + return Objects.hash(this.name, this.schema, this.strict); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JsonSchema that = (JsonSchema) o; + return Objects.equals(this.name, that.name) && Objects.equals(this.schema, that.schema) + && Objects.equals(this.strict, that.strict); + } + + public static final class Builder { + + private String name = "custom_schema"; + + private Map schema; + + private Boolean strict = true; + + private Builder() { + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder schema(Map schema) { + this.schema = schema; + return this; + } + + public Builder schema(String schema) { + this.schema = ModelOptionsUtils.jsonToMap(schema); + return this; + } + + public Builder strict(Boolean strict) { + this.strict = strict; + return this; + } + + public JsonSchema build() { + return new JsonSchema(this.name, this.schema, this.strict); + } + + } + + } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index 46dcf5547d4..d1311d4141e 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,7 +68,7 @@ public void createRequestWithChatOptions() { .logprobs(true) .topLogprobs(5) .enhancements(mockAzureChatEnhancementConfiguration) - .responseFormat(AzureOpenAiResponseFormat.TEXT) + .responseFormat(AzureOpenAiResponseFormat.builder().type(AzureOpenAiResponseFormat.Type.TEXT).build()) .build(); var client = AzureOpenAiChatModel.builder() @@ -114,7 +114,8 @@ public void createRequestWithChatOptions() { .logprobs(true) .topLogprobs(4) .enhancements(anotherMockAzureChatEnhancementConfiguration) - .responseFormat(AzureOpenAiResponseFormat.JSON) + .responseFormat( + AzureOpenAiResponseFormat.builder().type(AzureOpenAiResponseFormat.Type.JSON_OBJECT).build()) .build(); requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content", runtimeOptions)); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelResponseFormatIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelResponseFormatIT.java new file mode 100644 index 00000000000..4618923c20e --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelResponseFormatIT.java @@ -0,0 +1,243 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.OpenAIServiceVersion; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.http.policy.HttpLogOptions; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonghoon Park + */ +@SpringBootTest(classes = AzureOpenAiChatModelResponseFormatIT.TestConfiguration.class) +@RequiresAzureCredentials +public class AzureOpenAiChatModelResponseFormatIT { + + private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS); + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + @Autowired + private AzureOpenAiChatModel chatModel; + + public static boolean isValidJson(String json) { + try { + MAPPER.readTree(json); + } + catch (JacksonException e) { + return false; + } + return true; + } + + @Test + void jsonObject() { + + Prompt prompt = new Prompt("List 8 planets. Use JSON response", AzureOpenAiChatOptions.builder() + .responseFormat( + AzureOpenAiResponseFormat.builder().type(AzureOpenAiResponseFormat.Type.JSON_OBJECT).build()) + .build()); + + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + + String content = response.getResult().getOutput().getText(); + + logger.info("Response content: {}", content); + + assertThat(isValidJson(content)).isTrue(); + } + + @Test + void jsonSchema() { + + var jsonSchema = """ + { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation", "output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps", "final_answer"], + "additionalProperties": false + } + """; + + Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", + AzureOpenAiChatOptions.builder() + .responseFormat(AzureOpenAiResponseFormat.builder() + .type(AzureOpenAiResponseFormat.Type.JSON_SCHEMA) + .jsonSchema(jsonSchema) + .build()) + .build()); + + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + + String content = response.getResult().getOutput().getText(); + + logger.info("Response content: {}", content); + + assertThat(isValidJson(content)).isTrue(); + } + + @Test + void jsonSchemaBeanConverter() { + + @JsonPropertyOrder({ "steps", "final_answer" }) + record MathReasoning(@JsonProperty(required = true, value = "steps") Steps steps, + @JsonProperty(required = true, value = "final_answer") String finalAnswer) { + + record Steps(@JsonProperty(required = true, value = "items") Items[] items) { + + @JsonPropertyOrder({ "output", "explanation" }) + record Items(@JsonProperty(required = true, value = "explanation") String explanation, + @JsonProperty(required = true, value = "output") String output) { + + } + + } + + } + + var outputConverter = new BeanOutputConverter<>(MathReasoning.class); + // @formatter:off + // CHECKSTYLE:OFF + var expectedJsonSchema = """ + { + "$schema" : "https://json-schema.org/draft/2020-12/schema", + "type" : "object", + "properties" : { + "steps" : { + "type" : "object", + "properties" : { + "items" : { + "type" : "array", + "items" : { + "type" : "object", + "properties" : { + "output" : { + "type" : "string" + }, + "explanation" : { + "type" : "string" + } + }, + "required" : [ "output", "explanation" ], + "additionalProperties" : false + } + } + }, + "required" : [ "items" ], + "additionalProperties" : false + }, + "final_answer" : { + "type" : "string" + } + }, + "required" : [ "steps", "final_answer" ], + "additionalProperties" : false + }"""; + // @formatter:on + // CHECKSTYLE:ON + var jsonSchema1 = outputConverter.getJsonSchema(); + + assertThat(jsonSchema1).isNotNull(); + assertThat(jsonSchema1).isEqualTo(expectedJsonSchema); + + Prompt prompt = new Prompt("how can I solve 8x + 7 = -23", + AzureOpenAiChatOptions.builder() + .responseFormat(AzureOpenAiResponseFormat.builder() + .type(AzureOpenAiResponseFormat.Type.JSON_SCHEMA) + .jsonSchema(jsonSchema1) + .build()) + .build()); + + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + + String content = response.getResult().getOutput().getText(); + + logger.info("Response content: {}", content); + + assertThat(isValidJson(content)).isTrue(); + + // Check if the order is correct as specified in the schema. Steps should come + // first before final answer. + assertThat(content.startsWith("{\"steps\":{\"items\":[")); + + MathReasoning mathReasoning = outputConverter.convert(content); + + assertThat(mathReasoning).isNotNull(); + logger.info(mathReasoning.toString()); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public OpenAIClientBuilder openAIClientBuilder() { + return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .serviceVersion(OpenAIServiceVersion.V2025_01_01_PREVIEW) + .httpLogOptions(new HttpLogOptions() + .setLogLevel(com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS)); + } + + @Bean + public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { + return AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClientBuilder) + .defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build()) + .build(); + } + + } + +} diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java index b3a8bfd6d74..35a10728ee2 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -36,7 +36,9 @@ class AzureOpenAiChatOptionsTests { @Test void testBuilderWithAllFields() { - AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.builder() + .type(AzureOpenAiResponseFormat.Type.TEXT) + .build(); ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); streamOptions.setIncludeUsage(true); @@ -74,7 +76,9 @@ void testBuilderWithAllFields() { @Test void testCopy() { - AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.builder() + .type(AzureOpenAiResponseFormat.Type.TEXT) + .build(); ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); streamOptions.setIncludeUsage(true); @@ -111,7 +115,9 @@ void testCopy() { @Test void testSetters() { - AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.builder() + .type(AzureOpenAiResponseFormat.Type.TEXT) + .build(); ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); streamOptions.setIncludeUsage(true); AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration();