Skip to content

Commit 53d6abf

Browse files
authored
Merge pull request #98 from danmcp/removefastchatdep
Remove fastchat dependency
2 parents 893b6ec + ca129ab commit 53d6abf

File tree

5 files changed

+380
-11
lines changed

5 files changed

+380
-11
lines changed

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
FastChat
32
GitPython>=3.1.42,<4.0.0
43
shortuuid
54
openai>=1.13.3,<2.0.0

src/instructlab/eval/mt_bench_answers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import time
77

88
# Third Party
9-
# TODO need to look into this dependency
10-
from fastchat.model.model_adapter import get_conversation_template # type: ignore
119
import shortuuid
1210
import tqdm
1311

@@ -20,6 +18,7 @@
2018
load_questions,
2119
temperature_config,
2220
)
21+
from .mt_bench_model_adapter import get_conversation_template # type: ignore
2322

2423
logger = setup_logger(__name__)
2524

@@ -61,7 +60,7 @@ def get_answer(
6160

6261
choices = []
6362
for i in range(num_choices):
64-
conv = get_conversation_template(model)
63+
conv = get_conversation_template(model, "granite")
6564

6665
turns = []
6766
for j in range(len(question["turns"])):

src/instructlab/eval/mt_bench_common.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
import time
1414

1515
# Third Party
16-
from fastchat import conversation
17-
from fastchat.model.model_adapter import get_conversation_template # type: ignore
1816
import openai
1917

2018
# First Party
2119
from instructlab.eval import exceptions
2220

2321
# Local
2422
from .logger_config import setup_logger
23+
from .mt_bench_conversation import Conversation
24+
from .mt_bench_model_adapter import get_conversation_template
2525

2626
logger = setup_logger(__name__)
2727

@@ -158,7 +158,7 @@ def run_judge_single(
158158
rating = -1
159159

160160
system_prompt = judge.prompt_template["system_prompt"]
161-
conv = get_conversation_template(model)
161+
conv = get_conversation_template(model, "mixtral")
162162
conv.set_system_message(system_prompt)
163163
conv.append_message(conv.roles[0], user_prompt)
164164
conv.append_message(conv.roles[1], None)
@@ -268,9 +268,7 @@ class Message(TypedDict):
268268
role: str
269269

270270

271-
def _get_messages(
272-
conv: conversation.Conversation, merge_system_user_message: bool
273-
) -> list[Message]:
271+
def _get_messages(conv: Conversation, merge_system_user_message: bool) -> list[Message]:
274272
messages = conv.to_openai_api_messages()
275273
if (
276274
(merge_system_user_message or conv.name == "mistral")
@@ -285,7 +283,7 @@ def _get_messages(
285283
def chat_completion_openai(
286284
openai_client,
287285
model,
288-
conv: conversation.Conversation,
286+
conv: Conversation,
289287
temperature,
290288
max_tokens,
291289
merge_system_user_message: bool = False,
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Conversation prompt templates.
4+
"""
5+
6+
# Standard
7+
from enum import IntEnum, auto
8+
from typing import Dict, List, Tuple, Union
9+
import dataclasses
10+
11+
12+
class SeparatorStyle(IntEnum):
13+
"""Separator styles."""
14+
15+
ADD_COLON_SINGLE = auto()
16+
ADD_COLON_TWO = auto()
17+
ADD_COLON_SPACE_SINGLE = auto()
18+
NO_COLON_SINGLE = auto()
19+
NO_COLON_TWO = auto()
20+
ADD_NEW_LINE_SINGLE = auto()
21+
LLAMA2 = auto()
22+
DEFAULT = auto()
23+
24+
25+
@dataclasses.dataclass
26+
class Conversation:
27+
# pylint: disable=too-many-instance-attributes
28+
"""A class that manages prompt templates and keeps all conversation history."""
29+
30+
# The name of this template
31+
name: str
32+
# The template of the system prompt
33+
system_template: str = "{system_message}"
34+
# The system message
35+
system_message: str = ""
36+
# The names of two roles
37+
roles: Tuple[str, str] = ("USER", "ASSISTANT")
38+
# All messages. Each item is (role, message).
39+
# Each message is either a string or a tuple of (string, List[image_url]).
40+
messages: List[List[str | None]] = dataclasses.field(default_factory=list)
41+
# The number of few shot examples
42+
offset: int = 0
43+
# The separator style and configurations
44+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
45+
sep: str | None = "\n"
46+
sep2: str | None = None
47+
# Stop criteria (the default one is EOS token)
48+
stop_str: Union[str, List[str]] | None = None
49+
# Stops generation if meeting any token in this list
50+
stop_token_ids: List[int] | None = None
51+
52+
def set_system_message(self, system_message: str):
53+
"""Set the system message."""
54+
self.system_message = system_message
55+
56+
def get_system_message(self):
57+
"""return the system message."""
58+
return self.system_message
59+
60+
def append_message(self, role: str, message: str | None):
61+
"""Append a new message."""
62+
self.messages.append([role, message])
63+
64+
def update_last_message(self, message: str):
65+
"""Update the last output.
66+
67+
The last message is typically set to be None when constructing the prompt,
68+
so we need to update it in-place after getting the response from a model.
69+
"""
70+
self.messages[-1][1] = message
71+
72+
def to_openai_api_messages(self):
73+
"""Convert the conversation to OpenAI chat completion format."""
74+
if self.system_message == "":
75+
ret = []
76+
else:
77+
ret = [{"role": "system", "content": self.system_message}]
78+
79+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
80+
if i % 2 == 0:
81+
ret.append({"role": "user", "content": msg})
82+
else:
83+
if msg is not None:
84+
ret.append({"role": "assistant", "content": msg})
85+
return ret
86+
87+
def copy(self):
88+
return Conversation(
89+
name=self.name,
90+
system_template=self.system_template,
91+
system_message=self.system_message,
92+
roles=self.roles,
93+
messages=[[x, y] for x, y in self.messages],
94+
offset=self.offset,
95+
sep_style=self.sep_style,
96+
sep=self.sep,
97+
sep2=self.sep2,
98+
stop_str=self.stop_str,
99+
stop_token_ids=self.stop_token_ids,
100+
)
101+
102+
def dict(self):
103+
return {
104+
"template_name": self.name,
105+
"system_message": self.system_message,
106+
"roles": self.roles,
107+
"messages": self.extract_text_from_messages(),
108+
"offset": self.offset,
109+
}
110+
111+
112+
# A global registry for all conversation templates
113+
conv_templates: Dict[str, Conversation] = {}
114+
115+
116+
def register_conv_template(template: Conversation, override: bool = False):
117+
"""Register a new conversation template."""
118+
if not override:
119+
assert (
120+
template.name not in conv_templates
121+
), f"{template.name} has been registered."
122+
123+
conv_templates[template.name] = template
124+
125+
126+
def get_conv_template(name: str) -> Conversation:
127+
"""Get a conversation template."""
128+
return conv_templates[name].copy()
129+
130+
131+
# An empty template for raw conversation.
132+
register_conv_template(
133+
Conversation(
134+
name="raw",
135+
system_message="",
136+
roles=("", ""),
137+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
138+
sep="",
139+
)
140+
)
141+
142+
143+
# api-based default template
144+
register_conv_template(
145+
Conversation(
146+
name="api_based_default",
147+
system_message="",
148+
roles=("user", "assistant"),
149+
sep_style=SeparatorStyle.DEFAULT,
150+
sep=None,
151+
)
152+
)
153+
154+
155+
# ChatGPT default template
156+
register_conv_template(
157+
Conversation(
158+
name="chatgpt",
159+
system_message="You are a helpful assistant.",
160+
roles=("user", "assistant"),
161+
sep_style=SeparatorStyle.DEFAULT,
162+
sep=None,
163+
)
164+
)
165+
166+
# Mistral template
167+
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
168+
register_conv_template(
169+
Conversation(
170+
name="mistral",
171+
system_template="[INST] {system_message}\n",
172+
roles=("[INST]", "[/INST]"),
173+
sep_style=SeparatorStyle.LLAMA2,
174+
sep=" ",
175+
sep2="</s>",
176+
)
177+
)
178+
179+
register_conv_template(
180+
Conversation(
181+
name="labrador-chat",
182+
system_template="<|system|>\n{system_message}",
183+
system_message="""You are Labrador, an AI language model developed by IBM DMF (Data Model Factory) Alignment Team. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior. You always respond to greetings (for example, hi, hello, g'day, morning, afternoon, evening, night, what's up, nice to meet you, sup, etc) with "Hello! I am Labrador, created by the IBM DMF Alignment Team. How can I help you today?". Please do not say anything else and do not start a conversation.""",
184+
roles=("<|user|>", "<|assistant|>"),
185+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
186+
sep="\n",
187+
stop_str="<|endoftext|>",
188+
)
189+
)
190+
191+
register_conv_template(
192+
Conversation(
193+
name="ibm-generic",
194+
system_template="<|system|>\n{system_message}",
195+
system_message="""You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.""",
196+
roles=("<|user|>", "<|assistant|>"),
197+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
198+
sep="\n",
199+
stop_str="<|endoftext|>",
200+
)
201+
)
202+
203+
register_conv_template(
204+
Conversation(
205+
name="granite-chat",
206+
system_template="<|system|>\n{system_message}",
207+
system_message="""You are Granite Chat, an AI language model developed by IBM. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.""",
208+
roles=("<|user|>", "<|assistant|>"),
209+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
210+
sep="\n",
211+
stop_str="<|endoftext|>",
212+
)
213+
)

0 commit comments

Comments
 (0)