Skip to content

Commit 2fcd728

Browse files
authored
Use structured output for tool choice handling (#3762)
1 parent c348d6b commit 2fcd728

23 files changed

+452
-640
lines changed

src/llm/io_processing/base_generation_config_builder.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ namespace ovms {
3131
class BaseGenerationConfigBuilder {
3232
protected:
3333
ov::genai::GenerationConfig config;
34+
const bool enableToolGuidedGeneration;
3435
void setStructuralTagsConfig(const ov::genai::StructuredOutputConfig::StructuralTag& structuralTag);
3536

3637
public:
3738
BaseGenerationConfigBuilder() = delete;
3839
// Initializes the builder with a base generation config read from model generation_config.json
39-
explicit BaseGenerationConfigBuilder(ov::genai::GenerationConfig& baseConfig) :
40-
config(baseConfig) {}
40+
explicit BaseGenerationConfigBuilder(ov::genai::GenerationConfig& baseConfig, bool enableToolGuidedGeneration) :
41+
config(baseConfig),
42+
enableToolGuidedGeneration(enableToolGuidedGeneration) {}
4143
virtual ~BaseGenerationConfigBuilder() = default;
4244

4345
ov::genai::GenerationConfig& getConfig() { return config; }

src/llm/io_processing/base_output_parser.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,4 @@ rapidjson::Document BaseOutputParser::wrapDelta(const rapidjson::Document& delta
7272
return wrappedDelta;
7373
}
7474

75-
void BaseOutputParser::enableImmediateParsing() {
76-
immediateParsingEnabled = true;
77-
}
78-
79-
bool BaseOutputParser::isImmediateParsingEnabled() const {
80-
return immediateParsingEnabled;
81-
}
82-
8375
} // namespace ovms

src/llm/io_processing/base_output_parser.hpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ using ToolsParameterTypeMap_t = std::unordered_map<std::string, ParametersTypeMa
6464
class BaseOutputParser {
6565
protected:
6666
ov::genai::Tokenizer tokenizer;
67-
// Flag indicating whether parsing start tag has been injected into the prompt
68-
// if true, parser should assume start tag already appeared and start parsing immediately
69-
bool immediateParsingEnabled = false;
7067

7168
public:
7269
BaseOutputParser() = delete;
@@ -81,11 +78,6 @@ class BaseOutputParser {
8178
// {"tool_calls":[{"index":0,"function":<delta>}]}
8279
static rapidjson::Document wrapDelta(const rapidjson::Document& delta, int toolCallIndex);
8380

84-
// Calling this method should put parser into immediate parsing mode where it starts parsing immediately, without seeking the start tag.
85-
void enableImmediateParsing();
86-
87-
bool isImmediateParsingEnabled() const;
88-
8981
// --- Specialized output parsers interface ---
9082

9183
// Parse model output and extract relevant information to parsedOutput fields. Raw generated tokens are provided as an argument.

src/llm/io_processing/generation_config_builder.hpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,21 @@ class GenerationConfigBuilder {
3434
public:
3535
GenerationConfigBuilder() = delete;
3636
// Using tool parser name to select appropriate builder implementation to avoid introducing additional parameters. Might be insufficient in the future.
37-
explicit GenerationConfigBuilder(ov::genai::GenerationConfig baseConfig, std::string toolParserName = "", bool enableToolGuidedGeneration = false) {
38-
if (!enableToolGuidedGeneration) {
39-
builder_impl = std::make_unique<BaseGenerationConfigBuilder>(baseConfig);
40-
return;
41-
}
42-
37+
explicit GenerationConfigBuilder(ov::genai::GenerationConfig baseConfig, bool enableToolGuidedGeneration, std::string toolParserName = "") {
4338
if (toolParserName == "llama3") {
44-
builder_impl = std::make_unique<Llama3GenerationConfigBuilder>(baseConfig);
39+
builder_impl = std::make_unique<Llama3GenerationConfigBuilder>(baseConfig, enableToolGuidedGeneration);
4540
} else if (toolParserName == "qwen3") {
4641
// Qwen3 and Hermes3 share the same mechanism for generating tool calls, so we can use Hermes3GenerationConfigBuilder
47-
builder_impl = std::make_unique<Hermes3GenerationConfigBuilder>(baseConfig);
42+
builder_impl = std::make_unique<Hermes3GenerationConfigBuilder>(baseConfig, enableToolGuidedGeneration);
4843
} else if (toolParserName == "hermes3") {
49-
builder_impl = std::make_unique<Hermes3GenerationConfigBuilder>(baseConfig);
44+
builder_impl = std::make_unique<Hermes3GenerationConfigBuilder>(baseConfig, enableToolGuidedGeneration);
5045
} else if (toolParserName == "phi4") {
51-
builder_impl = std::make_unique<Phi4GenerationConfigBuilder>(baseConfig);
46+
builder_impl = std::make_unique<Phi4GenerationConfigBuilder>(baseConfig, enableToolGuidedGeneration);
5247
} else {
53-
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Option enable_tool_guided_generation is set, but will not be effective since no valid tool parser has been provided.");
54-
builder_impl = std::make_unique<BaseGenerationConfigBuilder>(baseConfig);
48+
if (enableToolGuidedGeneration) {
49+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Option enable_tool_guided_generation is set, but will not be effective since no valid tool parser has been provided.");
50+
}
51+
builder_impl = std::make_unique<BaseGenerationConfigBuilder>(baseConfig, enableToolGuidedGeneration);
5552
}
5653
}
5754

src/llm/io_processing/hermes3/generation_config_builder.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,25 @@ void Hermes3GenerationConfigBuilder::parseConfigFromRequest(const OpenAIChatComp
3333
return;
3434
}
3535

36-
// Set tool guided generation config specific to Hermes3 and Qwen3 models
37-
auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>();
38-
triggeredTags->triggers.push_back("<tool_call>");
39-
40-
for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) {
41-
const auto& toolSchema = toolSchemaWrapper.stringRepr;
42-
ov::genai::StructuredOutputConfig::Tag tagItem;
43-
tagItem.begin = "<tool_call>\n{\"name\": \"" + toolName + "\", \"arguments\": ";
44-
tagItem.end = "}\n</tool_call>";
45-
tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema);
46-
triggeredTags->tags.push_back(tagItem);
36+
if (enableToolGuidedGeneration || request.toolChoice == "required") {
37+
// Set tool guided generation config specific to Hermes3 and Qwen3 models
38+
auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>();
39+
triggeredTags->triggers.push_back("<tool_call>");
40+
41+
for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) {
42+
const auto& toolSchema = toolSchemaWrapper.stringRepr;
43+
ov::genai::StructuredOutputConfig::Tag tagItem;
44+
tagItem.begin = "<tool_call>\n{\"name\": \"" + toolName + "\", \"arguments\": ";
45+
tagItem.end = "}\n</tool_call>";
46+
tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema);
47+
triggeredTags->tags.push_back(tagItem);
48+
}
49+
if (request.toolChoice == "required") {
50+
triggeredTags->at_least_one = true;
51+
}
52+
ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags;
53+
setStructuralTagsConfig(structuralTag);
4754
}
48-
ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags;
49-
setStructuralTagsConfig(structuralTag);
5055
}
5156

5257
} // namespace ovms

src/llm/io_processing/hermes3/generation_config_builder.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ namespace ovms {
2525
class Hermes3GenerationConfigBuilder : public BaseGenerationConfigBuilder {
2626
public:
2727
Hermes3GenerationConfigBuilder() = delete;
28-
explicit Hermes3GenerationConfigBuilder(ov::genai::GenerationConfig& baseConfig) :
29-
BaseGenerationConfigBuilder(baseConfig) {}
28+
explicit Hermes3GenerationConfigBuilder(ov::genai::GenerationConfig& baseConfig, bool enableToolGuidedGeneration) :
29+
BaseGenerationConfigBuilder(baseConfig, enableToolGuidedGeneration) {}
3030

3131
void parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) override;
3232
};

src/llm/io_processing/hermes3/tool_parser.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,8 @@ void Hermes3ToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int6
123123
size_t pos = 0;
124124
size_t firstToolCallPos;
125125

126-
// If immediate parsing is enabled, we assume tool calls start from the beginning of the content.
127-
// Otherwise, we search for the first occurrence of the tool call start tag.
128-
if (!immediateParsingEnabled) {
129-
firstToolCallPos = parsedOutput.content.find(startTag, pos);
130-
} else {
131-
// Read first tool call without opening tag
132-
firstToolCallPos = 0;
133-
size_t end = parsedOutput.content.find(endTag, firstToolCallPos);
134-
std::string tool;
135-
if (end != std::string::npos) {
136-
tool = parsedOutput.content.substr(0, end);
137-
pos = end + endTag.length();
138-
} else {
139-
tool = parsedOutput.content;
140-
pos = parsedOutput.content.length();
141-
}
142-
if (!tool.empty()) {
143-
tools.push_back(tool);
144-
}
145-
}
146-
126+
// Save position of the first tool call start tag to properly clear content after parsing.
127+
firstToolCallPos = parsedOutput.content.find(startTag, pos);
147128
while (true) {
148129
size_t start = parsedOutput.content.find(startTag, pos);
149130
if (start == std::string::npos) {

src/llm/io_processing/llama3/generation_config_builder.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,25 @@ void Llama3GenerationConfigBuilder::parseConfigFromRequest(const OpenAIChatCompl
3333
return;
3434
}
3535

36-
// Set tool guided generation config specific to Llama-3 model
37-
auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>();
38-
triggeredTags->triggers.push_back("{\"name\":");
39-
40-
for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) {
41-
const auto& toolSchema = toolSchemaWrapper.stringRepr;
42-
ov::genai::StructuredOutputConfig::Tag tagItem;
43-
tagItem.begin = "{\"name\": \"" + toolName + "\", \"parameters\": ";
44-
tagItem.end = "}";
45-
tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema);
46-
triggeredTags->tags.push_back(tagItem);
36+
if (enableToolGuidedGeneration || request.toolChoice == "required") {
37+
// Set tool guided generation config specific to Llama-3 model
38+
auto triggeredTags = std::make_shared<ov::genai::StructuredOutputConfig::TriggeredTags>();
39+
triggeredTags->triggers.push_back("{\"name\":");
40+
41+
for (const auto& [toolName, toolSchemaWrapper] : request.toolNameSchemaMap) {
42+
const auto& toolSchema = toolSchemaWrapper.stringRepr;
43+
ov::genai::StructuredOutputConfig::Tag tagItem;
44+
tagItem.begin = "{\"name\": \"" + toolName + "\", \"parameters\": ";
45+
tagItem.end = "}";
46+
tagItem.content = ov::genai::StructuredOutputConfig::JSONSchema(toolSchema);
47+
triggeredTags->tags.push_back(tagItem);
48+
}
49+
if (request.toolChoice == "required") {
50+
triggeredTags->at_least_one = true;
51+
}
52+
ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags;
53+
setStructuralTagsConfig(structuralTag);
4754
}
48-
ov::genai::StructuredOutputConfig::StructuralTag structuralTag = triggeredTags;
49-
setStructuralTagsConfig(structuralTag);
5055
}
5156

5257
} // namespace ovms

src/llm/io_processing/llama3/generation_config_builder.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ namespace ovms {
2525
class Llama3GenerationConfigBuilder : public BaseGenerationConfigBuilder {
2626
public:
2727
Llama3GenerationConfigBuilder() = delete;
28-
explicit Llama3GenerationConfigBuilder(ov::genai::GenerationConfig& baseConfig) :
29-
BaseGenerationConfigBuilder(baseConfig) {}
28+
explicit Llama3GenerationConfigBuilder(ov::genai::GenerationConfig& baseConfig, bool enableToolGuidedGeneration) :
29+
BaseGenerationConfigBuilder(baseConfig, enableToolGuidedGeneration) {}
3030

3131
void parseConfigFromRequest(const OpenAIChatCompletionsRequest& request) override;
3232
};

src/llm/io_processing/llama3/tool_parser.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,25 @@ void Llama3ToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64
3131
// TODO: check if we can rely on decoded <|python_tag|> token to be present in the content, so we can drop multiple detokenizations and copies
3232
// and just extract substrings from the content and modify content in-place
3333

34-
// If immediate trigger parsing is enabled, we assume botTokenId has been injected into the prompt and whole output are tool calls,
35-
// otherwise we search for botTokenId in the generatedTokens to find tool calls start or check if the content starts with "{" (llama3 sometimes does not generate botTokenId)
34+
// We search for botTokenId in the generatedTokens to find tool calls start or check if the content starts with "{" (llama3 sometimes does not generate botTokenId)
3635
auto toolCallsStartPosition = generatedTokens.begin();
37-
if (!immediateParsingEnabled) {
38-
toolCallsStartPosition = generatedTokens.end();
39-
// Find botTokenId in generated_ids
40-
auto botTokenIt = std::find(generatedTokens.begin(), generatedTokens.end(), botTokenId);
36+
toolCallsStartPosition = generatedTokens.end();
37+
// Find botTokenId in generated_ids
38+
auto botTokenIt = std::find(generatedTokens.begin(), generatedTokens.end(), botTokenId);
4139

42-
if (botTokenIt != generatedTokens.end()) {
43-
// Decode the content before botTokenId
44-
std::vector<int64_t> contentTokens(generatedTokens.begin(), botTokenIt);
45-
parsedOutput.content = tokenizer.decode(contentTokens);
46-
// Tokens after botTokenId will be treated as tool calls
47-
toolCallsStartPosition = botTokenIt + 1;
48-
} else {
49-
// If botTokenId is not found, check if model output starts with "{" and if so, assume it's a tool call"
50-
if (!parsedOutput.content.empty() && parsedOutput.content[0] == '{') {
51-
// If model output starts with "{", treat it as a tool call
52-
toolCallsStartPosition = generatedTokens.begin();
53-
parsedOutput.content.clear();
54-
}
55-
}
40+
if (botTokenIt != generatedTokens.end()) {
41+
// Decode the content before botTokenId
42+
std::vector<int64_t> contentTokens(generatedTokens.begin(), botTokenIt);
43+
parsedOutput.content = tokenizer.decode(contentTokens);
44+
// Tokens after botTokenId will be treated as tool calls
45+
toolCallsStartPosition = botTokenIt + 1;
5646
} else {
57-
parsedOutput.content.clear();
47+
// If botTokenId is not found, check if model output starts with "{" and if so, assume it's a tool call"
48+
if (!parsedOutput.content.empty() && parsedOutput.content[0] == '{') {
49+
// If model output starts with "{", treat it as a tool call
50+
toolCallsStartPosition = generatedTokens.begin();
51+
parsedOutput.content.clear();
52+
}
5853
}
5954

6055
if (toolCallsStartPosition != generatedTokens.end()) {

0 commit comments

Comments
 (0)