Skip to content

Commit 2636336

Browse files
authored
feat: add json_object support in response_format (#4080)
1 parent 60aa80e commit 2636336

File tree

4 files changed

+72
-41
lines changed

4 files changed

+72
-41
lines changed

lmdeploy/pytorch/engine/guided_process.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
2626
processors = {}
2727
for i, _format in enumerate(response_formats):
2828
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
29-
if _format['type'] == 'json_schema':
29+
schema_type = _format['type']
30+
if schema_type == 'json_schema':
3031
schema = _format['json_schema']
3132
if isinstance(schema, Dict):
3233
for key in ['json_schema', 'schema']:
@@ -37,15 +38,17 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
3738
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
3839
'either a dictionary or a string that contains the'
3940
' JSON Schema specification')
40-
elif _format['type'] == 'regex_schema':
41+
elif schema_type == 'regex_schema':
4142
schema = _format.get('regex_schema', '')
43+
elif schema_type == 'json_object':
44+
schema = '{"type" : "object", "additionalProperties": true}'
4245
else:
43-
raise ValueError(f"unsupported format type: {_format['type']}")
46+
raise ValueError(f'unsupported format type: {schema_type}')
4447

4548
session_id = session_ctx[i]['session_id']
4649
seq_id = session_ctx[i]['seq_id']
4750

48-
processors[i] = self.get_processor(session_id, seq_id, schema, _format['type'])
51+
processors[i] = self.get_processor(session_id, seq_id, schema, schema_type)
4952

5053
return processors
5154

@@ -63,7 +66,9 @@ def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) ->
6366
assert isinstance(schema, dict)
6467
compiled = self.compiler.compile_json_schema(schema)
6568
elif type == 'regex_schema':
66-
compiled = self.compiler.compile_regex_grammar(schema)
69+
compiled = self.compiler.compile_regex(schema)
70+
elif type == 'json_object':
71+
compiled = self.compiler.compile_json_schema(schema)
6772
else:
6873
assert False, f'Do not support schema type {type}'
6974

lmdeploy/serve/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
590590
tool_calls = None
591591
reasoning_content = None
592592
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
593-
try: # TODO add json_schema guidance to turbomind
593+
try:
594594
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
595595
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
596596
if isinstance(tool_calls, List) and len(tool_calls):

lmdeploy/turbomind/turbomind.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,12 @@ async def async_stream_infer(self,
720720
try:
721721
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size)
722722
decode_grammar_type = gen_config.response_format['type']
723-
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
723+
if decode_grammar_type == 'json_schema':
724+
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
725+
elif decode_grammar_type == 'regex_schema':
726+
decode_grammar = gen_config.response_format[decode_grammar_type]
727+
elif decode_grammar_type == 'json_object':
728+
decode_grammar = '{"type" : "object", "additionalProperties": true}'
724729

725730
compiler = _xgr.GrammarCompiler(tokenizer_info)
726731

@@ -730,9 +735,12 @@ async def async_stream_infer(self,
730735
elif decode_grammar_type == 'regex_schema':
731736
decode_grammar = str(decode_grammar)
732737
grammar = compiler.compile_regex(decode_grammar)
738+
elif decode_grammar_type == 'json_object':
739+
decode_grammar = str(decode_grammar)
740+
grammar = compiler.compile_json_schema(decode_grammar)
733741
else:
734742
assert False, f'Decode grammar type {decode_grammar_type} should be in ' \
735-
'["json_schema", "regex_schema"]'
743+
'["json_schema", "regex_schema", "json_object"]'
736744

737745
self.model_inst.set_grammar(grammar)
738746
except ValueError as e:
Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23

34
import pytest
45
from jsonschema import validate
@@ -16,64 +17,81 @@
1617
('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)),
1718
]
1819

19-
GUIDE_SCHEMA = {
20-
'type': 'object',
21-
'properties': {
22-
'name': {
23-
'type': 'string'
24-
},
25-
'skills': {
26-
'type': 'array',
27-
'items': {
28-
'type': 'string',
29-
'maxLength': 10
20+
SCHEMA_MAP = {
21+
'json_schema': {
22+
'type': 'object',
23+
'properties': {
24+
'name': {
25+
'type': 'string'
3026
},
31-
'minItems': 3,
32-
'maxItems': 10,
33-
},
34-
'work history': {
35-
'type': 'array',
36-
'items': {
37-
'type': 'object',
38-
'properties': {
39-
'company': {
40-
'type': 'string'
41-
},
42-
'duration': {
43-
'type': 'string'
27+
'skills': {
28+
'type': 'array',
29+
'items': {
30+
'type': 'string',
31+
'maxLength': 10
32+
},
33+
'minItems': 3,
34+
'maxItems': 10,
35+
},
36+
'work history': {
37+
'type': 'array',
38+
'items': {
39+
'type': 'object',
40+
'properties': {
41+
'company': {
42+
'type': 'string'
43+
},
44+
'duration': {
45+
'type': 'string'
46+
},
4447
},
48+
'required': ['company'],
4549
},
46-
'required': ['company'],
4750
},
4851
},
52+
'required': ['name', 'skills', 'work history'],
4953
},
50-
'required': ['name', 'skills', 'work history'],
54+
'regex_schema': 'call me [A-Za-z]{1,10}',
55+
'json_object': None,
5156
}
5257

5358

5459
@pytest.mark.parametrize('model_id', MODEL_IDS)
5560
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
56-
@pytest.mark.parametrize('enable_guide', [True, False])
57-
def test_guided_matrix(model_id, backend_name, backend_factory, enable_guide):
61+
@pytest.mark.parametrize('schema_type', list(SCHEMA_MAP.keys()) + [None])
62+
def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
5863
pipe = pipeline(
5964
model_id,
6065
backend_config=backend_factory(),
6166
log_level='INFO',
6267
)
6368

69+
if schema_type is None:
70+
enable_guide = False
71+
else:
72+
enable_guide = True
73+
response_format = {'type': schema_type}
74+
schema = SCHEMA_MAP[schema_type]
75+
if schema_type == 'json_schema':
76+
response_format[schema_type] = dict(name='test', schema=schema)
77+
elif schema_type == 'regex_schema':
78+
response_format[schema_type] = schema
79+
6480
try:
6581
if enable_guide:
66-
gen_config = GenerationConfig(response_format=dict(
67-
type='json_schema',
68-
json_schema=dict(name='test', schema=GUIDE_SCHEMA),
69-
), )
82+
gen_config = GenerationConfig(response_format=response_format)
7083
else:
7184
gen_config = GenerationConfig()
7285

7386
response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config)
7487
assert response and response[0].text
7588

7689
if enable_guide:
77-
validate(instance=json.loads(response[0].text), schema=GUIDE_SCHEMA)
90+
if schema_type == 'json_schema':
91+
validate(instance=json.loads(response[0].text), schema=schema)
92+
elif schema_type == 'json_object':
93+
validate(instance=json.loads(response[0].text), schema={'type': 'object', 'additionalProperties': True})
94+
elif schema_type == 'regex_schema':
95+
assert re.fullmatch(schema, response[0].text)
7896
finally:
7997
pipe.close()

0 commit comments

Comments
 (0)