Skip to content

Commit c83865e

Browse files
committed
feat: add json_object support in response_format
1 parent f36aa71 commit c83865e

File tree

5 files changed

+74
-43
lines changed

5 files changed

+74
-43
lines changed

lmdeploy/pytorch/engine/guided_process.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
2626
processors = {}
2727
for i, _format in enumerate(response_formats):
2828
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
29-
if _format['type'] == 'json_schema':
29+
schema_type = _format['type']
30+
if schema_type == 'json_schema':
3031
schema = _format['json_schema']
3132
if isinstance(schema, Dict):
3233
for key in ['json_schema', 'schema']:
@@ -37,15 +38,17 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
3738
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
3839
'either a dictionary or a string that contains the'
3940
' JSON Schema specification')
40-
elif _format['type'] == 'regex_schema':
41+
elif schema_type == 'regex_schema':
4142
schema = _format.get('regex_schema', '')
43+
elif schema_type == 'json_object':
44+
schema = ''
4245
else:
43-
raise ValueError(f"unsupported format type: {_format['type']}")
46+
raise ValueError(f'unsupported format type: {schema_type}')
4447

4548
session_id = session_ctx[i]['session_id']
4649
seq_id = session_ctx[i]['seq_id']
4750

48-
processors[i] = self.get_processor(session_id, seq_id, schema, _format['type'])
51+
processors[i] = self.get_processor(session_id, seq_id, schema, schema_type)
4952

5053
return processors
5154

@@ -63,7 +66,9 @@ def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) ->
6366
assert isinstance(schema, dict)
6467
compiled = self.compiler.compile_json_schema(schema)
6568
elif type == 'regex_schema':
66-
compiled = self.compiler.compile_regex_grammar(schema)
69+
compiled = self.compiler.compile_regex(schema)
70+
elif type == 'json_object':
71+
compiled = self.compiler.compile_builtin_json_grammar()
6772
else:
6873
assert False, f'Do not support schema type {type}'
6974

lmdeploy/serve/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
590590
tool_calls = None
591591
reasoning_content = None
592592
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
593-
try: # TODO add json_schema guidance to turbomind
593+
try:
594594
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
595595
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
596596
if isinstance(tool_calls, List) and len(tool_calls):

lmdeploy/turbomind/turbomind.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,10 @@ async def async_stream_infer(self,
720720
try:
721721
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size)
722722
decode_grammar_type = gen_config.response_format['type']
723-
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
723+
if decode_grammar_type == 'json_schema':
724+
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
725+
elif decode_grammar_type == 'regex_schema':
726+
decode_grammar = gen_config.response_format[decode_grammar_type]
724727

725728
compiler = _xgr.GrammarCompiler(tokenizer_info)
726729

@@ -730,9 +733,11 @@ async def async_stream_infer(self,
730733
elif decode_grammar_type == 'regex_schema':
731734
decode_grammar = str(decode_grammar)
732735
grammar = compiler.compile_regex(decode_grammar)
736+
elif decode_grammar_type == 'json_object':
737+
grammar = compiler.compile_builtin_json_grammar()
733738
else:
734739
assert False, f'Decode grammar type {decode_grammar_type} should be in ' \
735-
'["json_schema", "regex_schema"]'
740+
'["json_schema", "regex_schema", "json_object"]'
736741

737742
self.model_inst.set_grammar(grammar)
738743
except ValueError as e:

src/turbomind/python/xgrammar_bind.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,8 @@ PYBIND11_MODULE(_xgrammar, m)
130130
.def("compile_regex",
131131
&GrammarCompiler::CompileRegex,
132132
py::call_guard<py::gil_scoped_release>(),
133-
py::arg("schema"));
133+
py::arg("schema"))
134+
.def("compile_builtin_json_grammar",
135+
&GrammarCompiler::CompileBuiltinJSONGrammar,
136+
py::call_guard<py::gil_scoped_release>());
134137
}
Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23

34
import pytest
45
from jsonschema import validate
@@ -8,72 +9,89 @@
89

910
MODEL_IDS = [
1011
'Qwen/Qwen3-0.6B',
11-
'OpenGVLab/InternVL3_5-1B',
12+
'OpenGVLab/InternVL3_5-4B',
1213
]
1314

1415
BACKEND_FACTORIES = [
1516
('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)),
1617
('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)),
1718
]
1819

19-
GUIDE_SCHEMA = {
20-
'type': 'object',
21-
'properties': {
22-
'name': {
23-
'type': 'string'
24-
},
25-
'skills': {
26-
'type': 'array',
27-
'items': {
28-
'type': 'string',
29-
'maxLength': 10
20+
SCHEMA_MAP = {
21+
'json_schema': {
22+
'type': 'object',
23+
'properties': {
24+
'name': {
25+
'type': 'string'
3026
},
31-
'minItems': 3,
32-
'maxItems': 10,
33-
},
34-
'work history': {
35-
'type': 'array',
36-
'items': {
37-
'type': 'object',
38-
'properties': {
39-
'company': {
40-
'type': 'string'
41-
},
42-
'duration': {
43-
'type': 'string'
27+
'skills': {
28+
'type': 'array',
29+
'items': {
30+
'type': 'string',
31+
'maxLength': 10
32+
},
33+
'minItems': 3,
34+
'maxItems': 10,
35+
},
36+
'work history': {
37+
'type': 'array',
38+
'items': {
39+
'type': 'object',
40+
'properties': {
41+
'company': {
42+
'type': 'string'
43+
},
44+
'duration': {
45+
'type': 'string'
46+
},
4447
},
48+
'required': ['company'],
4549
},
46-
'required': ['company'],
4750
},
4851
},
52+
'required': ['name', 'skills', 'work history'],
4953
},
50-
'required': ['name', 'skills', 'work history'],
54+
'regex_schema': 'call me [A-Za-z]{1,10}',
55+
'json_object': None,
5156
}
5257

5358

5459
@pytest.mark.parametrize('model_id', MODEL_IDS)
5560
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
56-
@pytest.mark.parametrize('enable_guide', [True, False])
57-
def test_guided_matrix(model_id, backend_name, backend_factory, enable_guide):
61+
@pytest.mark.parametrize('schema_type', list(SCHEMA_MAP.keys()) + [None])
62+
def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
5863
pipe = pipeline(
5964
model_id,
6065
backend_config=backend_factory(),
6166
log_level='INFO',
6267
)
6368

69+
if schema_type is None:
70+
enable_guide = False
71+
else:
72+
enable_guide = True
73+
response_format = {'type': schema_type}
74+
schema = SCHEMA_MAP[schema_type]
75+
if schema_type == 'json_schema':
76+
response_format[schema_type] = dict(name='test', schema=schema)
77+
elif schema_type == 'regex_schema':
78+
response_format[schema_type] = schema
79+
6480
try:
6581
if enable_guide:
66-
gen_config = GenerationConfig(response_format=dict(
67-
type='json_schema',
68-
json_schema=dict(name='test', schema=GUIDE_SCHEMA),
69-
), )
82+
gen_config = GenerationConfig(response_format=response_format)
7083
else:
7184
gen_config = GenerationConfig()
7285

7386
response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config)
7487
assert response and response[0].text
7588

7689
if enable_guide:
77-
validate(instance=json.loads(response[0].text), schema=GUIDE_SCHEMA)
90+
if schema_type == 'json_schema':
91+
validate(instance=json.loads(response[0].text), schema=schema)
92+
elif schema_type == 'json_object':
93+
validate(instance=json.loads(response[0].text), schema={'type': 'object', 'additionalProperties': True})
94+
elif schema_type == 'regex_schema':
95+
assert re.fullmatch(schema, response[0].text)
7896
finally:
7997
pipe.close()

0 commit comments

Comments
 (0)