-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathllms.py
More file actions
100 lines (73 loc) · 2.44 KB
/
llms.py
File metadata and controls
100 lines (73 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import re
import os
import json
import logging
from dotenv import load_dotenv
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_random_exponential, before_sleep_log
load_dotenv()
logger = logging.getLogger(__name__)
# Define retry strategy parameters
RETRY_TIMES = int(os.getenv('RETRY_TIMES'))
WAIT_TIME_LOWER = int(os.getenv('WAIT_TIME_LOWER'))
WAIT_TIME_UPPER = int(os.getenv('WAIT_TIME_UPPER'))
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
MODEL = os.getenv('OPENAI_MODEL')
common_params = {}
if os.getenv('OPENAI_MAX_TOKENS'):
common_params["max_tokens"] = int(os.getenv('OPENAI_MAX_TOKENS'))
if os.getenv('OPENAI_TEMPERATURE'):
common_params["temperature"] = float(os.getenv('OPENAI_TEMPERATURE'))
if os.getenv('OPENAI_TIMEOUT'):
common_params["timeout"] = int(os.getenv('OPENAI_TIMEOUT'))
client = OpenAI(
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY
)
@retry(
wait=wait_random_exponential(min=WAIT_TIME_LOWER, max=WAIT_TIME_UPPER),
stop=stop_after_attempt(RETRY_TIMES),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING)
)
def llm_request(prompt):
"""
Sends a request to the specified language model with a given prompt.
Args:
prompt (str): The input text or message to send to the model.
Returns:
str: The response generated by the model.
"""
response_obj = client.chat.completions.create(
model=MODEL,
messages=[
{
'role': 'user',
'content': prompt
}
],
**common_params
)
return response_obj.choices[0].message.content
@retry(
wait=wait_random_exponential(min=WAIT_TIME_LOWER, max=WAIT_TIME_UPPER),
stop=stop_after_attempt(RETRY_TIMES),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING)
)
def llm_request_for_json(prompt):
response_obj = client.chat.completions.create(
model=MODEL,
messages=[{'role': 'user', 'content': prompt}],
**common_params
)
content = response_obj.choices[0].message.content or ""
match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL)
if not match:
raise ValueError(f"No JSON block found in model output: {content}")
json_str = match.group(1).strip()
return json.loads(json_str)
if __name__ == '__main__':
r = llm_request_for_json('hello')
print(r)