diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java index 969379833b0..206bb60a816 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/MergeUtils.java @@ -168,6 +168,10 @@ public static ChatCompletions mergeChatCompletions(ChatCompletions left, ChatCom setField(instance, "usage", usage); + setField(instance, "model", right.getModel() == null ? left.getModel() : right.getModel()); + + setField(instance, "serviceTier", right.getServiceTier() == null ? left.getServiceTier() : right.getModel()); + return instance; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java index af945bb6a1f..d842b2c1fc0 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java @@ -32,6 +32,7 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.AssistantMessage; @@ -409,6 +410,56 @@ void testMaxTokensForNonReasoningModels() { } } + @Test + void testModelInStreamingResponse() { + String prompt = "List three colors of the rainbow."; + + // @formatter:off + Flux responseFlux = ChatClient.create(this.chatModel).prompt() + .options(AzureOpenAiChatOptions.builder() + .deploymentName("gpt-4o") + .build()) + .user(prompt) + .stream() + .chatResponse(); + // @formatter:on + + List responses = responseFlux.collectList().block(); + + assertThat(responses).isNotEmpty(); + + ChatResponse firstResponse = responses.get(0); + logger.info("First response model: {}", firstResponse.getMetadata().getModel()); + + ChatResponse lastResponse = responses.get(responses.size() - 1); + logger.info("Last response model: {}", lastResponse.getMetadata().getModel()); + + boolean modelFound = responses.stream() + .map(ChatResponse::getMetadata) + .filter(Objects::nonNull) + .map(metadata -> metadata.getModel()) + .anyMatch(Objects::nonNull); + + assertThat(modelFound).as("Model field should be present in streaming responses").isTrue(); + + if (lastResponse.getMetadata() != null && lastResponse.getMetadata().getModel() != null) { + String model = lastResponse.getMetadata().getModel(); + logger.info("Final merged response model: {}", model); + assertThat(model).isNotEmpty(); + // Azure OpenAI models typically contain "gpt" in their name + assertThat(model).containsIgnoringCase("gpt"); + } + String content = responses.stream() + .flatMap(r -> r.getResults().stream()) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .filter(Objects::nonNull) + .collect(Collectors.joining()); + + assertThat(content).isNotEmpty(); + logger.info("Generated content: {}", content); + } + record ActorsFilms(String actor, List movies) { }