Skip to content

Commit e7c071b

Browse files
committed
修复代码问题,优化采集逻辑
1 parent 0e117e3 commit e7c071b

File tree

14 files changed

+1231
-511
lines changed

14 files changed

+1231
-511
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
Table cars_data has columns such as id, weight.\nid is the primary key.\nTable model_list has columns such as maker, model.
2+
3+
Table model_list has columns such as model, maker.\nTable cars_data has columns such as id, weight.\nid is the primary key.
4+
5+
python evaluation.py --gold /Users/jizha/code/python/spider/dataset/spider/dev_gold.sql --pred /Users/jizha/code/python/test-suite-sql-eval/二轮测试_gpt4_choice.json --etype all --db /Users/jizha/code/python/spider/dataset/spider/database --table tables.json
6+
7+
```sh
8+
python generate_question.py \
9+
--data_type spider \
10+
--split test \
11+
--tokenizer gpt-3.5-turbo \
12+
--max_seq_len 4096 \
13+
--selector_type EUCDISMASKPRESKLSIMTHR \
14+
--pre_test_result /Users/jizha/code/python/test-suite-sql-eval/随机列测试/union_test_20231201_random_table.sql \
15+
--prompt_repr SQL \
16+
--k_shot 9 \
17+
--example_type QA
18+
19+
```
20+
21+
```
22+
import argparse
23+
import os
24+
import json
25+
26+
import openai
27+
from tqdm import tqdm
28+
29+
from llm.chatgpt import init_chatgpt, ask_llm
30+
from utils.enums import LLM
31+
from torch.utils.data import DataLoader
32+
33+
from utils.post_process import process_duplication, get_sqls
34+
import concurrent.futures
35+
36+
QUESTION_FILE = "questions.json"
37+
38+
39+
def gen_predict_sql(index, token_cnt, args, batch):
40+
try:
41+
res = ask_llm(args.model, batch, args.temperature, args.n)
42+
except openai.error.InvalidRequestError:
43+
print(f"The {i}-th question has too much tokens! Return \"SELECT\" instead")
44+
res = ""
45+
# parse result
46+
token_cnt += res["total_tokens"]
47+
results = []
48+
if args.n == 1:
49+
for sql in res["response"]:
50+
# remove \n and extra spaces
51+
sql = " ".join(sql.replace("\n", " ").split())
52+
sql = process_duplication(sql)
53+
# python version should >= 3.8
54+
if sql.startswith("SELECT"):
55+
results.append(sql)
56+
elif sql.startswith(" "):
57+
results.append("SELECT" + sql)
58+
else:
59+
results.append("SELECT " + sql)
60+
else:
61+
cur_db_ids = db_ids[i * args.batch_size: i * args.batch_size + len(batch)]
62+
for sqls, db_id in zip(res["response"], cur_db_ids):
63+
processed_sqls = []
64+
for sql in sqls:
65+
sql = " ".join(sql.replace("\n", " ").split())
66+
sql = process_duplication(sql)
67+
if sql.startswith("SELECT"):
68+
pass
69+
elif sql.startswith(" "):
70+
sql = "SELECT" + sql
71+
else:
72+
sql = "SELECT " + sql
73+
processed_sqls.append(sql)
74+
result = {
75+
'db_id': db_id,
76+
'p_sqls': processed_sqls
77+
}
78+
final_sqls = get_sqls([result], args.n, args.db_dir)
79+
results = final_sqls
80+
return index, results
81+
82+
83+
if __name__ == '__main__':
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument("--question", type=str)
86+
parser.add_argument("--openai_api_key", type=str, default="eab38a33cc07467aae9b7d09783b75a8")
87+
parser.add_argument("--openai_group_id", type=str, default="luli.wjc")
88+
parser.add_argument("--openai_api_base", type=str,
89+
default="https://codegencore.antgroup-inc.cn/api/chat/commonPower/v1")
90+
parser.add_argument("--model", type=str, choices=[LLM.TEXT_DAVINCI_003,
91+
LLM.GPT_35_TURBO,
92+
LLM.GPT_35_TURBO_0613,
93+
LLM.TONG_YI_QIAN_WEN,
94+
LLM.GPT_35_TURBO_16K,
95+
LLM.GPT_4],
96+
default=LLM.GPT_35_TURBO)
97+
parser.add_argument("--start_index", type=int, default=0)
98+
parser.add_argument("--end_index", type=int, default=1000000)
99+
parser.add_argument("--temperature", type=float, default=0)
100+
parser.add_argument("--mini_index_path", type=str, default="")
101+
parser.add_argument("--batch_size", type=int, default=1)
102+
parser.add_argument("--n", type=int, default=1, help="Size of self-consistent set")
103+
parser.add_argument("--db_dir", type=str, default="dataset/spider/database")
104+
args = parser.parse_args()
105+
106+
# check args
107+
assert args.model in LLM.BATCH_FORWARD or \
108+
args.model not in LLM.BATCH_FORWARD and args.batch_size == 1, \
109+
f"{args.model} doesn't support batch_size > 1"
110+
111+
questions_json = json.load(open(os.path.join(args.question, QUESTION_FILE), "r"))
112+
questions = [_["prompt"] for _ in questions_json["questions"]]
113+
db_ids = [_["db_id"] for _ in questions_json["questions"]]
114+
115+
# init openai api
116+
init_chatgpt(args.openai_api_key, args.openai_group_id, args.openai_api_base, args.model)
117+
118+
if args.start_index == 0:
119+
mode = "w"
120+
else:
121+
mode = "a"
122+
123+
if args.mini_index_path:
124+
mini_index = json.load(open(args.mini_index_path, 'r'))
125+
questions = [questions[i] for i in mini_index]
126+
out_file = f"{args.question}/RESULTS_MODEL-{args.model}_MINI.txt"
127+
else:
128+
out_file = f"{args.question}/RESULTS_MODEL-{args.model}.txt"
129+
130+
question_loader = DataLoader(questions, batch_size=args.batch_size, shuffle=False, drop_last=False)
131+
132+
token_cnt = 0
133+
results = []
134+
with open(out_file, mode) as f:
135+
for i in tqdm(range(0, len(question_loader), 10)):
136+
up = i + 10
137+
if len(question_loader) < up:
138+
up = len(question_loader)
139+
result_temp = [""] * (up - i)
140+
future_list = []
141+
with concurrent.futures.ThreadPoolExecutor() as executor:
142+
question_batch = question_loader[i:up]
143+
for index, item in enumerate(question_batch):
144+
future_list.append(executor.submit(gen_predict_sql, index, token_cnt, args, item))
145+
for future in concurrent.futures.as_completed(future_list):
146+
index, p_sqls = future.result()
147+
result_temp[index] = p_sqls
148+
for item in result_temp:
149+
f.write("".join(item))
150+
results.extend(item)
151+
152+
153+
```
154+

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "sqlgpt-parser"
7-
version = "0.0.1a5"
7+
version = "0.0.1a7"
88
authors = [
99
{ name="luliwjc", email="[email protected]" },
1010
{ name="Ifffff", email="[email protected]" },
@@ -35,7 +35,7 @@ dependencies = [
3535
line-length=120
3636

3737
[tool.black]
38-
skip-string-normalization = 1
38+
skip-string-normalization = false
3939
force-exclude = '''
4040
sqlgpt_parser/parser/mysql_parser/parser_table.py
4141
| sqlgpt_parser/parser/oceanbase_parser/parser_table.py

sqlgpt_parser/parser/mysql_parser/lexer.py

Lines changed: 78 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,43 @@
2121

2222
tokens = (
2323
[
24-
'IDENTIFIER',
25-
'DIGIT_IDENTIFIER',
26-
'QUOTED_IDENTIFIER',
27-
'BACKQUOTED_IDENTIFIER',
28-
'PERIOD',
29-
'COMMA',
30-
'PLUS',
31-
'MINUS',
32-
'LPAREN',
33-
'RPAREN',
34-
'ANDAND',
35-
'ASSIGNMENTEQ',
36-
'GT',
37-
'GE',
38-
'LT',
39-
'LE',
40-
'EQ',
41-
'NE',
42-
'NULL_SAFE_EQ',
43-
'BIT_OR',
44-
'BIT_AND',
45-
'BIT_XOR',
46-
'BIT_OPPOSITE',
47-
'EXCLA_MARK',
48-
'BIT_MOVE_LEFT',
49-
'BIT_MOVE_RIGHT',
50-
'PIPES',
51-
'SLASH',
52-
'ASTERISK',
53-
'PERCENT',
54-
'NUMBER',
55-
'FRACTION',
56-
'QM',
57-
'SCONST',
58-
'SINGLE_AT_IDENTIFIER',
59-
'DOUBLE_AT_IDENTIFIER',
60-
'HEX_NUMBER',
24+
"IDENTIFIER",
25+
"DIGIT_IDENTIFIER",
26+
"QUOTED_IDENTIFIER",
27+
"BACKQUOTED_IDENTIFIER",
28+
"PERIOD",
29+
"COMMA",
30+
"PLUS",
31+
"MINUS",
32+
"LPAREN",
33+
"RPAREN",
34+
"ANDAND",
35+
"ASSIGNMENTEQ",
36+
"GT",
37+
"GE",
38+
"LT",
39+
"LE",
40+
"EQ",
41+
"NE",
42+
"NULL_SAFE_EQ",
43+
"BIT_OR",
44+
"BIT_AND",
45+
"BIT_XOR",
46+
"BIT_OPPOSITE",
47+
"EXCLA_MARK",
48+
"BIT_MOVE_LEFT",
49+
"BIT_MOVE_RIGHT",
50+
"PIPES",
51+
"SLASH",
52+
"ASTERISK",
53+
"PERCENT",
54+
"NUMBER",
55+
"FRACTION",
56+
"QM",
57+
"SCONST",
58+
"SINGLE_AT_IDENTIFIER",
59+
"DOUBLE_AT_IDENTIFIER",
60+
"HEX_NUMBER",
6161
]
6262
+ list(reversed)
6363
+ list(nonreserved)
@@ -66,48 +66,48 @@
6666

6767
sql_tokens = list(reversed) + list(nonreserved) + list(not_keyword_token)
6868

69-
t_LPAREN = r'\('
70-
t_RPAREN = r'\)'
71-
72-
t_ASSIGNMENTEQ = r':='
73-
t_EQ = r'='
74-
t_NE = r'<>|!='
75-
t_LT = r'<'
76-
t_LE = r'<='
77-
t_GT = r'>'
78-
t_GE = r'>='
79-
t_NULL_SAFE_EQ = r'<=>'
80-
t_PERIOD = r'\.'
81-
t_COMMA = r','
82-
t_PLUS = r'\+'
83-
t_MINUS = r'-'
84-
t_ASTERISK = r'\*'
85-
t_SLASH = r'/'
86-
t_PERCENT = r'%'
87-
t_QM = r'\?'
69+
t_LPAREN = r"\("
70+
t_RPAREN = r"\)"
71+
72+
t_ASSIGNMENTEQ = r":="
73+
t_EQ = r"="
74+
t_NE = r"<>|!="
75+
t_LT = r"<"
76+
t_LE = r"<="
77+
t_GT = r">"
78+
t_GE = r">="
79+
t_NULL_SAFE_EQ = r"<=>"
80+
t_PERIOD = r"\."
81+
t_COMMA = r","
82+
t_PLUS = r"\+"
83+
t_MINUS = r"-"
84+
t_ASTERISK = r"\*"
85+
t_SLASH = r"/"
86+
t_PERCENT = r"%"
87+
t_QM = r"\?"
8888

8989
# TODO
9090
# By default, || is a logical OR operator.
9191
# With PIPES_AS_CONCAT enabled, || is string concatenation.
9292
# Need support or semantics in future development
93-
t_PIPES = r'\|\|'
93+
t_PIPES = r"\|\|"
9494

95-
t_ignore = ' \t'
95+
t_ignore = " \t"
9696

97-
t_ANDAND = r'\&\&'
98-
t_BIT_OR = r'\|'
99-
t_BIT_AND = r'\&'
100-
t_BIT_XOR = r'\^'
101-
t_BIT_OPPOSITE = r'\~'
102-
t_BIT_MOVE_LEFT = r'<<'
103-
t_BIT_MOVE_RIGHT = r'>>'
104-
t_EXCLA_MARK = r'!'
97+
t_ANDAND = r"\&\&"
98+
t_BIT_OR = r"\|"
99+
t_BIT_AND = r"\&"
100+
t_BIT_XOR = r"\^"
101+
t_BIT_OPPOSITE = r"\~"
102+
t_BIT_MOVE_LEFT = r"<<"
103+
t_BIT_MOVE_RIGHT = r">>"
104+
t_EXCLA_MARK = r"!"
105105

106106

107107
def t_DOUBLE(t):
108108
r"[0-9]*\.[0-9]+([eE][-+]?[0-9]+)?|[-+]?[0-9]+([eE][-+]?[0-9]+)"
109-
if 'e' in t.value or 'E' in t.value or '.' in t.value:
110-
t.type = 'FRACTION'
109+
if "e" in t.value or "E" in t.value or "." in t.value:
110+
t.type = "FRACTION"
111111
else:
112112
t.type = "NUMBER"
113113
return t
@@ -129,7 +129,7 @@ def t_NUMBER_START_WITH_XB(t):
129129
def t_IDENTIFIER(t):
130130
r"""[a-zA-Z\u4e00-\u9fa50-9_$][a-zA-Z\u4e00-\u9fa50-9_@:$]*"""
131131
if re.match(
132-
r'(^0[xX][0-9a-fA-F]+$)|(^0[bB][01]+$)|(^\d+$)',
132+
r"(^0[xX][0-9a-fA-F]+$)|(^0[bB][01]+$)|(^\d+$)",
133133
t.value,
134134
):
135135
t.type = "NUMBER"
@@ -155,21 +155,21 @@ def t_DOUBLE_AT_IDENTIFIER(t):
155155

156156

157157
def t_QUOTED_IDENTIFIER(t):
158-
r'"(\\["\\]|[^"]|["]{2})*"'
158+
r""" "(\\["\\]|[^"]|["]{2})*\" """
159159
t.type = "QUOTED_IDENTIFIER"
160160
return t
161161

162162

163163
def t_BACKQUOTED_IDENTIFIER(t):
164-
r'`([^`]|``)*`'
164+
r"""`([^`]|``)*`"""
165165
val = t.value.lower()
166166
if val in tokens:
167167
t.type = tokens[val]
168168
return t
169169

170170

171171
def t_newline(t):
172-
r'[\r\n]+'
172+
r"""[\r\n]+"""
173173
t.lexer.lineno += t.value.count("\n")
174174

175175

@@ -179,7 +179,12 @@ def t_error(t):
179179

180180

181181
def t_COMMENT(t):
182-
r'(\/\*\*\/)|(/\*((?!\/\*).)+\*/)'
182+
r"""(\/\*\*\/)|(/\*((?!\/\*).)+\*/)"""
183+
pass
184+
185+
186+
def t_SEMICOLON(t):
187+
r""";"""
183188
pass
184189

185190

0 commit comments

Comments
 (0)