Skip to content

fix: fix process_mm_info for tool call#388

Open
NicholasCao wants to merge 1 commit intoQwenLM:mainfrom
NicholasCao:main
Open

fix: fix process_mm_info for tool call#388
NicholasCao wants to merge 1 commit intoQwenLM:mainfrom
NicholasCao:main

Conversation

@NicholasCao
Copy link

fix: fix process_mm_info for tool call

bug:

import json
import re
import torch
import copy
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
from qwen_omni_utils import process_mm_info

MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
USE_AUDIO_IN_VIDEO = True

# ===== 模型加载 =====
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH)

# ===== 工具函数 =====
def get_current_weather(location: str, unit: str = "celsius") -> dict:
    mock_data = {"北京": {"temperature": 22, "unit": "celsius", "description": "晴"}}
    data = mock_data.get(location, {"temperature": 20, "unit": unit, "description": "Unknown"})
    return {"location": location, "temperature": data["temperature"], "unit": data["unit"], "description": data["description"]}

TOOL_FUNCTIONS = {"get_current_weather": get_current_weather}
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "获取指定城市的当前天气信息",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {"type": "string", "description": "城市名称"},
                    "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "温度单位"}
                },
                "required": ["location"]
            }
        }
    }
]

def parse_tool_calls_from_output(output_text: str):
    """
    从模型输出中提取工具调用。
    Qwen3-Omni 使用 <tool_call>...</tool_call> 包裹 JSON。
    """
    tool_calls = []
    pattern = r"<tool_call>(.*?)<\/tool_call>"
    matches = re.findall(pattern, output_text, re.DOTALL)
    for match in matches:
        try:
            # 清理可能的多余空格或换行
            clean_json = match.strip()
            call_dict = json.loads(clean_json)
            tool_calls.append(call_dict)
        except Exception as e:
            print(f"⚠️ Failed to parse tool call: {match}, error: {e}")
            continue
    return tool_calls

def run_multimodal_tool_call(conversation: list, max_rounds: int = 2):
    messages = conversation.copy()
    
    for round_idx in range(max_rounds):
        print(f"\n🔄 Round {round_idx + 1}")
        
        # === Step 1: 处理多模态信息 ===
        audios, images, videos, video_kwargs = process_mm_info(
            messages, 
            use_audio_in_video=USE_AUDIO_IN_VIDEO, 
            return_video_kwargs=True
        )
        
        # === Step 2: 构建文本 prompt(不 tokenize)===
        text_prompt = processor.apply_chat_template(
            messages, 
            tools=TOOLS,
            add_generation_prompt=True, 
            tokenize=False
        )
        print("📝 Prompt sent to model:")
        print(repr(text_prompt))
        print("-" * 50)
        
        # === Step 3: Tokenize + 多模态输入 ===
        inputs = processor(
            text=text_prompt,
            audio=audios,
            images=images,
            videos=videos,
            videos_kwargs=video_kwargs,
            return_tensors="pt",
            padding=True,
            use_audio_in_video=USE_AUDIO_IN_VIDEO,
        ).to(model.device)

        # === Step 4: 生成响应 ===
        with torch.no_grad():
            text_ids, audio = model.generate(**inputs, 
                                    thinker_return_dict_in_generate=True,
                                    thinker_max_new_tokens=1024, 
                                    thinker_do_sample=False,
                                    speaker="Ethan", 
                                    use_audio_in_video=USE_AUDIO_IN_VIDEO,
                                    return_audio=False)

        generated_tokens = text_ids.sequences[0][inputs.input_ids.shape[1]:] #text_ids[0]#[inputs.input_ids.shape[1]:]
        response_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=False).strip()

        print(f"🤖 Raw Output: {repr(response_text)}")

        # === Step 5: 尝试解析工具调用 ===
        tool_calls = parse_tool_calls_from_output(response_text)

        if tool_calls:
            print(f"🔧 Detected {len(tool_calls)} tool call(s): {tool_calls}")
            
            new_tool_calls_for_history = []
            tool_response_messages = []
            
            for i, tc in enumerate(tool_calls):
                func_name = tc.get("name")
                args = tc.get("arguments", {})
                call_id = f"call_{round_idx}_{i}"
                
                # 执行工具函数
                if func_name in TOOL_FUNCTIONS:
                    try:
                        result = TOOL_FUNCTIONS[func_name](**args)
                    except Exception as e:
                        result = {"error": f"Tool execution failed: {str(e)}"}
                else:
                    result = {"error": "Unknown function"}
                
                # 工具调用记录(用于 assistant 消息)
                new_tool_calls_for_history.append({
                    "id": call_id,
                    "type": "function",
                    "function": {
                        "name": func_name,
                        "arguments": json.dumps(args, ensure_ascii=False)
                    }
                })
                
                # 工具响应(role='tool')
                tool_response_messages.append({
                    "role": "tool",
                    "tool_call_id": call_id,
                    "content": json.dumps(result, ensure_ascii=False)
                })
            
            # 更新对话历史
            messages.append({
                "role": "assistant",
                "tool_calls": new_tool_calls_for_history
            })
            messages.extend(tool_response_messages)
            
        else:
            # 无工具调用,视为最终回答
            # 移除可能的特殊 token(如 <|im_end|>)
            final_answer = response_text.split("<|im_end|>")[0].strip()
            print(f"\n💬 Final Answer: {final_answer}")
            return final_answer

    print(f"\n⚠️ Max rounds reached. Last output: {response_text}")
    return response_text.split("<|im_end|>")[0].strip()


# ===== 测试 =====
if __name__ == "__main__":
    conv1 = [
        {"role": "user", "content": "北京天气如何?"}
    ]
    run_multimodal_tool_call(conv1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant