Skip to content

Commit bb763c2

Browse files
committed
Add string json check
1 parent a76e65e commit bb763c2

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

src/llm/io_processing/qwen3coder/qwen3coder_tool_parser.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,20 @@ static const ParametersTypeMap_t parseToolSchema(const std::string& functionName
127127
return result;
128128
}
129129

130-
// helper function to escape \n, "
130+
static std::string escapeQuotes(const std::string& input) {
131+
std::string output;
132+
output.reserve(input.size());
133+
for (char c : input) {
134+
switch (c) {
135+
case '"':
136+
output += "\\\"";
137+
break;
138+
default:
139+
output += c;
140+
}
141+
}
142+
return output;
143+
}
131144
static std::string escapeString(const std::string& input) {
132145
std::string output;
133146
output.reserve(input.size());
@@ -136,9 +149,6 @@ static std::string escapeString(const std::string& input) {
136149
case '\n':
137150
output += "\\n";
138151
break;
139-
case '"':
140-
output += "\\\"";
141-
break;
142152
default:
143153
output += c;
144154
}
@@ -153,7 +163,7 @@ static std::string setCorrectValueType(std::string& inputValue, const std::strin
153163
return inputValue;
154164
}
155165
if (paramIt->second == ParameterType::STRING) {
156-
inputValue = "\"" + inputValue + "\"";
166+
inputValue = "\"" + escapeQuotes(inputValue) + "\"";
157167
return inputValue;
158168
}
159169
if (paramIt->second == ParameterType::BOOLEAN) {
@@ -240,8 +250,8 @@ bool Qwen3CoderToolParserImpl::parseUntilStateChange(ToolCalls_t& toolCalls) {
240250
SPDLOG_DEBUG("Tool schema not found for tool: {}, leaving parameter: {} as string", this->currentFunction.name, this->currentParameterName);
241251
} else {
242252
// we don't want to escape entry/exit " for string parameters
243-
auto escaped = escapeString(parameterValue);
244-
parameterValue = setCorrectValueType(escaped, this->currentParameterName, paramIt->second);
253+
// auto escaped = escapeString(parameterValue);
254+
parameterValue = escapeString(setCorrectValueType(parameterValue, this->currentParameterName, paramIt->second));
245255
}
246256
auto res = this->currentFunction.parameters.try_emplace(this->currentParameterName, parameterValue);
247257
if (!res.second)
@@ -361,6 +371,7 @@ std::optional<rapidjson::Document> Qwen3CoderToolParser::sendFullDelta(std::opti
361371
// now we need to add string toolCall.arguments to argumentsWrapper under "arguments" key
362372
rapidjson::Value toolCallsString(rapidjson::kStringType);
363373
toolCallsString.SetString(toolCall.arguments.c_str(), allocator);
374+
SPDLOG_TRACE("Tool call arguments string: {}", toolCall.arguments);
364375
argumentsWrapper.AddMember("arguments", toolCallsString, allocator);
365376
auto currentDelta = wrapDelta(argumentsWrapper, this->toolCallIndex);
366377
SPDLOG_DEBUG("First delta doc: {}", documentToString(currentDelta));

src/test/llm/output_parsers/qwen3coder_output_parser_test.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,28 @@ if __name__ == "__main__":
691691
SPDLOG_ERROR("Expected:\n{}", expected);
692692
SPDLOG_ERROR("Got:\n{}", docStr);
693693
EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk;
694+
// now we do final check
695+
// we want to ensure that if we extract from expectedDelta["delta"]["tool_calls"][0]["function"]["arguments"] which as a string
696+
// and then try to read this string as a json
697+
if (expectedDelta.value().find("arguments") == std::string::npos) {
698+
SPDLOG_TRACE("No arguments to check for delta:\n{}", expectedDelta.value());
699+
continue; // no arguments to check
700+
}
701+
auto docJsonIt = doc->FindMember("delta");
702+
ASSERT_NE(docJsonIt, doc->MemberEnd());
703+
auto toolCallsIt = docJsonIt->value.FindMember("tool_calls");
704+
ASSERT_NE(toolCallsIt, docJsonIt->value.MemberEnd());
705+
for (auto& toolCall : toolCallsIt->value.GetArray()) {
706+
auto functionIt = toolCall.FindMember("function");
707+
ASSERT_NE(functionIt, toolCall.MemberEnd());
708+
auto argumentsIt = functionIt->value.FindMember("arguments");
709+
ASSERT_NE(argumentsIt, functionIt->value.MemberEnd());
710+
const std::string& argumentsStr = argumentsIt->value.GetString();
711+
rapidjson::Document argsDoc;
712+
argsDoc.Parse(argumentsStr.c_str()); // now check for errors
713+
EXPECT_FALSE(argsDoc.HasParseError()) << "Arguments is not valid JSON for chunk: " << chunk << "\nArguments string:\n"
714+
<< argumentsStr;
715+
}
694716
}
695717
} else {
696718
EXPECT_TRUE(false) << "Mismatch between expectedDelta and doc for id: " << i << " chunk:\n"

0 commit comments

Comments
 (0)