Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions lmdeploy/pytorch/engine/guided_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
processors = {}
for i, _format in enumerate(response_formats):
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
if _format['type'] == 'json_schema':
schema_type = _format['type']
if schema_type == 'json_schema':
schema = _format['json_schema']
if isinstance(schema, Dict):
for key in ['json_schema', 'schema']:
Expand All @@ -37,15 +38,17 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
'either a dictionary or a string that contains the'
' JSON Schema specification')
elif _format['type'] == 'regex_schema':
elif schema_type == 'regex_schema':
schema = _format.get('regex_schema', '')
elif schema_type == 'json_object':
schema = '{"type" : "object", "additionalProperties": true}'
else:
raise ValueError(f"unsupported format type: {_format['type']}")
raise ValueError(f'unsupported format type: {schema_type}')

session_id = session_ctx[i]['session_id']
seq_id = session_ctx[i]['seq_id']

processors[i] = self.get_processor(session_id, seq_id, schema, _format['type'])
processors[i] = self.get_processor(session_id, seq_id, schema, schema_type)

return processors

Expand All @@ -63,7 +66,9 @@ def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) ->
assert isinstance(schema, dict)
compiled = self.compiler.compile_json_schema(schema)
elif type == 'regex_schema':
compiled = self.compiler.compile_regex_grammar(schema)
compiled = self.compiler.compile_regex(schema)
elif type == 'json_object':
compiled = self.compiler.compile_json_schema(schema)
else:
assert False, f'Do not support schema type {type}'

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
tool_calls = None
reasoning_content = None
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
try: # TODO add json_schema guidance to turbomind
try:
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
if isinstance(tool_calls, List) and len(tool_calls):
Expand Down
12 changes: 10 additions & 2 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,12 @@ async def async_stream_infer(self,
try:
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size)
decode_grammar_type = gen_config.response_format['type']
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
if decode_grammar_type == 'json_schema':
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
elif decode_grammar_type == 'regex_schema':
decode_grammar = gen_config.response_format[decode_grammar_type]
elif decode_grammar_type == 'json_object':
decode_grammar = '{"type" : "object", "additionalProperties": true}'

compiler = _xgr.GrammarCompiler(tokenizer_info)

Expand All @@ -730,9 +735,12 @@ async def async_stream_infer(self,
elif decode_grammar_type == 'regex_schema':
decode_grammar = str(decode_grammar)
grammar = compiler.compile_regex(decode_grammar)
elif decode_grammar_type == 'json_object':
decode_grammar = str(decode_grammar)
grammar = compiler.compile_json_schema(decode_grammar)
else:
assert False, f'Decode grammar type {decode_grammar_type} should be in ' \
'["json_schema", "regex_schema"]'
'["json_schema", "regex_schema", "json_object"]'

self.model_inst.set_grammar(grammar)
except ValueError as e:
Expand Down
84 changes: 51 additions & 33 deletions tests/test_lmdeploy/test_grammar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re

import pytest
from jsonschema import validate
Expand All @@ -16,64 +17,81 @@
('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)),
]

GUIDE_SCHEMA = {
'type': 'object',
'properties': {
'name': {
'type': 'string'
},
'skills': {
'type': 'array',
'items': {
'type': 'string',
'maxLength': 10
SCHEMA_MAP = {
'json_schema': {
'type': 'object',
'properties': {
'name': {
'type': 'string'
},
'minItems': 3,
'maxItems': 10,
},
'work history': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'company': {
'type': 'string'
},
'duration': {
'type': 'string'
'skills': {
'type': 'array',
'items': {
'type': 'string',
'maxLength': 10
},
'minItems': 3,
'maxItems': 10,
},
'work history': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'company': {
'type': 'string'
},
'duration': {
'type': 'string'
},
},
'required': ['company'],
},
'required': ['company'],
},
},
'required': ['name', 'skills', 'work history'],
},
'required': ['name', 'skills', 'work history'],
'regex_schema': 'call me [A-Za-z]{1,10}',
'json_object': None,
}


@pytest.mark.parametrize('model_id', MODEL_IDS)
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
@pytest.mark.parametrize('enable_guide', [True, False])
def test_guided_matrix(model_id, backend_name, backend_factory, enable_guide):
@pytest.mark.parametrize('schema_type', list(SCHEMA_MAP.keys()) + [None])
def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
pipe = pipeline(
model_id,
backend_config=backend_factory(),
log_level='INFO',
)

if schema_type is None:
enable_guide = False
else:
enable_guide = True
response_format = {'type': schema_type}
schema = SCHEMA_MAP[schema_type]
if schema_type == 'json_schema':
response_format[schema_type] = dict(name='test', schema=schema)
elif schema_type == 'regex_schema':
response_format[schema_type] = schema

try:
if enable_guide:
gen_config = GenerationConfig(response_format=dict(
type='json_schema',
json_schema=dict(name='test', schema=GUIDE_SCHEMA),
), )
gen_config = GenerationConfig(response_format=response_format)
else:
gen_config = GenerationConfig()

response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config)
assert response and response[0].text

if enable_guide:
validate(instance=json.loads(response[0].text), schema=GUIDE_SCHEMA)
if schema_type == 'json_schema':
validate(instance=json.loads(response[0].text), schema=schema)
elif schema_type == 'json_object':
validate(instance=json.loads(response[0].text), schema={'type': 'object', 'additionalProperties': True})
elif schema_type == 'regex_schema':
assert re.fullmatch(schema, response[0].text)
finally:
pipe.close()