|
1 | 1 | import json |
| 2 | +import re |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 | from jsonschema import validate |
|
8 | 9 |
|
9 | 10 | MODEL_IDS = [ |
10 | 11 | 'Qwen/Qwen3-0.6B', |
11 | | - 'OpenGVLab/InternVL3_5-1B', |
| 12 | + 'OpenGVLab/InternVL3_5-4B', |
12 | 13 | ] |
13 | 14 |
|
14 | 15 | BACKEND_FACTORIES = [ |
15 | 16 | ('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)), |
16 | 17 | ('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)), |
17 | 18 | ] |
18 | 19 |
|
19 | | -GUIDE_SCHEMA = { |
20 | | - 'type': 'object', |
21 | | - 'properties': { |
22 | | - 'name': { |
23 | | - 'type': 'string' |
24 | | - }, |
25 | | - 'skills': { |
26 | | - 'type': 'array', |
27 | | - 'items': { |
28 | | - 'type': 'string', |
29 | | - 'maxLength': 10 |
| 20 | +SCHEMA_MAP = { |
| 21 | + 'json_schema': { |
| 22 | + 'type': 'object', |
| 23 | + 'properties': { |
| 24 | + 'name': { |
| 25 | + 'type': 'string' |
30 | 26 | }, |
31 | | - 'minItems': 3, |
32 | | - 'maxItems': 10, |
33 | | - }, |
34 | | - 'work history': { |
35 | | - 'type': 'array', |
36 | | - 'items': { |
37 | | - 'type': 'object', |
38 | | - 'properties': { |
39 | | - 'company': { |
40 | | - 'type': 'string' |
41 | | - }, |
42 | | - 'duration': { |
43 | | - 'type': 'string' |
| 27 | + 'skills': { |
| 28 | + 'type': 'array', |
| 29 | + 'items': { |
| 30 | + 'type': 'string', |
| 31 | + 'maxLength': 10 |
| 32 | + }, |
| 33 | + 'minItems': 3, |
| 34 | + 'maxItems': 10, |
| 35 | + }, |
| 36 | + 'work history': { |
| 37 | + 'type': 'array', |
| 38 | + 'items': { |
| 39 | + 'type': 'object', |
| 40 | + 'properties': { |
| 41 | + 'company': { |
| 42 | + 'type': 'string' |
| 43 | + }, |
| 44 | + 'duration': { |
| 45 | + 'type': 'string' |
| 46 | + }, |
44 | 47 | }, |
| 48 | + 'required': ['company'], |
45 | 49 | }, |
46 | | - 'required': ['company'], |
47 | 50 | }, |
48 | 51 | }, |
| 52 | + 'required': ['name', 'skills', 'work history'], |
49 | 53 | }, |
50 | | - 'required': ['name', 'skills', 'work history'], |
| 54 | + 'regex_schema': 'call me [A-Za-z]{1,10}', |
| 55 | + 'json_object': None, |
51 | 56 | } |
52 | 57 |
|
53 | 58 |
|
54 | 59 | @pytest.mark.parametrize('model_id', MODEL_IDS) |
55 | 60 | @pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES) |
56 | | -@pytest.mark.parametrize('enable_guide', [True, False]) |
57 | | -def test_guided_matrix(model_id, backend_name, backend_factory, enable_guide): |
| 61 | +@pytest.mark.parametrize('schema_type', list(SCHEMA_MAP.keys()) + [None]) |
| 62 | +def test_guided_matrix(model_id, backend_name, backend_factory, schema_type): |
58 | 63 | pipe = pipeline( |
59 | 64 | model_id, |
60 | 65 | backend_config=backend_factory(), |
61 | 66 | log_level='INFO', |
62 | 67 | ) |
63 | 68 |
|
| 69 | + if schema_type is None: |
| 70 | + enable_guide = False |
| 71 | + else: |
| 72 | + enable_guide = True |
| 73 | + response_format = {'type': schema_type} |
| 74 | + schema = SCHEMA_MAP[schema_type] |
| 75 | + if schema_type == 'json_schema': |
| 76 | + response_format[schema_type] = dict(name='test', schema=schema) |
| 77 | + elif schema_type == 'regex_schema': |
| 78 | + response_format[schema_type] = schema |
| 79 | + |
64 | 80 | try: |
65 | 81 | if enable_guide: |
66 | | - gen_config = GenerationConfig(response_format=dict( |
67 | | - type='json_schema', |
68 | | - json_schema=dict(name='test', schema=GUIDE_SCHEMA), |
69 | | - ), ) |
| 82 | + gen_config = GenerationConfig(response_format=response_format) |
70 | 83 | else: |
71 | 84 | gen_config = GenerationConfig() |
72 | 85 |
|
73 | 86 | response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config) |
74 | 87 | assert response and response[0].text |
75 | 88 |
|
76 | 89 | if enable_guide: |
77 | | - validate(instance=json.loads(response[0].text), schema=GUIDE_SCHEMA) |
| 90 | + if schema_type == 'json_schema': |
| 91 | + validate(instance=json.loads(response[0].text), schema=schema) |
| 92 | + elif schema_type == 'json_object': |
| 93 | + validate(instance=json.loads(response[0].text), schema={'type': 'object', 'additionalProperties': True}) |
| 94 | + elif schema_type == 'regex_schema': |
| 95 | + assert re.fullmatch(schema, response[0].text) |
78 | 96 | finally: |
79 | 97 | pipe.close() |
0 commit comments