-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·135 lines (108 loc) · 3.95 KB
/
utils.py
File metadata and controls
executable file
·135 lines (108 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# -*- encoding: utf-8 -*-
import json
import logging
import logging.config
import os
import re
from typing import Any, Optional, Tuple, List, Dict
import yaml
def setup_logging(
output_directory,
config_path='config/logging.yaml'
):
"""
Load logging configuration from a YAML configuration file.
"""
config = load_yaml_safe(config_path)
log_dir = os.path.join(output_directory, "logs")
os.makedirs(log_dir, exist_ok=True)
config['handlers']['file']['filename'] = os.path.join(log_dir, f"tableeval.log")
logging.config.dictConfig(config)
def load_yaml_safe(file_path):
"""
Load configuration from a YAML file.
"""
try:
with open(file_path, "r", encoding="utf-8") as file:
config_data = yaml.safe_load(file)
if not config_data:
raise ValueError(f"Config file {file_path} is empty or invalid.")
except FileNotFoundError as e:
raise FileNotFoundError(f"Config file {file_path} not found.") from e
except yaml.YAMLError as e:
raise ValueError(f"Failed to parse YAML file: {e}") from e
return config_data
def load_api_configuration(model_name: str, config_file: str) -> Tuple[str, str, Optional[float], int]:
"""
Load API configuration from a YAML file.
Returns a tuple of (base_url, api_key, timeout, max_retries).
"""
config_data = load_yaml_safe(config_file)
api_config = config_data.get(model_name)
if not api_config:
raise ValueError(f"No configuration found for model '{model_name}'.")
api_key = api_config.get('api_key')
base_url = api_config.get('base_url')
if not api_key:
raise ValueError("Missing 'api_key' in API configuration.")
if not model_name:
raise ValueError("Missing 'model_name' in API configuration.")
return api_config
def load_json(file_path):
"""
Load JSON data from the specified file.
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件 {file_path} 不存在")
try:
with open(file_path, 'r', encoding='utf-8') as f:
if file_path.endswith('.jsonl'):
return [json.loads(line) for line in f] # 逐行解析 JSONL
elif file_path.endswith('.json'):
return json.load(f) # 解析标准 JSON
else:
raise ValueError("Only support .json and .jsonl")
except (json.JSONDecodeError, OSError) as e:
raise ValueError(f"JSON file load failed: {file_path}")
def save_json(data: Any, file_path: str, msg: Optional[str] = None) -> None:
"""
Save data to a JSON file.
"""
try:
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
if msg:
logging.info(f"Save {msg} to {file_path}")
except Exception as e:
logging.error(f"Error saving file: {e}")
def load_txt(file_path: str) -> str:
"""
Read and return the content of the specified file using UTF-8 encoding.
"""
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def sanitize_filename(filename: str) -> str:
"""
Sanitize the filename by replacing illegal characters with underscores.
"""
return re.sub(r'[\\/*?:"<>|]', "_", filename)
def generate_prediction_path(
output_dir: str,
model_name: str,
reasoning_type: str,
context_type: str,
specific_tasks: Optional[List[str]] = None,
specific_ids: bool = False
) -> str:
"""
Generate the output file path based on provided parameters.
"""
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{sanitize_filename(model_name)}_{reasoning_type}_{context_type}")
if specific_tasks:
specific_tasks_str = "_".join(specific_tasks)
output_path += "_" + specific_tasks_str
if specific_ids:
output_path += "_specific_ids"
output_path += ".json"
return output_path