Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ authors = [
]
description = "Toolkit for Chat API"
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.9"
classifiers = [
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
2 changes: 1 addition & 1 deletion src/chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '4.2.0'
__version__ = '4.3.0'


from chattool.core import (
Expand Down
91 changes: 86 additions & 5 deletions src/chattool/core/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import os
import logging
import hashlib
from typing import List, Dict, Union, Optional, AsyncGenerator, Any
from chattool.utils import setup_logger
from .request import HTTPClient
from filelock import FileLock
import tempfile
from typing import Awaitable, Callable, List, Dict, Union, Optional, AsyncGenerator, Any
from batch_executor import batch_executor, batch_async_executor, batch_hybrid_executor
from pathlib import Path

from chattool.core.response import ChatResponse
from chattool.utils import valid_models
from chattool.utils.urltool import curl_cmd_of_chat_completion
from chattool.utils import valid_models, setup_logger, curl_cmd_of_chat_completion
from chattool.const import (
OPENAI_API_BASE, OPENAI_API_KEY, OPENAI_API_MODEL,
AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_MODEL, AZURE_OPENAI_API_VERSION, AZURE_OPENAI_ENDPOINT
)
from .request import HTTPClient

class Chat(HTTPClient):
def __init__(self,
Expand Down Expand Up @@ -424,6 +427,21 @@ def copy(self) -> 'Chat':
messages = [msg.copy() for msg in self._chat_log]
return Chat(messages=messages)

@classmethod
def load_chats(cls, path: Union[str, Path]) -> list[Union['Chat', None]]:
"""从文件加载多个对话历史"""
path = Path(path)
if not path.exists():
return []
data = path.read_text()
datas = [json.loads(i) for i in data.splitlines()]
if not datas:
return []
chats = [None] * (max([i['index'] for i in datas]) + 1)
for data in datas:
chats[data['index']] = cls(messages=data['chat_log'])
return chats

# === 显示和调试 ===
def print_log(self, sep: str = "\n" + "-" * 50 + "\n"):
"""打印对话历史"""
Expand Down Expand Up @@ -468,6 +486,69 @@ def print_curl(self, **options):
)
print(curl_cmd)

# === 并发执行 ===
@classmethod
def batch_process_chat(cls,
messages:list[str],
async_func:Callable[[str], Awaitable['Chat']],
nworker: int = 1,
pool_size:int = 1,
checkpoint: Optional[Union[str, Path]] = None,
overwrite: bool = False
) -> list['Chat']:
"""批量处理聊天消息

Args:
messages: 待处理的消息列表
async_func: 处理函数,输入为消息字符串,输出为处理后的 Chat 对象

Returns:
处理后的 Chat 对象列表
"""
chats = [] # 最终结果
msgs = list(enumerate(messages)) # 处理数据

# 处理已缓存数据
if checkpoint is not None: # 缓存存在
checkpoint = Path(checkpoint)
# assert checkpoint.suffix in ('.json', '.jsonl'), "Checkpoint file must be JSON or JSONL format"
checkpoint.parent.mkdir(parents=True, exist_ok=True)
if checkpoint.exists():
if overwrite:
checkpoint.unlink()
else:
chats = cls.load_chats(checkpoint)
# 跳过已处理数据
msgs = [(idt, msg) for idt, msg in enumerate(messages) if idt >= len(chats) or chats[idt] is None]
if not msgs:
return chats
# 开始处理数据
with tempfile.TemporaryDirectory() as temp_dir:
if checkpoint: # TODO: 多进程模式下使用
# if pool_id is not None:
# lock = f'{temp_dir}/{checkpoint.stem}_{pool_id}.lock'
# dest_file = f'{checkpoint.parent}/{checkpoint.stem}_{pool_id}.{checkpoint.suffix}'
lock = f'{temp_dir}/{checkpoint}.lock'
async def process_message(idt_msg:tuple[int, str]) -> Chat:
idt, msg = idt_msg
chat = await async_func(msg)
if checkpoint:
with FileLock(lock):
chat.save(checkpoint, index=idt)
return chat
if pool_size <= 1:
newchats = batch_async_executor(msgs, process_message, nworker=nworker)
else:
newchats = batch_hybrid_executor(msgs, process_message, nworker=nworker, pool_size=pool_size)
if not chats:
return newchats
idt = 0
for i, chat in enumerate(chats):
if chat is None:
chats[i] = newchats[idt]
idt += 1
return chats + newchats[idt:]

# === 属性访问 ===
@property
def chat_log(self) -> List[Dict]:
Expand Down
26 changes: 26 additions & 0 deletions tests/core/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio
from chattool import Chat, AzureChat
from chattool.const import CHATTOOL_REPO_DIR
test_dir = CHATTOOL_REPO_DIR / 'tests' / 'testfiles'

async def get_res(msg):
chat = Chat(msg)
chat.assistant('hi')
asyncio.sleep(0.1)
return chat

def test_batch():
Chat.batch_process_chat(
[f'hello {i}' for i in range(10)], get_res)

checkpoint = test_dir/'hello'
Chat.batch_process_chat(
[f'hello {i}' for i in range(2)], get_res, checkpoint=checkpoint)

Chat.batch_process_chat(
[f'hello {i}' for i in range(10)], get_res, checkpoint=checkpoint)

checkpoint.unlink()

if __name__ == '__main__':
test_batch()
Loading