|
| 1 | +Table cars_data has columns such as id, weight.\nid is the primary key.\nTable model_list has columns such as maker, model. |
| 2 | + |
| 3 | +Table model_list has columns such as model, maker.\nTable cars_data has columns such as id, weight.\nid is the primary key. |
| 4 | + |
| 5 | +python evaluation.py --gold /Users/jizha/code/python/spider/dataset/spider/dev_gold.sql --pred /Users/jizha/code/python/test-suite-sql-eval/二轮测试_gpt4_choice.json --etype all --db /Users/jizha/code/python/spider/dataset/spider/database --table tables.json |
| 6 | + |
| 7 | +```sh |
| 8 | +python generate_question.py \ |
| 9 | +--data_type spider \ |
| 10 | +--split test \ |
| 11 | +--tokenizer gpt-3.5-turbo \ |
| 12 | +--max_seq_len 4096 \ |
| 13 | +--selector_type EUCDISMASKPRESKLSIMTHR \ |
| 14 | +--pre_test_result /Users/jizha/code/python/test-suite-sql-eval/随机列测试/union_test_20231201_random_table.sql \ |
| 15 | +--prompt_repr SQL \ |
| 16 | +--k_shot 9 \ |
| 17 | +--example_type QA |
| 18 | + |
| 19 | +``` |
| 20 | + |
| 21 | +``` |
| 22 | +import argparse |
| 23 | +import os |
| 24 | +import json |
| 25 | +
|
| 26 | +import openai |
| 27 | +from tqdm import tqdm |
| 28 | +
|
| 29 | +from llm.chatgpt import init_chatgpt, ask_llm |
| 30 | +from utils.enums import LLM |
| 31 | +from torch.utils.data import DataLoader |
| 32 | +
|
| 33 | +from utils.post_process import process_duplication, get_sqls |
| 34 | +import concurrent.futures |
| 35 | +
|
| 36 | +QUESTION_FILE = "questions.json" |
| 37 | +
|
| 38 | +
|
| 39 | +def gen_predict_sql(index, token_cnt, args, batch): |
| 40 | + try: |
| 41 | + res = ask_llm(args.model, batch, args.temperature, args.n) |
| 42 | + except openai.error.InvalidRequestError: |
| 43 | + print(f"The {i}-th question has too much tokens! Return \"SELECT\" instead") |
| 44 | + res = "" |
| 45 | + # parse result |
| 46 | + token_cnt += res["total_tokens"] |
| 47 | + results = [] |
| 48 | + if args.n == 1: |
| 49 | + for sql in res["response"]: |
| 50 | + # remove \n and extra spaces |
| 51 | + sql = " ".join(sql.replace("\n", " ").split()) |
| 52 | + sql = process_duplication(sql) |
| 53 | + # python version should >= 3.8 |
| 54 | + if sql.startswith("SELECT"): |
| 55 | + results.append(sql) |
| 56 | + elif sql.startswith(" "): |
| 57 | + results.append("SELECT" + sql) |
| 58 | + else: |
| 59 | + results.append("SELECT " + sql) |
| 60 | + else: |
| 61 | + cur_db_ids = db_ids[i * args.batch_size: i * args.batch_size + len(batch)] |
| 62 | + for sqls, db_id in zip(res["response"], cur_db_ids): |
| 63 | + processed_sqls = [] |
| 64 | + for sql in sqls: |
| 65 | + sql = " ".join(sql.replace("\n", " ").split()) |
| 66 | + sql = process_duplication(sql) |
| 67 | + if sql.startswith("SELECT"): |
| 68 | + pass |
| 69 | + elif sql.startswith(" "): |
| 70 | + sql = "SELECT" + sql |
| 71 | + else: |
| 72 | + sql = "SELECT " + sql |
| 73 | + processed_sqls.append(sql) |
| 74 | + result = { |
| 75 | + 'db_id': db_id, |
| 76 | + 'p_sqls': processed_sqls |
| 77 | + } |
| 78 | + final_sqls = get_sqls([result], args.n, args.db_dir) |
| 79 | + results = final_sqls |
| 80 | + return index, results |
| 81 | +
|
| 82 | +
|
| 83 | +if __name__ == '__main__': |
| 84 | + parser = argparse.ArgumentParser() |
| 85 | + parser.add_argument("--question", type=str) |
| 86 | + parser.add_argument("--openai_api_key", type=str, default="eab38a33cc07467aae9b7d09783b75a8") |
| 87 | + parser.add_argument("--openai_group_id", type=str, default="luli.wjc") |
| 88 | + parser.add_argument("--openai_api_base", type=str, |
| 89 | + default="https://codegencore.antgroup-inc.cn/api/chat/commonPower/v1") |
| 90 | + parser.add_argument("--model", type=str, choices=[LLM.TEXT_DAVINCI_003, |
| 91 | + LLM.GPT_35_TURBO, |
| 92 | + LLM.GPT_35_TURBO_0613, |
| 93 | + LLM.TONG_YI_QIAN_WEN, |
| 94 | + LLM.GPT_35_TURBO_16K, |
| 95 | + LLM.GPT_4], |
| 96 | + default=LLM.GPT_35_TURBO) |
| 97 | + parser.add_argument("--start_index", type=int, default=0) |
| 98 | + parser.add_argument("--end_index", type=int, default=1000000) |
| 99 | + parser.add_argument("--temperature", type=float, default=0) |
| 100 | + parser.add_argument("--mini_index_path", type=str, default="") |
| 101 | + parser.add_argument("--batch_size", type=int, default=1) |
| 102 | + parser.add_argument("--n", type=int, default=1, help="Size of self-consistent set") |
| 103 | + parser.add_argument("--db_dir", type=str, default="dataset/spider/database") |
| 104 | + args = parser.parse_args() |
| 105 | +
|
| 106 | + # check args |
| 107 | + assert args.model in LLM.BATCH_FORWARD or \ |
| 108 | + args.model not in LLM.BATCH_FORWARD and args.batch_size == 1, \ |
| 109 | + f"{args.model} doesn't support batch_size > 1" |
| 110 | +
|
| 111 | + questions_json = json.load(open(os.path.join(args.question, QUESTION_FILE), "r")) |
| 112 | + questions = [_["prompt"] for _ in questions_json["questions"]] |
| 113 | + db_ids = [_["db_id"] for _ in questions_json["questions"]] |
| 114 | +
|
| 115 | + # init openai api |
| 116 | + init_chatgpt(args.openai_api_key, args.openai_group_id, args.openai_api_base, args.model) |
| 117 | +
|
| 118 | + if args.start_index == 0: |
| 119 | + mode = "w" |
| 120 | + else: |
| 121 | + mode = "a" |
| 122 | +
|
| 123 | + if args.mini_index_path: |
| 124 | + mini_index = json.load(open(args.mini_index_path, 'r')) |
| 125 | + questions = [questions[i] for i in mini_index] |
| 126 | + out_file = f"{args.question}/RESULTS_MODEL-{args.model}_MINI.txt" |
| 127 | + else: |
| 128 | + out_file = f"{args.question}/RESULTS_MODEL-{args.model}.txt" |
| 129 | +
|
| 130 | + question_loader = DataLoader(questions, batch_size=args.batch_size, shuffle=False, drop_last=False) |
| 131 | +
|
| 132 | + token_cnt = 0 |
| 133 | + results = [] |
| 134 | + with open(out_file, mode) as f: |
| 135 | + for i in tqdm(range(0, len(question_loader), 10)): |
| 136 | + up = i + 10 |
| 137 | + if len(question_loader) < up: |
| 138 | + up = len(question_loader) |
| 139 | + result_temp = [""] * (up - i) |
| 140 | + future_list = [] |
| 141 | + with concurrent.futures.ThreadPoolExecutor() as executor: |
| 142 | + question_batch = question_loader[i:up] |
| 143 | + for index, item in enumerate(question_batch): |
| 144 | + future_list.append(executor.submit(gen_predict_sql, index, token_cnt, args, item)) |
| 145 | + for future in concurrent.futures.as_completed(future_list): |
| 146 | + index, p_sqls = future.result() |
| 147 | + result_temp[index] = p_sqls |
| 148 | + for item in result_temp: |
| 149 | + f.write("".join(item)) |
| 150 | + results.extend(item) |
| 151 | +
|
| 152 | +
|
| 153 | +``` |
| 154 | + |
0 commit comments