diff --git a/examples/KWS_VAD_Whisper_LLM/KWS_VAD_Whisper_LLM.ino b/examples/KWS_VAD_Whisper_LLM/KWS_VAD_Whisper_LLM.ino new file mode 100644 index 0000000..c2e8aa2 --- /dev/null +++ b/examples/KWS_VAD_Whisper_LLM/KWS_VAD_Whisper_LLM.ino @@ -0,0 +1,120 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include +#include +#include + +M5ModuleLLM module_llm; + +/* Must be capitalized */ +String wake_up_keyword = "HELLO"; +// String wake_up_keyword = "你好你好"; +String kws_work_id; +String vad_work_id; +String whisper_work_id; +String llm_work_id; +String language; + +void setup() +{ + M5.begin(); + M5.Display.setTextSize(2); + M5.Display.setTextScroll(true); + M5.Display.setFont(&fonts::efontCN_12); // Support Chinese display + // M5.Display.setFont(&fonts::efontJA_12); // Support Japanese display + + language = "en_US"; + // language = "zh_CN"; + + /* Init module serial port */ + // int rxd = 16, txd = 17; // Basic + // int rxd = 13, txd = 14; // Core2 + // int rxd = 18, txd = 17; // CoreS3 + int rxd = M5.getPin(m5::pin_name_t::port_c_rxd); + int txd = M5.getPin(m5::pin_name_t::port_c_txd); + Serial2.begin(115200, SERIAL_8N1, rxd, txd); + + /* Init module */ + module_llm.begin(&Serial2); + + /* Make sure module is connected */ + M5.Display.printf(">> Check ModuleLLM connection..\n"); + while (1) { + if (module_llm.checkConnection()) { + break; + } + } + + /* Reset ModuleLLM */ + M5.Display.printf(">> Reset ModuleLLM..\n"); + module_llm.sys.reset(); + + /* Setup Audio module */ + M5.Display.printf(">> Setup audio..\n"); + module_llm.audio.setup(); + + /* Setup KWS module and save returned work id */ + M5.Display.printf(">> Setup kws..\n"); + m5_module_llm::ApiKwsSetupConfig_t kws_config; + kws_config.kws = wake_up_keyword; + kws_work_id = module_llm.kws.setup(kws_config, "kws_setup", language); + + /* Setup VAD module and save returned work id */ + M5.Display.printf(">> Setup vad..\n"); + m5_module_llm::ApiVadSetupConfig_t vad_config; + vad_config.input = {"sys.pcm", kws_work_id}; + vad_work_id = module_llm.vad.setup(vad_config, "vad_setup"); + + /* Setup Whisper module and save returned work id */ + M5.Display.printf(">> Setup whisper..\n"); + m5_module_llm::ApiWhisperSetupConfig_t whisper_config; + whisper_config.input = {"sys.pcm", kws_work_id, vad_work_id}; + whisper_config.language = "en"; + // whisper_config.language = "zh"; + // whisper_config.language = "ja"; + whisper_work_id = module_llm.whisper.setup(whisper_config, "whisper_setup"); + + M5.Display.printf(">> Setup llm..\n"); + llm_work_id = module_llm.llm.setup(); + + M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str()); +} + +void loop() +{ + /* Update ModuleLLM */ + module_llm.update(); + + /* Handle module response messages */ + for (auto& msg : module_llm.msg.responseMsgList) { + /* If KWS module message */ + if (msg.work_id == kws_work_id) { + M5.Display.setTextColor(TFT_GREENYELLOW); + M5.Display.printf(">> Keyword detected\n"); + } + + /* If ASR module message */ + if (msg.work_id == whisper_work_id) { + /* Check message object type */ + if (msg.object == "asr.utf-8") { + /* Parse message json and get ASR result */ + JsonDocument doc; + deserializeJson(doc, msg.raw_msg); + String asr_result = doc["data"].as(); + + M5.Display.setTextColor(TFT_YELLOW); + M5.Display.printf(">> %s\n", asr_result.c_str()); + module_llm.llm.inferenceAndWaitResult(llm_work_id, asr_result.c_str(), [](String& result) { + /* Show result on screen */ + M5.Display.printf("%s", result.c_str()); + }); + } + } + } + + /* Clear handled messages */ + module_llm.msg.responseMsgList.clear(); +} \ No newline at end of file diff --git a/examples/KWS_VAD_Whisper_LLM_TTS/KWS_VAD_Whisper_LLM_TTS.ino b/examples/KWS_VAD_Whisper_LLM_TTS/KWS_VAD_Whisper_LLM_TTS.ino new file mode 100644 index 0000000..6f6f033 --- /dev/null +++ b/examples/KWS_VAD_Whisper_LLM_TTS/KWS_VAD_Whisper_LLM_TTS.ino @@ -0,0 +1,136 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include +#include +#include + +M5ModuleLLM module_llm; + +/* Must be capitalized */ +String wake_up_keyword = "HELLO"; +// String wake_up_keyword = "你好你好"; +String kws_work_id; +String vad_work_id; +String whisper_work_id; +String llm_work_id; +String melotts_work_id; +String language; + +void setup() +{ + M5.begin(); + M5.Display.setTextSize(2); + M5.Display.setTextScroll(true); + // M5.Display.setFont(&fonts::efontCN_12); // Support Chinese display + // M5.Display.setFont(&fonts::efontJA_12); // Support Japanese display + + language = "en_US"; + // language = "zh_CN"; + + /* Init module serial port */ + // int rxd = 16, txd = 17; // Basic + // int rxd = 13, txd = 14; // Core2 + // int rxd = 18, txd = 17; // CoreS3 + int rxd = M5.getPin(m5::pin_name_t::port_c_rxd); + int txd = M5.getPin(m5::pin_name_t::port_c_txd); + Serial2.begin(115200, SERIAL_8N1, rxd, txd); + + /* Init module */ + module_llm.begin(&Serial2); + + /* Make sure module is connected */ + M5.Display.printf(">> Check ModuleLLM connection..\n"); + while (1) { + if (module_llm.checkConnection()) { + break; + } + } + + /* Reset ModuleLLM */ + M5.Display.printf(">> Reset ModuleLLM..\n"); + module_llm.sys.reset(); + + /* Setup Audio module */ + M5.Display.printf(">> Setup audio..\n"); + module_llm.audio.setup(); + + /* Setup KWS module and save returned work id */ + M5.Display.printf(">> Setup kws..\n"); + m5_module_llm::ApiKwsSetupConfig_t kws_config; + kws_config.kws = wake_up_keyword; + kws_work_id = module_llm.kws.setup(kws_config, "kws_setup", language); + + /* Setup VAD module and save returned work id */ + M5.Display.printf(">> Setup vad..\n"); + m5_module_llm::ApiVadSetupConfig_t vad_config; + vad_config.input = {"sys.pcm", kws_work_id}; + vad_work_id = module_llm.vad.setup(vad_config, "vad_setup"); + + /* Setup Whisper module and save returned work id */ + M5.Display.printf(">> Setup whisper..\n"); + m5_module_llm::ApiWhisperSetupConfig_t whisper_config; + whisper_config.input = {"sys.pcm", kws_work_id, vad_work_id}; + whisper_config.language = "en"; + // whisper_config.language = "zh"; + // whisper_config.language = "ja"; + whisper_work_id = module_llm.whisper.setup(whisper_config, "whisper_setup"); + + M5.Display.printf(">> Setup llm..\n"); + llm_work_id = module_llm.llm.setup(); + + M5.Display.printf(">> Setup melotts..\n\n"); + m5_module_llm::ApiMelottsSetupConfig_t melotts_config; + melotts_config.input = {"tts.utf-8.stream", llm_work_id}; + melotts_work_id = module_llm.melotts.setup(melotts_config, "melotts_setup", language); + + M5.Display.printf(">> Setup ok\n>> Say \"%s\" to wakeup\n", wake_up_keyword.c_str()); +} + +void loop() +{ + /* Update ModuleLLM */ + module_llm.update(); + + /* Handle module response messages */ + for (auto& msg : module_llm.msg.responseMsgList) { + /* If KWS module message */ + if (msg.work_id == kws_work_id) { + M5.Display.setTextColor(TFT_GREENYELLOW); + M5.Display.printf(">> Keyword detected\n"); + } + + if (msg.work_id == vad_work_id) { + M5.Display.setTextColor(TFT_GREENYELLOW); + M5.Display.printf(">> vad detected\n"); + } + /* If ASR module message */ + if (msg.work_id == whisper_work_id) { + /* Check message object type */ + if (msg.object == "asr.utf-8") { + /* Parse message json and get ASR result */ + JsonDocument doc; + deserializeJson(doc, msg.raw_msg); + String asr_result = doc["data"].as(); + + M5.Display.setTextColor(TFT_YELLOW); + M5.Display.printf(">> %s\n", asr_result.c_str()); + + module_llm.llm.inferenceAndWaitResult(llm_work_id, asr_result.c_str(), [](String& result) { + /* Show result on screen */ + handleLLMResult(result); + }); + } + } + } + + /* Clear handled messages */ + module_llm.msg.responseMsgList.clear(); +} + +void handleLLMResult(String& result) +{ + M5.Display.printf("%s", result.c_str()); +} diff --git a/examples/MeloTTS/MeloTTS.ino b/examples/MeloTTS/MeloTTS.ino new file mode 100644 index 0000000..b63a685 --- /dev/null +++ b/examples/MeloTTS/MeloTTS.ino @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include +#include +#include + +M5ModuleLLM module_llm; +String melotts_work_id; +String language; + +void setup() +{ + M5.begin(); + M5.Display.setTextSize(2); + M5.Display.setTextScroll(true); + // M5.Display.setFont(&fonts::efontCN_12); // Support Chinese display + // M5.Display.setFont(&fonts::efontJA_12); // Support Japanese display + + language = "en_US"; + // language = "zh_CN"; + // language = "ja_JP"; + + /* Init module serial port */ + // int rxd = 16, txd = 17; // Basic + // int rxd = 13, txd = 14; // Core2 + // int rxd = 18, txd = 17; // CoreS3 + int rxd = M5.getPin(m5::pin_name_t::port_c_rxd); + int txd = M5.getPin(m5::pin_name_t::port_c_txd); + Serial2.begin(115200, SERIAL_8N1, rxd, txd); + + /* Init module */ + module_llm.begin(&Serial2); + + /* Make sure module is connected */ + M5.Display.printf(">> Check ModuleLLM connection..\n"); + while (1) { + if (module_llm.checkConnection()) { + break; + } + } + + /* Reset ModuleLLM */ + M5.Display.printf(">> Reset ModuleLLM..\n"); + module_llm.sys.reset(); + + /* Setup Audio module */ + M5.Display.printf(">> Setup audio..\n"); + module_llm.audio.setup(); + + /* Setup MeloTTS module and save returned work id */ + M5.Display.printf(">> Setup melotts..\n\n"); + m5_module_llm::ApiMelottsSetupConfig_t melotts_config; + melotts_work_id = module_llm.melotts.setup(melotts_config, "melotts_setup", language); +} + +void loop() +{ + /* Make a text for speech: {i} plus {i} equals to {i + i} */ + static int i = 0; + i++; + std::string text = std::to_string(i) + " plus " + std::to_string(i) + " equals " + std::to_string(i + i) + "."; + // std::string text = std::to_string(i) + " 加 " + std::to_string(i) + " 等于 " + std::to_string(i + i) + "."; + + M5.Display.setTextColor(TFT_GREEN); + M5.Display.printf("<< %s\n\n", text.c_str()); + + /* Push text to TTS module and wait inference result */ + module_llm.tts.inference(melotts_work_id, text.c_str(), 10000); + + delay(500); +} \ No newline at end of file diff --git a/examples/VLM/VLM.ino b/examples/VLM/VLM.ino new file mode 100644 index 0000000..566f5b9 --- /dev/null +++ b/examples/VLM/VLM.ino @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include +#include +#include + +#include "M5CoreS3.h" + +M5ModuleLLM module_llm; +String vlm_work_id; + +void setup() +{ + M5.begin(); + M5.Display.setTextSize(2); + M5.Display.setTextScroll(true); + + /* Init M5CoreS3 Camera */ + CoreS3.Camera.begin(); + CoreS3.Camera.sensor->set_framesize(CoreS3.Camera.sensor, FRAMESIZE_QVGA); + + /* Init module serial port */ + int rxd = M5.getPin(m5::pin_name_t::port_c_rxd); + int txd = M5.getPin(m5::pin_name_t::port_c_txd); + Serial2.begin(115200, SERIAL_8N1, rxd, txd); + /* Init module */ + module_llm.begin(&Serial2); + + /* Make sure module is connected */ + M5.Display.printf(">> Check ModuleLLM connection..\n"); + while (1) { + if (module_llm.checkConnection()) { + break; + } + } + + /* Reset ModuleLLM */ + M5.Display.printf(">> Reset ModuleLLM..\n"); + module_llm.sys.reset(); + + /* Setup VLM module and save returned work id */ + M5.Display.printf(">> Setup vlm..\n"); + vlm_work_id = module_llm.vlm.setup(); +} + +void loop() +{ + String question = "Describe the content of the image"; + + M5.update(); + + auto count = M5.Touch.getCount(); + + static m5::touch_state_t prev_state; + auto t = M5.Touch.getDetail(); + + static int vlm_inference; + + if (t.wasClicked()) { + static unsigned long lastClickTime = 0; + unsigned long currentMillis = millis(); + if (currentMillis - lastClickTime < 800) { + vlm_inference = 2; + } + lastClickTime = currentMillis; + } + + if (t.wasFlicked()) { + vlm_inference--; + } + + if (CoreS3.Camera.get()) { + if (vlm_inference == 2) { + uint8_t* out_jpg = NULL; + size_t out_jpg_len = 0; + frame2jpg(CoreS3.Camera.fb, 50, &out_jpg, &out_jpg_len); + module_llm.vlm.inference(vlm_work_id, out_jpg, out_jpg_len); + free(out_jpg); + delay(10); + M5.Lcd.setCursor(0, 0); + /* Push question to LLM module and wait inference result */ + module_llm.vlm.inferenceAndWaitResult(vlm_work_id, question.c_str(), [](String& result) { + /* Show result on screen */ + M5.Display.printf("%s", result.c_str()); + }); + vlm_inference--; + } else if (vlm_inference == 1) { + delay(10); + } else { + CoreS3.Display.pushImage(0, 0, CoreS3.Display.width(), CoreS3.Display.height(), + (uint16_t*)CoreS3.Camera.fb->buf); + } + CoreS3.Camera.free(); + } +} \ No newline at end of file diff --git a/examples/YOLO/YOLO_CoreS3.ino b/examples/YOLO/YOLO_CoreS3.ino new file mode 100644 index 0000000..7214353 --- /dev/null +++ b/examples/YOLO/YOLO_CoreS3.ino @@ -0,0 +1,114 @@ +/* + * SPDX-FileCopyrightText: 2024 M5Stack Technology CO LTD + * + * SPDX-License-Identifier: MIT + */ +#include +#include +#include +#include + +#include "M5CoreS3.h" + +M5ModuleLLM module_llm; +String yolo_work_id; + +struct DetectionResult { + String class_name; + float confidence; + int x1; + int y1; + int x2; + int y2; +}; + +M5Canvas canvas(&M5.Display); + +void setup() +{ + M5.begin(); + M5.Display.setTextSize(2); + M5.Display.setTextScroll(true); + + canvas.createSprite(M5.Display.width(), M5.Display.height()); + + /* Init M5CoreS3 Camera */ + CoreS3.Camera.begin(); + CoreS3.Camera.sensor->set_framesize(CoreS3.Camera.sensor, FRAMESIZE_QVGA); + + /* Init module serial port */ + int rxd = M5.getPin(m5::pin_name_t::port_c_rxd); + int txd = M5.getPin(m5::pin_name_t::port_c_txd); + Serial2.begin(115200, SERIAL_8N1, rxd, txd); + /* Init module */ + module_llm.begin(&Serial2); + + /* Make sure module is connected */ + M5.Display.printf(">> Check ModuleLLM connection..\n"); + while (1) { + if (module_llm.checkConnection()) { + break; + } + } + + /* Reset ModuleLLM */ + M5.Display.printf(">> Reset ModuleLLM..\n"); + module_llm.sys.reset(); + + /* Set ModuleLLM baud rate */ + M5.Display.printf(">> ModuleLLM connected, set baud rate to 1500000\n"); + module_llm.setBaudRate(1500000); + + Serial2.begin(1500000, SERIAL_8N1, rxd, txd); + module_llm.begin(&Serial2); + + /* Setup YOLO module and save returned work id */ + M5.Display.printf(">> Setup yolo..\n"); + yolo_work_id = module_llm.yolo.setup(); + canvas.setFont(&fonts::FreeSerifBold12pt7b); +} + +DetectionResult parseDetection(String& jsonStr) +{ + DetectionResult detection; + JsonDocument doc; + deserializeJson(doc, jsonStr); + JsonObject obj = doc.as(); + USBSerial.println(jsonStr); + if (obj["bbox"].is() && obj["class"].is() && obj["confidence"].is()) { + detection.class_name = obj["class"].as(); + detection.confidence = atof(obj["confidence"].as()); + JsonArray bbox = obj["bbox"].as(); + if (bbox.size() == 4) { + detection.x1 = (int)atof(bbox[0].as()); + detection.y1 = (int)atof(bbox[1].as()); + detection.x2 = (int)atof(bbox[2].as()); + detection.y2 = (int)atof(bbox[3].as()); + } + } + return detection; +} + +void loop() +{ + if (CoreS3.Camera.get()) { + uint8_t* out_jpg = NULL; + size_t out_jpg_len = 0; + frame2jpg(CoreS3.Camera.fb, 50, &out_jpg, &out_jpg_len); + canvas.pushImage(0, 0, CoreS3.Display.width(), CoreS3.Display.height(), (uint16_t*)CoreS3.Camera.fb->buf); + module_llm.yolo.inferenceAndWaitResult( + yolo_work_id, out_jpg, out_jpg_len, + [](String& result) { + DetectionResult detection = parseDetection(result); + int y1_pos = detection.y1 - 40; + if (y1_pos < 24) y1_pos = 24; + String combinedResult = detection.class_name + " " + String(detection.confidence, 2); + canvas.drawString(combinedResult, detection.x1, y1_pos); + canvas.drawRect(detection.x1, detection.y1 - 40, detection.x2, detection.y2 - 40, ORANGE); + }, + 10); + canvas.pushSprite(0, 0); + free(out_jpg); + } + CoreS3.Camera.free(); +} \ No newline at end of file diff --git a/examples/YOLO/YOLO.ino b/examples/YOLO/YOLO_UVC.ino similarity index 90% rename from examples/YOLO/YOLO.ino rename to examples/YOLO/YOLO_UVC.ino index 408fd28..978a553 100644 --- a/examples/YOLO/YOLO.ino +++ b/examples/YOLO/YOLO_UVC.ino @@ -91,13 +91,12 @@ void loop() /* Parse message json and get YOLO result */ JsonDocument doc; deserializeJson(doc, msg.raw_msg); - JsonArray delta = doc["data"]["delta"].as(); + JsonObject delta = doc["data"]["delta"].as(); - if (delta.size() > 0) { - JsonObject result = delta[0].as(); - String class_name = result["class"].as(); - float confidence = result["confidence"].as(); - JsonArray bboxArray = result["bbox"].as(); + if (delta.containsKey("bbox") && delta.containsKey("class") && delta.containsKey("confidence")) { + String class_name = delta["class"].as(); + float confidence = delta["confidence"].as(); + JsonArray bboxArray = delta["bbox"].as(); if (bboxArray.size() == 4) { int x1 = bboxArray[0].as(); @@ -125,5 +124,5 @@ void loop() module_llm.msg.clearMsg("yolo_setup"); module_llm.msg.responseMsgList.clear(); - usleep(500000); + // usleep(500000); } \ No newline at end of file diff --git a/src/M5ModuleLLM.cpp b/src/M5ModuleLLM.cpp index d4c63f1..557c808 100644 --- a/src/M5ModuleLLM.cpp +++ b/src/M5ModuleLLM.cpp @@ -29,10 +29,15 @@ bool M5ModuleLLM::begin(Stream* serialPort) bool M5ModuleLLM::checkConnection() { const bool result = (sys.ping() == MODULE_LLM_OK); - llm_version = (sys.version() == MODULE_LLM_OK); + llm_version = sys.version(); return result; } +bool M5ModuleLLM::setBaudRate(uint32_t baudRate) +{ + return sys.setBaudRate(baudRate) == MODULE_LLM_OK; +} + void M5ModuleLLM::update() { msg.update(); diff --git a/src/M5ModuleLLM.h b/src/M5ModuleLLM.h index a1846fb..39f42ac 100644 --- a/src/M5ModuleLLM.h +++ b/src/M5ModuleLLM.h @@ -41,6 +41,14 @@ class M5ModuleLLM { */ bool checkConnection(); + /** + * @brief Set module baud rate + * + * @return true + * @return false + */ + bool setBaudRate(uint32_t baudRate); + /** * @brief Update module * diff --git a/src/api/api_asr.cpp b/src/api/api_asr.cpp index 2766fa6..2b3d328 100644 --- a/src/api/api_asr.cpp +++ b/src/api/api_asr.cpp @@ -29,7 +29,7 @@ String ApiAsr::setup(ApiAsrSetupConfig_t config, String request_id, String langu doc["data"]["rule1"] = config.rule1; doc["data"]["rule2"] = config.rule2; doc["data"]["rule3"] = config.rule3; - if (!llm_version) { + if (llm_version == "v1.0") { doc["data"]["input"] = config.input[0]; } else { JsonArray inputArray = doc["data"]["input"].to(); diff --git a/src/api/api_depth_anything.h b/src/api/api_depth_anything.h index 36cb450..06f0b99 100644 --- a/src/api/api_depth_anything.h +++ b/src/api/api_depth_anything.h @@ -39,11 +39,11 @@ class ApiDepthAnything { String exit(String work_id, String request_id = "depth_anything_exit"); /** - * @brief Inference input data by module LLM + * @brief Inference image data by module LLM * - * @param raw_len * @param work_id * @param input + * @param raw_len * @param request_id * @return int */ diff --git a/src/api/api_kws.cpp b/src/api/api_kws.cpp index a4c3103..1fab882 100644 --- a/src/api/api_kws.cpp +++ b/src/api/api_kws.cpp @@ -26,7 +26,7 @@ String ApiKws::setup(ApiKwsSetupConfig_t config, String request_id, String langu doc["data"]["response_format"] = config.response_format; doc["data"]["enoutput"] = config.enoutput; doc["data"]["kws"] = config.kws; - if (!llm_version) { + if (llm_version == "v1.0") { doc["data"]["input"] = config.input[0]; } else { JsonArray inputArray = doc["data"]["input"].to(); diff --git a/src/api/api_llm.cpp b/src/api/api_llm.cpp index 44ca3aa..d88f497 100644 --- a/src/api/api_llm.cpp +++ b/src/api/api_llm.cpp @@ -28,7 +28,7 @@ String ApiLlm::setup(ApiLlmSetupConfig_t config, String request_id) doc["data"]["enkws"] = config.enkws; doc["data"]["max_token_len"] = config.max_token_len; doc["data"]["prompt"] = config.prompt; - if (!llm_version) { + if (llm_version == "v1.0") { doc["data"]["model"] = "qwen2.5-0.5b"; doc["data"]["input"] = config.input[0]; } else { diff --git a/src/api/api_melotts.cpp b/src/api/api_melotts.cpp index adcadbf..17601c8 100644 --- a/src/api/api_melotts.cpp +++ b/src/api/api_melotts.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: MIT */ #include "api_melotts.h" +#include "api_version.h" using namespace m5_module_llm; @@ -27,7 +28,20 @@ String ApiMelotts::setup(ApiMelottsSetupConfig_t config, String request_id, Stri for (const String& str : config.input) { inputArray.add(str); } - if (language == "zh_CN") doc["data"]["model"] = "melotts_zh-cn"; + float version = llm_version.substring(1).toFloat(); + if (version >= 1.6) { + if (language == "zh_CN") { + doc["data"]["model"] = "melotts-zh-cn"; + } else if (language == "ja_JP") { + doc["data"]["model"] = "melotts-ja-jp"; + } else { + doc["data"]["model"] = "melotts-en-default"; + } + } else { + if (language == "zh_CN") { + doc["data"]["model"] = "melotts_zh-cn"; + } + } doc["data"]["enoutput"] = config.enoutput; doc["data"]["enaudio"] = config.enaudio; serializeJson(doc, cmd); diff --git a/src/api/api_sys.cpp b/src/api/api_sys.cpp index 2baaa68..e6adc3f 100644 --- a/src/api/api_sys.cpp +++ b/src/api/api_sys.cpp @@ -29,12 +29,22 @@ int ApiSys::ping() return ret; } -int ApiSys::version() +String ApiSys::version() { + String version_str; int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; _module_msg->sendCmdAndWaitToTakeMsg( - _cmd_version, "sys_version", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 2000); - return ret; + _cmd_version, "sys_version", + [&version_str, &ret](ResponseMsg_t& msg) { + ret = msg.error.code; + if (ret == MODULE_LLM_OK) { + JsonDocument doc; + deserializeJson(doc, msg.raw_msg); + version_str = doc["data"].as(); + } + }, + 2000); + return version_str; } int ApiSys::reset(bool waitResetFinish) @@ -65,3 +75,17 @@ int ApiSys::reboot() _cmd_reboot, "sys_reboot", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 2000); return ret; } + +int ApiSys::setBaudRate(uint32_t baudRate) +{ + int ret = MODULE_LLM_WAIT_RESPONSE_TIMEOUT; + String cmd = + "{\"request_id\":\"1\",\"work_id\":\"sys\",\"action\":\"uartsetup\",\"object\":\"sys.uartsetup\",\"data\":{" + "\"baud\":"; + cmd += baudRate; + cmd += ",\"data_bits\":8,\"stop_bits\":1,\"parity\":\"n\"}}"; + + _module_msg->sendCmdAndWaitToTakeMsg( + cmd.c_str(), "sys_set_baudrate", [&ret](ResponseMsg_t& msg) { ret = msg.error.code; }, 2000); + return ret; +} \ No newline at end of file diff --git a/src/api/api_sys.h b/src/api/api_sys.h index 796038b..2541854 100644 --- a/src/api/api_sys.h +++ b/src/api/api_sys.h @@ -27,13 +27,13 @@ class ApiSys { * @return int */ - int version(); + String version(); /** * @brief Check version * * @param waitCheckFinish - * @return int + * @return string */ int reset(bool waitResetFinish = true); @@ -45,6 +45,14 @@ class ApiSys { */ int reboot(); + /** + * @brief Set module baud rate + * + * @return true + * @return false + */ + int setBaudRate(uint32_t baudRate); + private: ModuleMsg* _module_msg = nullptr; }; diff --git a/src/api/api_tts.cpp b/src/api/api_tts.cpp index 7aee3c8..84823da 100644 --- a/src/api/api_tts.cpp +++ b/src/api/api_tts.cpp @@ -27,7 +27,7 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id, String langu doc["data"]["enoutput"] = config.enoutput; doc["data"]["enkws"] = config.enkws; doc["data"]["enaudio"] = config.enaudio; - if (!llm_version) { + if (llm_version == "v1.0") { doc["data"]["response_format"] = "tts.base64.wav"; doc["data"]["input"] = config.input[0]; doc["data"]["enoutput"] = true; @@ -37,7 +37,18 @@ String ApiTts::setup(ApiTtsSetupConfig_t config, String request_id, String langu inputArray.add(str); } } - if (language == "zh_CN") doc["data"]["model"] = "single_speaker_fast"; + float version = llm_version.substring(1).toFloat(); + if (version >= 1.6) { + if (language == "zh_CN") { + doc["data"]["model"] = "single-speaker-fast"; + } else { + doc["data"]["model"] = "single-speaker-english-fast"; + } + } else { + if (language == "zh_CN") { + doc["data"]["model"] = "single_speaker_fast"; + } + } serializeJson(doc, cmd); } diff --git a/src/api/api_version.cpp b/src/api/api_version.cpp index 8cf6244..ed19204 100644 --- a/src/api/api_version.cpp +++ b/src/api/api_version.cpp @@ -5,4 +5,4 @@ */ #include "api_version.h" -int llm_version = 0; +String llm_version = "v1.3"; diff --git a/src/api/api_version.h b/src/api/api_version.h index 64f5758..b431ad8 100644 --- a/src/api/api_version.h +++ b/src/api/api_version.h @@ -4,5 +4,6 @@ * SPDX-License-Identifier: MIT */ #pragma once +#include "../utils/msg.h" -extern int llm_version; +extern String llm_version; diff --git a/src/api/api_vlm.cpp b/src/api/api_vlm.cpp index 651b986..c021150 100644 --- a/src/api/api_vlm.cpp +++ b/src/api/api_vlm.cpp @@ -28,7 +28,7 @@ String ApiVlm::setup(ApiVlmSetupConfig_t config, String request_id) doc["data"]["enkws"] = config.enkws; doc["data"]["max_token_len"] = config.max_token_len; doc["data"]["prompt"] = config.prompt; - if (!llm_version) { + if (llm_version == "v1.0") { doc["data"]["model"] = "qwen2.5-0.5b"; doc["data"]["input"] = config.input[0]; } else { @@ -37,6 +37,8 @@ String ApiVlm::setup(ApiVlmSetupConfig_t config, String request_id) inputArray.add(str); } } + float version = llm_version.substring(1).toFloat(); + if (version >= 1.6) doc["data"]["model"] = "internvl2.5-1B-364-ax630c"; serializeJson(doc, cmd); } @@ -47,7 +49,7 @@ String ApiVlm::setup(ApiVlmSetupConfig_t config, String request_id) // Copy work id llm_work_id = msg.work_id; }, - 20000); + 30000); return llm_work_id; } @@ -91,6 +93,24 @@ int ApiVlm::inference(String work_id, String input, String request_id) return MODULE_LLM_OK; } +int ApiVlm::inference(String& work_id, uint8_t* input, size_t& raw_len, String request_id) +{ + String cmd; + { + JsonDocument doc; + doc["RAW"] = raw_len; + doc["request_id"] = request_id; + doc["work_id"] = work_id; + doc["action"] = "inference"; + doc["object"] = "cv.jpeg.base64"; + serializeJson(doc, cmd); + } + + _module_msg->sendCmd(cmd.c_str()); + _module_msg->sendRaw(input, raw_len); + return MODULE_LLM_OK; +} + int ApiVlm::inferenceAndWaitResult(String work_id, String input, std::function onResult, uint32_t timeout, String request_id) { diff --git a/src/api/api_vlm.h b/src/api/api_vlm.h index cad9e30..95d8932 100644 --- a/src/api/api_vlm.h +++ b/src/api/api_vlm.h @@ -52,6 +52,17 @@ class ApiVlm { */ int inference(String work_id, String input, String request_id = "vlm_inference"); + /** + * @brief Inference image data by module LLM + * + * @param work_id + * @param input + * @param raw_len + * @param request_id + * @return int + */ + int inference(String& work_id, uint8_t* input, size_t& raw_len, String request_id = "vlm_inference"); + /** * @brief Inference input data by module VLLM, and wait inference result * diff --git a/src/presets/voice_assistant.cpp b/src/presets/voice_assistant.cpp index 9ce852a..b774192 100644 --- a/src/presets/voice_assistant.cpp +++ b/src/presets/voice_assistant.cpp @@ -60,7 +60,7 @@ int M5ModuleLLM_VoiceAssistant::begin(String wakeUpKeyword, String prompt, Strin _debug("setup module tts.."); { - if (!llm_version) { + if (llm_version == "v1.0") { ApiTtsSetupConfig_t config; config.input = {_work_id.llm, _work_id.kws}; _work_id.tts = _m5_module_llm->tts.setup(config, "tts_setup", language); diff --git a/src/utils/comm.cpp b/src/utils/comm.cpp index 21d6c72..0f53c9a 100644 --- a/src/utils/comm.cpp +++ b/src/utils/comm.cpp @@ -36,44 +36,35 @@ void ModuleComm::sendRaw(const uint8_t* data, size_t& raw_len) ModuleComm::Respond_t ModuleComm::getResponse(uint32_t timeout) { Respond_t ret; + ret.time_out = false; String buffer; + uint32_t startTime = millis(); + int openBraces = 0; + bool started = false; - uint32_t time_out_count = millis(); - bool get_msg = false; - uint32_t get_msg_count = millis(); - while (1) { - // Check input - if (_serial->available()) { - get_msg = true; - while (_serial->available()) { - char c = (char)_serial->read(); - buffer += c; - - if (c == '\n') { + while (millis() - startTime < timeout) { + while (_serial->available()) { + char c = (char)_serial->read(); + buffer += c; + if (c == '{') { + started = true; + openBraces++; + } else if (c == '}') { + openBraces--; + if (started && openBraces == 0) { ret.msg = buffer; return ret; } } - get_msg_count = millis(); - time_out_count = millis(); - } - // Check package finish, if more than 50ms no input, treat it as a package - else if (get_msg) { - if (millis() - get_msg_count > 50) { - break; + if (c == '\n' && !started) { + ret.msg = buffer; + return ret; } } - - // Check timeout - if (millis() - time_out_count > timeout) { - ret.time_out = true; - break; - } - - delay(5); + delay(1); } - + ret.time_out = true; return ret; } diff --git a/src/utils/msg.cpp b/src/utils/msg.cpp index 6eb3f87..390cc39 100644 --- a/src/utils/msg.cpp +++ b/src/utils/msg.cpp @@ -14,7 +14,6 @@ void ModuleMsg::init(ModuleComm* ModuleMsg) void ModuleMsg::update() { - // 拉取串口响应 auto reponse = _module_comm->getResponse(50); if (reponse.time_out) { return; @@ -24,13 +23,11 @@ void ModuleMsg::update() void ModuleMsg::addMsgFromResponse(const char* response) { - // 尝试解析 JsonDocument doc; if (deserializeJson(doc, response) != DeserializationError::Ok) { return; } - // 压进消息列表 ResponseMsg_t new_msg; new_msg.raw_msg = response; new_msg.request_id = doc["request_id"].as(); @@ -39,7 +36,6 @@ void ModuleMsg::addMsgFromResponse(const char* response) new_msg.error.code = doc["error"]["code"]; new_msg.error.message = doc["error"]["message"].as(); responseMsgList.push_back(new_msg); - // printf("[\033[1;34mdebug\033[0m] get msg:\n%s\n", response); } void ModuleMsg::clearMsg(String request_id) @@ -56,17 +52,12 @@ void ModuleMsg::clearMsg(String request_id) bool ModuleMsg::takeMsg(String request_id, std::function onMsg) { bool ret = false; - // 遍历消息列表 for (auto iter = responseMsgList.begin(); iter != responseMsgList.end();) { - // 匹对 id - // printf("%s %s %d\n", workId.c_str(), iter->workId.c_str(), iter->workId == workId); if (iter->request_id == request_id) { ret = true; - // 触发回调 if (onMsg) { onMsg(*iter); } - // 移出列表 iter = responseMsgList.erase(iter); } else { iter++; @@ -77,24 +68,17 @@ bool ModuleMsg::takeMsg(String request_id, std::function onMsg, uint32_t timeout) { - uint32_t time_out_count = millis(); - bool is_time_out = false; - while (1) { - // 刷新消息拉取 + uint32_t startTime = millis(); + while (millis() - startTime < timeout) { update(); - // 匹配 if (takeMsg(request_id, onMsg)) { - break; + return true; } - // 超时判断 - if (millis() - time_out_count > timeout) { - is_time_out = true; - break; - } + delay(10); } - return !is_time_out; + return false; } bool ModuleMsg::sendCmdAndWaitToTakeMsg(const char* cmd, String request_id,