Skip to content

Commit 14cee1e

Browse files
authored
Fixed logging bug, added chat template (#125)
* fixed logging model input, added template Signed-off-by: Mandana Vaziri <[email protected]> * cleanup Signed-off-by: Mandana Vaziri <[email protected]> --------- Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 842994d commit 14cee1e

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

src/pdl/pdl_interpreter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,9 @@ def get_transformed_inputs(kwargs):
10511051
if "input" in litellm_params:
10521052
append_log(state, "Model Input", litellm_params["input"])
10531053
else:
1054-
append_log(state, "Model Input", model_input)
1054+
append_log(
1055+
state, "Model Input", messages_to_str(concrete_block.model, model_input)
1056+
)
10551057
background: Messages = [msg]
10561058
result = msg["content"]
10571059
append_log(state, "Model Output", result)
@@ -1093,7 +1095,7 @@ def generate_client_response_streaming(
10931095
model_input: Messages,
10941096
) -> Generator[YieldMessage, Any, Message]:
10951097
msg_stream: Generator[Message, Any, None]
1096-
model_input_str = messages_to_str(model_input)
1098+
model_input_str = messages_to_str(block.model, model_input)
10971099
match block:
10981100
case BamModelBlock():
10991101
msg_stream = BamModel.generate_text_stream(
@@ -1148,7 +1150,7 @@ def generate_client_response_single(
11481150
model_input: Messages,
11491151
) -> Generator[YieldMessage, Any, Message]:
11501152
msg: Message
1151-
model_input_str = messages_to_str(model_input)
1153+
model_input_str = messages_to_str(block.model, model_input)
11521154
match block:
11531155
case BamModelBlock():
11541156
msg = BamModel.generate_text(
@@ -1178,7 +1180,7 @@ def generate_client_response_batching( # pylint: disable=too-many-arguments
11781180
# model: str,
11791181
model_input: Messages,
11801182
) -> Generator[YieldMessage, Any, Message]:
1181-
model_input_str = messages_to_str(model_input)
1183+
model_input_str = messages_to_str(block.model, model_input)
11821184
match block:
11831185
case BamModelBlock():
11841186
msg = yield ModelCallMessage(

src/pdl/pdl_utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,20 @@ def messages_concat(messages1: Messages, messages2: Messages) -> Messages:
3232
return messages1 + messages2
3333

3434

35-
def messages_to_str(messages: Messages) -> str:
36-
# TODO
37-
return "".join(
38-
[
39-
(
40-
msg["content"]
41-
if msg["role"] is None
42-
else f"<|{msg['role']}|>{msg['content']}"
43-
)
44-
for msg in messages
45-
]
35+
def messages_to_str(model_id: str, messages: Messages) -> str:
36+
if "granite-3b" not in model_id and "granite-8b" not in model_id:
37+
return "".join([(msg["content"]) for msg in messages])
38+
return (
39+
"".join(
40+
[
41+
(
42+
msg["content"]
43+
if msg["role"] is None
44+
# else f"<|{msg['role']}|>{msg['content']}"
45+
else f"<|start_of_role|>{msg['role']}<|end_of_role|>{msg['content']}<|end_of_text|>\n"
46+
)
47+
for msg in messages
48+
]
49+
)
50+
+ "<|start_of_role|>assistant<|end_of_role|>"
4651
)

0 commit comments

Comments
 (0)