diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index c5f3f6b..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "java.configuration.updateBuildConfiguration": "interactive" -} \ No newline at end of file diff --git a/app/build.gradle b/app/build.gradle index 33e8a66..1de6cb8 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -5,7 +5,7 @@ plugins { android { namespace 'com.digitalperson' - compileSdk 33 + compileSdk 34 buildFeatures { buildConfig true @@ -24,6 +24,11 @@ android { } } + // 不压缩大文件,避免内存不足错误 + aaptOptions { + noCompress 'rknn', 'rkllm', 'onnx', 'model', 'bin', 'json' + } + defaultConfig { applicationId "com.digitalperson" minSdk 21 @@ -40,7 +45,7 @@ android { buildConfigField "String", "LLM_API_URL", "\"${(project.findProperty('LLM_API_URL') ?: 'https://ark.cn-beijing.volces.com/api/v3/chat/completions').toString()}\"" buildConfigField "String", "LLM_API_KEY", "\"${(project.findProperty('LLM_API_KEY') ?: '').toString()}\"" buildConfigField "String", "LLM_MODEL", "\"${(project.findProperty('LLM_MODEL') ?: 'doubao-1-5-pro-32k-character-250228').toString()}\"" - buildConfigField "boolean", "USE_LIVE2D", "${(project.findProperty('USE_LIVE2D') ?: 'false').toString()}" + buildConfigField "boolean", "USE_LIVE2D", "${(project.findProperty('USE_LIVE2D') ?: 'true').toString()}" ndk { abiFilters "arm64-v8a" @@ -63,6 +68,8 @@ android { } dependencies { + + implementation 'androidx.core:core-ktx:1.7.0' implementation 'androidx.appcompat:appcompat:1.6.1' implementation 'com.google.android.material:material:1.9.0' @@ -73,6 +80,15 @@ dependencies { androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' // ExoPlayer for video playback (used to show silent / speaking videos) implementation 'com.google.android.exoplayer:exoplayer:2.18.6' + implementation 'androidx.camera:camera-core:1.3.4' + implementation 'androidx.camera:camera-camera2:1.3.4' + implementation 'androidx.camera:camera-lifecycle:1.3.4' + implementation 'androidx.camera:camera-view:1.3.4' implementation project(':framework') implementation files('../Live2DFramework/Core/android/Live2DCubismCore.aar') + + // Tencent Cloud TTS SDK + implementation files('libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar') + implementation 'com.google.code.gson:gson:2.8.9' + implementation 'com.squareup.okhttp3:okhttp:4.9.3' } diff --git a/app/libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar b/app/libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar new file mode 100644 index 0000000..59afdf8 Binary files /dev/null and b/app/libs/realtime_tts-release-v2.0.16-20260128-d80cafe.aar differ diff --git a/app/note/design_doc b/app/note/design_doc index b5cd071..f971813 100644 --- a/app/note/design_doc +++ b/app/note/design_doc @@ -71,6 +71,7 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime - haru_g_m25 - 扁嘴 - haru_g_m24 - 低头斜看地板,收手到背后 - haru_g_m05 扁嘴,张开双手 +- haru_g_m16 双手捧腮,思考 ### 😠 愤怒类情绪 - haru_g_m11 双手交叉,摇头,扁嘴 @@ -90,8 +91,6 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime - haru_g_m12 摆手,摇头 ### 😕 困惑类情绪 -- haru_g_m20 手指点腮,思考,皱眉 -- haru_g_m16 双手捧腮,思考 - haru_g_m14 身体前倾,皱眉 - haru_g_m13 身体前倾,双手分开 @@ -99,6 +98,22 @@ TTS Sherpa-ONNX VITS .onnx ❌ 否 CPU ONNX Runtime - haru_g_m19 脸红微笑 +### 担心 +- haru_g_m20 手指点腮,思考,皱眉 + ### ❤️ 关心类情绪 - haru_g_m17 靠近侧脸 +6. 其实可以抄一下讯飞的超脑平台的功能: +https://aiui-doc.xf-yun.com/project-2/doc-397/ + +7. 人脸检测使用的是RKNN zoo里的 retinaface模型,转成了rknn格式,并且使用了wider_face的数据集(验证集)进行了校准,下载地址: +https://www.modelscope.cn/datasets/shaoxuan/WIDER_FACE/files + +8. 人脸识别模型是insightface的r18模型,转成了rknn格式,并且使用了 lfw 的数据集进行了校准,下载地址: +https://tianchi.aliyun.com/dataset/93864 + +9. + + + diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index a7fa46d..1b87526 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -2,13 +2,17 @@ + + + + android:theme="@style/Theme.DigitalPerson" + android:usesCleartextTraffic="true"> +#include +#include +#include + +#define LOG_TAG "ArcFaceRKNN" +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) + +ArcFaceEngineRKNN::ArcFaceEngineRKNN() = default; + +ArcFaceEngineRKNN::~ArcFaceEngineRKNN() { + release(); +} + +int ArcFaceEngineRKNN::init(const char* modelPath) { + release(); + int ret = rknn_init(&ctx_, (void*)modelPath, 0, 0, nullptr); + if (ret != RKNN_SUCC) { + LOGE("rknn_init failed: %d model=%s", ret, modelPath); + return ret; + } + + std::memset(&ioNum_, 0, sizeof(ioNum_)); + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &ioNum_, sizeof(ioNum_)); + if (ret != RKNN_SUCC || ioNum_.n_input < 1 || ioNum_.n_output < 1) { + LOGE("query io num failed: ret=%d in=%u out=%u", ret, ioNum_.n_input, ioNum_.n_output); + release(); + return ret != RKNN_SUCC ? ret : -1; + } + + std::memset(&inputAttr_, 0, sizeof(inputAttr_)); + inputAttr_.index = 0; + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &inputAttr_, sizeof(inputAttr_)); + if (ret != RKNN_SUCC) { + LOGE("query input attr failed: %d", ret); + release(); + return ret; + } + if (inputAttr_.n_dims == 4) { + if (inputAttr_.fmt == RKNN_TENSOR_NHWC) { + inputH_ = static_cast(inputAttr_.dims[1]); + inputW_ = static_cast(inputAttr_.dims[2]); + } else { + inputH_ = static_cast(inputAttr_.dims[2]); + inputW_ = static_cast(inputAttr_.dims[3]); + } + } else if (inputAttr_.n_dims == 3) { + inputH_ = static_cast(inputAttr_.dims[1]); + inputW_ = static_cast(inputAttr_.dims[2]); + } + + std::memset(&outputAttr_, 0, sizeof(outputAttr_)); + outputAttr_.index = 0; + ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &outputAttr_, sizeof(outputAttr_)); + if (ret != RKNN_SUCC) { + LOGW("query output attr failed: %d", ret); + } + + initialized_ = true; + LOGI("ArcFace initialized input=%dx%d", inputW_, inputH_); + return 0; +} + +std::vector ArcFaceEngineRKNN::extractEmbedding( + const uint32_t* argbPixels, + int width, + int height, + int strideBytes, + float left, + float top, + float right, + float bottom) { + LOGI("extractEmbedding called: initialized=%d, ctx=%p, pixels=%p, width=%d, height=%d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f", + initialized_, ctx_, argbPixels, width, height, left, top, right, bottom); + + std::vector empty; + if (!initialized_ || ctx_ == 0 || argbPixels == nullptr || width <= 0 || height <= 0) { + LOGW("extractEmbedding failed: invalid parameters"); + return empty; + } + + const float faceW = right - left; + const float faceH = bottom - top; + LOGI("Face size: width=%.2f, height=%.2f", faceW, faceH); + + if (faceW < 4.0f || faceH < 4.0f) { + LOGW("extractEmbedding failed: face too small"); + return empty; + } + + const float pad = 0.15f; + float x1f = std::max(0.0f, left - faceW * pad); + float y1f = std::max(0.0f, top - faceH * pad); + float x2f = std::min(static_cast(width), right + faceW * pad); + float y2f = std::min(static_cast(height), bottom + faceH * pad); + int x1 = static_cast(std::floor(x1f)); + int y1 = static_cast(std::floor(y1f)); + int x2 = static_cast(std::ceil(x2f)); + int y2 = static_cast(std::ceil(y2f)); + if (x2 <= x1 || y2 <= y1) { + return empty; + } + + const int cropW = x2 - x1; + const int cropH = y2 - y1; + const int srcStridePx = strideBytes / 4; + std::vector rgb(inputW_ * inputH_ * 3); + for (int y = 0; y < inputH_; ++y) { + const int sy = y1 + (y * cropH / inputH_); + const uint32_t* srcRow = argbPixels + sy * srcStridePx; + uint8_t* dst = rgb.data() + y * inputW_ * 3; + for (int x = 0; x < inputW_; ++x) { + const int sx = x1 + (x * cropW / inputW_); + const uint32_t pixel = srcRow[sx]; + dst[3 * x + 0] = (pixel >> 16) & 0xFF; + dst[3 * x + 1] = (pixel >> 8) & 0xFF; + dst[3 * x + 2] = pixel & 0xFF; + } + } + + rknn_input input{}; + input.index = 0; + input.type = RKNN_TENSOR_UINT8; + input.size = rgb.size(); + input.buf = rgb.data(); + input.pass_through = 0; + input.fmt = (inputAttr_.fmt == RKNN_TENSOR_NCHW) ? RKNN_TENSOR_NCHW : RKNN_TENSOR_NHWC; + + std::vector nchw; + if (input.fmt == RKNN_TENSOR_NCHW) { + nchw.resize(rgb.size()); + const int hw = inputW_ * inputH_; + for (int i = 0; i < hw; ++i) { + nchw[i] = rgb[3 * i + 0]; + nchw[hw + i] = rgb[3 * i + 1]; + nchw[2 * hw + i] = rgb[3 * i + 2]; + } + input.buf = nchw.data(); + } + + int ret = rknn_inputs_set(ctx_, 1, &input); + if (ret != RKNN_SUCC) { + LOGW("rknn_inputs_set failed: %d", ret); + return empty; + } + ret = rknn_run(ctx_, nullptr); + if (ret != RKNN_SUCC) { + LOGW("rknn_run failed: %d", ret); + return empty; + } + LOGI("rknn_run succeeded"); + + std::vector outputs(ioNum_.n_output); + for (uint32_t i = 0; i < ioNum_.n_output; ++i) { + std::memset(&outputs[i], 0, sizeof(rknn_output)); + outputs[i].want_float = 1; + } + ret = rknn_outputs_get(ctx_, ioNum_.n_output, outputs.data(), nullptr); + if (ret != RKNN_SUCC) { + LOGW("rknn_outputs_get failed: %d", ret); + return empty; + } + LOGI("rknn_outputs_get succeeded: n_output=%u", ioNum_.n_output); + + if (outputs[0].buf == nullptr) { + LOGW("Output buffer is null"); + rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data()); + return empty; + } + + size_t elems = outputAttr_.n_elems > 0 ? static_cast(outputAttr_.n_elems) : 0; + if (elems == 0) { + elems = 1; + for (uint32_t d = 0; d < outputAttr_.n_dims; ++d) { + if (outputAttr_.dims[d] > 0) { + elems *= static_cast(outputAttr_.dims[d]); + } + } + } + if (elems == 0) { + rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data()); + return empty; + } + + const float* ptr = reinterpret_cast(outputs[0].buf); + std::vector embedding(ptr, ptr + elems); + rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data()); + + // L2 normalize for cosine similarity. + float sum = 0.0f; + for (float v : embedding) sum += v * v; + const float norm = std::sqrt(std::max(sum, 1e-12f)); + for (float& v : embedding) v /= norm; + return embedding; +} + +void ArcFaceEngineRKNN::release() { + if (ctx_ != 0) { + rknn_destroy(ctx_); + ctx_ = 0; + } + std::memset(&ioNum_, 0, sizeof(ioNum_)); + std::memset(&inputAttr_, 0, sizeof(inputAttr_)); + std::memset(&outputAttr_, 0, sizeof(outputAttr_)); + initialized_ = false; +} diff --git a/app/src/main/cpp/ArcFaceEngineRKNN.h b/app/src/main/cpp/ArcFaceEngineRKNN.h new file mode 100644 index 0000000..896986d --- /dev/null +++ b/app/src/main/cpp/ArcFaceEngineRKNN.h @@ -0,0 +1,36 @@ +#ifndef DIGITAL_PERSON_ARCFACE_ENGINE_RKNN_H +#define DIGITAL_PERSON_ARCFACE_ENGINE_RKNN_H + +#include +#include + +#include "rknn_api.h" + +class ArcFaceEngineRKNN { +public: + ArcFaceEngineRKNN(); + ~ArcFaceEngineRKNN(); + + int init(const char* modelPath); + std::vector extractEmbedding( + const uint32_t* argbPixels, + int width, + int height, + int strideBytes, + float left, + float top, + float right, + float bottom); + void release(); + +private: + rknn_context ctx_ = 0; + bool initialized_ = false; + rknn_input_output_num ioNum_{}; + rknn_tensor_attr inputAttr_{}; + rknn_tensor_attr outputAttr_{}; + int inputW_ = 112; + int inputH_ = 112; +}; + +#endif diff --git a/app/src/main/cpp/ArcFaceEngineRKNNJNI.cpp b/app/src/main/cpp/ArcFaceEngineRKNNJNI.cpp new file mode 100644 index 0000000..98f2d32 --- /dev/null +++ b/app/src/main/cpp/ArcFaceEngineRKNNJNI.cpp @@ -0,0 +1,109 @@ +#include +#include +#include + +#include "ArcFaceEngineRKNN.h" + +#define LOG_TAG "ArcFaceJNI" +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_com_digitalperson_engine_ArcFaceEngineRKNN_createEngineNative(JNIEnv* env, jobject thiz) { + auto* engine = new ArcFaceEngineRKNN(); + if (engine == nullptr) return 0; + return reinterpret_cast(engine); +} + +JNIEXPORT jint JNICALL +Java_com_digitalperson_engine_ArcFaceEngineRKNN_initNative( + JNIEnv* env, + jobject thiz, + jlong ptr, + jstring modelPath) { + auto* engine = reinterpret_cast(ptr); + if (engine == nullptr || modelPath == nullptr) return -1; + const char* model = env->GetStringUTFChars(modelPath, nullptr); + if (model == nullptr) return -1; + int ret = engine->init(model); + env->ReleaseStringUTFChars(modelPath, model); + return ret; +} + +JNIEXPORT jfloatArray JNICALL +Java_com_digitalperson_engine_ArcFaceEngineRKNN_extractEmbeddingNative( + JNIEnv* env, + jobject thiz, + jlong ptr, + jobject bitmapObj, + jfloat left, + jfloat top, + jfloat right, + jfloat bottom) { + LOGI("extractEmbeddingNative called: ptr=%ld, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f", ptr, left, top, right, bottom); + + auto* engine = reinterpret_cast(ptr); + if (engine == nullptr || bitmapObj == nullptr) { + LOGE("Engine or bitmap is null: engine=%p, bitmap=%p", engine, bitmapObj); + return env->NewFloatArray(0); + } + + AndroidBitmapInfo info{}; + if (AndroidBitmap_getInfo(env, bitmapObj, &info) < 0) { + LOGE("AndroidBitmap_getInfo failed"); + return env->NewFloatArray(0); + } + LOGI("Bitmap info: width=%d, height=%d, stride=%d, format=%d", info.width, info.height, info.stride, info.format); + + if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) { + LOGE("Unsupported bitmap format: %d", info.format); + return env->NewFloatArray(0); + } + + void* pixels = nullptr; + if (AndroidBitmap_lockPixels(env, bitmapObj, &pixels) < 0 || pixels == nullptr) { + LOGE("AndroidBitmap_lockPixels failed"); + return env->NewFloatArray(0); + } + LOGI("Bitmap pixels locked successfully"); + + std::vector emb = engine->extractEmbedding( + reinterpret_cast(pixels), + static_cast(info.width), + static_cast(info.height), + static_cast(info.stride), + left, + top, + right, + bottom); + + LOGI("Engine extractEmbedding returned: size=%zu", emb.size()); + + AndroidBitmap_unlockPixels(env, bitmapObj); + + jfloatArray out = env->NewFloatArray(static_cast(emb.size())); + if (out == nullptr) { + LOGE("Failed to create float array"); + return env->NewFloatArray(0); + } + if (!emb.empty()) { + env->SetFloatArrayRegion(out, 0, static_cast(emb.size()), emb.data()); + } + return out; +} + +JNIEXPORT void JNICALL +Java_com_digitalperson_engine_ArcFaceEngineRKNN_releaseNative( + JNIEnv* env, + jobject thiz, + jlong ptr) { + auto* engine = reinterpret_cast(ptr); + if (engine != nullptr) { + engine->release(); + delete engine; + } +} + +} // extern "C" diff --git a/app/src/main/cpp/CMakeLists.txt b/app/src/main/cpp/CMakeLists.txt index 17b5465..96c510f 100644 --- a/app/src/main/cpp/CMakeLists.txt +++ b/app/src/main/cpp/CMakeLists.txt @@ -14,6 +14,11 @@ if (ANDROID) set_target_properties(sentencepiece PROPERTIES IMPORTED_LOCATION ${JNI_LIBS_DIR}/libsentencepiece.so) + # 导入 rkllm 库 + add_library(rkllmrt SHARED IMPORTED) + set_target_properties(rkllmrt PROPERTIES IMPORTED_LOCATION + ${JNI_LIBS_DIR}/librkllmrt.so) + # Imported static libs add_library(kaldi_native_fbank STATIC IMPORTED) set_target_properties(kaldi_native_fbank PROPERTIES IMPORTED_LOCATION @@ -26,6 +31,12 @@ if (ANDROID) add_library(sensevoiceEngine SHARED SenseVoiceEngineRKNN.cpp SenseVoiceEngineRKNNJNI.cpp + RetinaFaceEngineRKNN.cpp + RetinaFaceEngineRKNNJNI.cpp + ArcFaceEngineRKNN.cpp + ArcFaceEngineRKNNJNI.cpp + RKLLMEngine.cpp + RKLLMEngineJNI.cpp utils/audio_utils.c ) @@ -40,9 +51,11 @@ if (ANDROID) target_link_libraries(sensevoiceEngine rknnrt + rkllmrt kaldi_native_fbank sndfile sentencepiece + jnigraphics log ) endif() diff --git a/app/src/main/cpp/RKLLMEngine.cpp b/app/src/main/cpp/RKLLMEngine.cpp new file mode 100644 index 0000000..9393a6a --- /dev/null +++ b/app/src/main/cpp/RKLLMEngine.cpp @@ -0,0 +1,118 @@ +#include +#include +#include "zipformer_headers/rkllm.h" + +#define TAG "RKLLMEngine" +#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) + +namespace { +// Keep these conservative first; too-large values may fail during init on RK3588. +constexpr int kDefaultMaxContextLen = 1024; +constexpr int kDefaultMaxNewTokens = 128; +} + +struct LLmJniEnv { + JNIEnv *env; + jobject thiz; + jclass clazz; +}; + +void callbackToJava(const char *text, int state, LLmJniEnv *jenv) { + jmethodID method = jenv->env->GetMethodID(jenv->clazz, "callbackFromNative", "(Ljava/lang/String;I)V"); + jstring jText = text ? jenv->env->NewStringUTF(text) : jenv->env->NewStringUTF(""); + jenv->env->CallVoidMethod(jenv->thiz, method, jText, state); +} + +int callback(RKLLMResult *result, void *userdata, LLMCallState state) { + auto jenv = (LLmJniEnv *)userdata; + + if (state == RKLLM_RUN_FINISH) { + LOGI(""); + callbackToJava(nullptr, 0, jenv); + delete jenv; + } else if (state == RKLLM_RUN_ERROR) { + LOGE(""); + callbackToJava(nullptr, -1, jenv); + delete jenv; + } else if (state == RKLLM_RUN_NORMAL) { + //LOGD("NM: [%d] %s", result->token_id, result->text); + callbackToJava(result->text, 1, jenv); + } + return 0; // 返回 0 表示正常继续执行 +} + +// JNI 方法实现 +extern "C" { + jlong initLLM(JNIEnv *env, jobject thiz, jstring model_path) { + const char* modelPath = env->GetStringUTFChars(model_path, nullptr); + LLMHandle llmHandle = nullptr; + + //设置参数及初始化 + RKLLMParam param = rkllm_createDefaultParam(); + param.model_path = modelPath; + + //设置采样参数 + param.top_k = 1; + param.top_p = 0.95; + param.temperature = 0.8; + param.repeat_penalty = 1.1; + param.frequency_penalty = 0.0; + param.presence_penalty = 0.0; + + param.max_new_tokens = kDefaultMaxNewTokens; + param.max_context_len = kDefaultMaxContextLen; + param.skip_special_token = true; + param.extend_param.base_domain_id = 0; + + LOGI("rkllm init with module: %s", modelPath); + LOGI("rkllm init params: max_context_len=%d, max_new_tokens=%d, top_k=%d, top_p=%f, temp=%f", + param.max_context_len, param.max_new_tokens, param.top_k, param.top_p, param.temperature); + int ret = rkllm_init(&llmHandle, ¶m, callback); + if (ret == 0){ + LOGI("rkllm init success, handle=%p", llmHandle); + } else { + LOGE("rkllm init failed, ret=%d", ret); + llmHandle = nullptr; + } + + env->ReleaseStringUTFChars(model_path, modelPath); + + return (jlong)llmHandle; + } + + void deinitLLM(JNIEnv *env, jobject thiz, jlong handle) { + rkllm_destroy((LLMHandle)handle); + } + + void infer(JNIEnv *env, jobject thiz, jlong handle, jstring text) { + if (handle == 0) { + LOGE("rkllm infer called with null handle"); + jclass clazz = env->GetObjectClass(thiz); + jmethodID method = env->GetMethodID(clazz, "callbackFromNative", "(Ljava/lang/String;I)V"); + jstring jText = env->NewStringUTF("RKLLM handle is null"); + env->CallVoidMethod(thiz, method, jText, -1); + env->DeleteLocalRef(jText); + return; + } + + auto *jnienv = new LLmJniEnv { + .env = env, + .thiz = thiz, + .clazz = env->GetObjectClass(thiz), + }; + + RKLLMInput rkllm_input = {}; + RKLLMInferParam rkllm_infer_params = {}; + const char* sText = env->GetStringUTFChars(text, nullptr); + + rkllm_infer_params.mode = RKLLM_INFER_GENERATE; + rkllm_input.input_type = RKLLM_INPUT_PROMPT; + rkllm_input.prompt_input = (char *)sText; + + rkllm_run((LLMHandle)handle, &rkllm_input, &rkllm_infer_params, jnienv); + env->ReleaseStringUTFChars(text, sText); + } +} \ No newline at end of file diff --git a/app/src/main/cpp/RKLLMEngineJNI.cpp b/app/src/main/cpp/RKLLMEngineJNI.cpp new file mode 100644 index 0000000..e074529 --- /dev/null +++ b/app/src/main/cpp/RKLLMEngineJNI.cpp @@ -0,0 +1,37 @@ +#include + +// JNI 方法声明 +extern "C" { + JNIEXPORT jlong JNICALL + Java_com_digitalperson_llm_RKLLM_initLLM(JNIEnv *env, jobject thiz, jstring model_path); + + JNIEXPORT void JNICALL + Java_com_digitalperson_llm_RKLLM_deinitLLM(JNIEnv *env, jobject thiz, jlong handle); + + JNIEXPORT void JNICALL + Java_com_digitalperson_llm_RKLLM_infer(JNIEnv *env, jobject thiz, jlong handle, jstring text); +} + +// 方法实现 +extern "C" { + jlong initLLM(JNIEnv *env, jobject thiz, jstring model_path); + void deinitLLM(JNIEnv *env, jobject thiz, jlong handle); + void infer(JNIEnv *env, jobject thiz, jlong handle, jstring text); +} + +extern "C" { + JNIEXPORT jlong JNICALL + Java_com_digitalperson_llm_RKLLM_initLLM(JNIEnv *env, jobject thiz, jstring model_path) { + return initLLM(env, thiz, model_path); + } + + JNIEXPORT void JNICALL + Java_com_digitalperson_llm_RKLLM_deinitLLM(JNIEnv *env, jobject thiz, jlong handle) { + deinitLLM(env, thiz, handle); + } + + JNIEXPORT void JNICALL + Java_com_digitalperson_llm_RKLLM_infer(JNIEnv *env, jobject thiz, jlong handle, jstring text) { + infer(env, thiz, handle, text); + } +} \ No newline at end of file diff --git a/app/src/main/cpp/RetinaFaceEngineRKNN.cpp b/app/src/main/cpp/RetinaFaceEngineRKNN.cpp new file mode 100644 index 0000000..ce58f9b --- /dev/null +++ b/app/src/main/cpp/RetinaFaceEngineRKNN.cpp @@ -0,0 +1,435 @@ +#include "RetinaFaceEngineRKNN.h" + +#include +#include +#include +#include + +#define LOG_TAG "RetinaFaceRKNN" +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) + +namespace { +constexpr float kVariance0 = 0.1f; +constexpr float kVariance1 = 0.2f; +} // namespace + +RetinaFaceEngineRKNN::RetinaFaceEngineRKNN() = default; + +RetinaFaceEngineRKNN::~RetinaFaceEngineRKNN() { + release(); +} + +size_t RetinaFaceEngineRKNN::tensorElemCount(const rknn_tensor_attr& attr) { + if (attr.n_elems > 0) { + return static_cast(attr.n_elems); + } + if (attr.n_dims <= 0) { + return 0; + } + size_t n = 1; + for (uint32_t i = 0; i < attr.n_dims; ++i) { + if (attr.dims[i] == 0) continue; + n *= static_cast(attr.dims[i]); + } + return n; +} + +int RetinaFaceEngineRKNN::init( + const char* modelPath, + int inputSize, + float scoreThreshold, + float nmsThreshold) { + release(); + inputSize_ = inputSize; + scoreThreshold_ = scoreThreshold; + nmsThreshold_ = nmsThreshold; + + int ret = rknn_init(&ctx_, (void*)modelPath, 0, 0, nullptr); + if (ret != RKNN_SUCC) { + LOGE("rknn_init failed: ret=%d, model=%s", ret, modelPath); + return ret; + } + + std::memset(&ioNum_, 0, sizeof(ioNum_)); + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &ioNum_, sizeof(ioNum_)); + if (ret != RKNN_SUCC) { + LOGE("rknn_query(RKNN_QUERY_IN_OUT_NUM) failed: %d", ret); + release(); + return ret; + } + if (ioNum_.n_input < 1 || ioNum_.n_output < 1) { + LOGE("invalid io num: input=%u output=%u", ioNum_.n_input, ioNum_.n_output); + release(); + return -1; + } + + std::memset(&inputAttr_, 0, sizeof(inputAttr_)); + inputAttr_.index = 0; + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &inputAttr_, sizeof(inputAttr_)); + if (ret != RKNN_SUCC) { + LOGE("rknn_query input attr failed: %d", ret); + release(); + return ret; + } + + outputAttrs_.clear(); + outputAttrs_.resize(ioNum_.n_output); + for (uint32_t i = 0; i < ioNum_.n_output; ++i) { + std::memset(&outputAttrs_[i], 0, sizeof(rknn_tensor_attr)); + outputAttrs_[i].index = i; + int qret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &outputAttrs_[i], sizeof(rknn_tensor_attr)); + if (qret != RKNN_SUCC) { + LOGW("query output attr[%u] failed: %d", i, qret); + } + LOGI("output[%u] n_elems=%u n_dims=%u type=%d qnt=%d", + i, outputAttrs_[i].n_elems, outputAttrs_[i].n_dims, outputAttrs_[i].type, outputAttrs_[i].qnt_type); + for (uint32_t d = 0; d < outputAttrs_[i].n_dims; ++d) { + LOGI(" output[%u] dim[%u]=%u", i, d, outputAttrs_[i].dims[d]); + } + } + + initialized_ = true; + LOGI("RetinaFace initialized, input_size=%d, outputs=%u", inputSize_, ioNum_.n_output); + return 0; +} + +std::vector RetinaFaceEngineRKNN::buildPriors() const { + std::vector priors; + const int steps[3] = {8, 16, 32}; + const int minSizes[3][2] = {{16, 32}, {64, 128}, {256, 512}}; + for (int s = 0; s < 3; ++s) { + const int step = steps[s]; + const int featW = static_cast(std::ceil(static_cast(inputSize_) / step)); + const int featH = static_cast(std::ceil(static_cast(inputSize_) / step)); + for (int y = 0; y < featH; ++y) { + for (int x = 0; x < featW; ++x) { + for (int k = 0; k < 2; ++k) { + const float minSize = static_cast(minSizes[s][k]); + PriorBox p; + p.cx = (x + 0.5f) * step / inputSize_; + p.cy = (y + 0.5f) * step / inputSize_; + p.w = minSize / inputSize_; + p.h = minSize / inputSize_; + priors.push_back(p); + } + } + } + } + return priors; +} + +bool RetinaFaceEngineRKNN::parseRetinaOutputs( + rknn_output* outputs, + std::vector* locOut, + std::vector* scoreOut) const { + std::vector> locCandidates; + std::vector> confCandidates2; + std::vector> scoreCandidates1; + + const int anchors8 = (inputSize_ / 8) * (inputSize_ / 8) * 2; + const int anchors16 = (inputSize_ / 16) * (inputSize_ / 16) * 2; + const int anchors32 = (inputSize_ / 32) * (inputSize_ / 32) * 2; + const int totalAnchors = anchors8 + anchors16 + anchors32; + const int expectedLoc8 = anchors8 * 4; + const int expectedLoc16 = anchors16 * 4; + const int expectedLoc32 = anchors32 * 4; + const int expectedConf8_2 = anchors8 * 2; + const int expectedConf16_2 = anchors16 * 2; + const int expectedConf32_2 = anchors32 * 2; + const int expectedConf8_1 = anchors8; + const int expectedConf16_1 = anchors16; + const int expectedConf32_1 = anchors32; + + for (uint32_t i = 0; i < ioNum_.n_output; ++i) { + const size_t elems = tensorElemCount(outputAttrs_[i]); + if (elems == 0 || outputs[i].buf == nullptr) continue; + + const float* ptr = reinterpret_cast(outputs[i].buf); + std::vector data(ptr, ptr + elems); + const int e = static_cast(elems); + + if (e == expectedLoc8 || e == expectedLoc16 || e == expectedLoc32 || e == totalAnchors * 4) { + locCandidates.push_back(std::move(data)); + continue; + } + if (e == expectedConf8_2 || e == expectedConf16_2 || e == expectedConf32_2 || e == totalAnchors * 2) { + confCandidates2.push_back(std::move(data)); + continue; + } + if (e == expectedConf8_1 || e == expectedConf16_1 || e == expectedConf32_1 || e == totalAnchors) { + scoreCandidates1.push_back(std::move(data)); + continue; + } + } + + locOut->clear(); + scoreOut->clear(); + + auto sortBySize = [](const std::vector& a, const std::vector& b) { + return a.size() > b.size(); + }; + std::sort(locCandidates.begin(), locCandidates.end(), sortBySize); + std::sort(confCandidates2.begin(), confCandidates2.end(), sortBySize); + std::sort(scoreCandidates1.begin(), scoreCandidates1.end(), sortBySize); + + auto mergeLoc = [&]() -> bool { + if (locCandidates.empty()) return false; + if (locCandidates.size() >= 3 && + static_cast(locCandidates[0].size()) == anchors8 * 4 && + static_cast(locCandidates[1].size()) == anchors16 * 4 && + static_cast(locCandidates[2].size()) == anchors32 * 4) { + locOut->reserve(static_cast(totalAnchors) * 4); + locOut->insert(locOut->end(), locCandidates[0].begin(), locCandidates[0].end()); + locOut->insert(locOut->end(), locCandidates[1].begin(), locCandidates[1].end()); + locOut->insert(locOut->end(), locCandidates[2].begin(), locCandidates[2].end()); + return true; + } + for (const auto& c : locCandidates) { + if (static_cast(c.size()) == totalAnchors * 4) { + *locOut = c; + return true; + } + } + return false; + }; + + auto mergeScoreFrom2Class = [&]() -> bool { + if (confCandidates2.empty()) return false; + std::vector merged2; + if (confCandidates2.size() >= 3 && + static_cast(confCandidates2[0].size()) == expectedConf8_2 && + static_cast(confCandidates2[1].size()) == expectedConf16_2 && + static_cast(confCandidates2[2].size()) == expectedConf32_2) { + merged2.reserve(static_cast(totalAnchors) * 2); + merged2.insert(merged2.end(), confCandidates2[0].begin(), confCandidates2[0].end()); + merged2.insert(merged2.end(), confCandidates2[1].begin(), confCandidates2[1].end()); + merged2.insert(merged2.end(), confCandidates2[2].begin(), confCandidates2[2].end()); + } else { + bool found = false; + for (const auto& c : confCandidates2) { + if (static_cast(c.size()) == totalAnchors * 2) { + merged2 = c; + found = true; + break; + } + } + if (!found) return false; + } + + scoreOut->reserve(totalAnchors); + for (int i = 0; i < totalAnchors; ++i) { + scoreOut->push_back(merged2[i * 2 + 1]); + } + return true; + }; + + auto mergeScoreFrom1Class = [&]() -> bool { + if (scoreCandidates1.empty()) return false; + if (scoreCandidates1.size() >= 3 && + static_cast(scoreCandidates1[0].size()) == expectedConf8_1 && + static_cast(scoreCandidates1[1].size()) == expectedConf16_1 && + static_cast(scoreCandidates1[2].size()) == expectedConf32_1) { + scoreOut->reserve(totalAnchors); + scoreOut->insert(scoreOut->end(), scoreCandidates1[0].begin(), scoreCandidates1[0].end()); + scoreOut->insert(scoreOut->end(), scoreCandidates1[1].begin(), scoreCandidates1[1].end()); + scoreOut->insert(scoreOut->end(), scoreCandidates1[2].begin(), scoreCandidates1[2].end()); + return true; + } + for (const auto& c : scoreCandidates1) { + if (static_cast(c.size()) == totalAnchors) { + *scoreOut = c; + return true; + } + } + return false; + }; + + const bool locOk = mergeLoc(); + bool scoreOk = mergeScoreFrom2Class(); + if (!scoreOk) { + scoreOk = mergeScoreFrom1Class(); + } + if (!locOk || !scoreOk) { + LOGW("Unable to parse retina outputs, loc_candidates=%zu, conf2_candidates=%zu, conf1_candidates=%zu", + locCandidates.size(), confCandidates2.size(), scoreCandidates1.size()); + return false; + } + return true; +} + +float RetinaFaceEngineRKNN::iou(const FaceCandidate& a, const FaceCandidate& b) { + const float left = std::max(a.left, b.left); + const float top = std::max(a.top, b.top); + const float right = std::min(a.right, b.right); + const float bottom = std::min(a.bottom, b.bottom); + const float w = std::max(0.0f, right - left); + const float h = std::max(0.0f, bottom - top); + const float inter = w * h; + const float areaA = std::max(0.0f, a.right - a.left) * std::max(0.0f, a.bottom - a.top); + const float areaB = std::max(0.0f, b.right - b.left) * std::max(0.0f, b.bottom - b.top); + const float uni = areaA + areaB - inter; + return uni > 0.0f ? (inter / uni) : 0.0f; +} + +std::vector RetinaFaceEngineRKNN::nms( + const std::vector& boxes, + float threshold) { + std::vector sorted = boxes; + std::sort(sorted.begin(), sorted.end(), [](const FaceCandidate& a, const FaceCandidate& b) { + return a.score > b.score; + }); + + std::vector keep; + std::vector removed(sorted.size(), 0); + for (size_t i = 0; i < sorted.size(); ++i) { + if (removed[i]) continue; + keep.push_back(sorted[i]); + for (size_t j = i + 1; j < sorted.size(); ++j) { + if (removed[j]) continue; + if (iou(sorted[i], sorted[j]) > threshold) { + removed[j] = 1; + } + } + } + return keep; +} + +std::vector RetinaFaceEngineRKNN::detect( + const uint32_t* argbPixels, + int width, + int height, + int strideBytes) { + std::vector empty; + if (!initialized_ || ctx_ == 0 || argbPixels == nullptr || width <= 0 || height <= 0) { + return empty; + } + + std::vector rgb(inputSize_ * inputSize_ * 3); + const int srcStridePx = strideBytes / 4; + for (int y = 0; y < inputSize_; ++y) { + const int sy = y * height / inputSize_; + const uint32_t* srcRow = argbPixels + sy * srcStridePx; + uint8_t* dst = rgb.data() + y * inputSize_ * 3; + for (int x = 0; x < inputSize_; ++x) { + const int sx = x * width / inputSize_; + const uint32_t pixel = srcRow[sx]; + const uint8_t r = (pixel >> 16) & 0xFF; + const uint8_t g = (pixel >> 8) & 0xFF; + const uint8_t b = pixel & 0xFF; + dst[3 * x + 0] = r; + dst[3 * x + 1] = g; + dst[3 * x + 2] = b; + } + } + + rknn_input input{}; + input.index = 0; + input.type = RKNN_TENSOR_UINT8; + input.size = rgb.size(); + input.buf = rgb.data(); + input.pass_through = 0; + input.fmt = (inputAttr_.fmt == RKNN_TENSOR_NCHW) ? RKNN_TENSOR_NCHW : RKNN_TENSOR_NHWC; + + std::vector nchw; + if (input.fmt == RKNN_TENSOR_NCHW) { + nchw.resize(rgb.size()); + const int hw = inputSize_ * inputSize_; + for (int i = 0; i < hw; ++i) { + nchw[i] = rgb[3 * i + 0]; + nchw[hw + i] = rgb[3 * i + 1]; + nchw[2 * hw + i] = rgb[3 * i + 2]; + } + input.buf = nchw.data(); + } + + int ret = rknn_inputs_set(ctx_, 1, &input); + if (ret != RKNN_SUCC) { + LOGW("rknn_inputs_set failed: %d", ret); + return empty; + } + + ret = rknn_run(ctx_, nullptr); + if (ret != RKNN_SUCC) { + LOGW("rknn_run failed: %d", ret); + return empty; + } + + std::vector outputs(ioNum_.n_output); + for (uint32_t i = 0; i < ioNum_.n_output; ++i) { + std::memset(&outputs[i], 0, sizeof(rknn_output)); + outputs[i].want_float = 1; + } + ret = rknn_outputs_get(ctx_, ioNum_.n_output, outputs.data(), nullptr); + if (ret != RKNN_SUCC) { + LOGW("rknn_outputs_get failed: %d", ret); + return empty; + } + + std::vector loc; + std::vector scores; + if (!parseRetinaOutputs(outputs.data(), &loc, &scores)) { + rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data()); + return empty; + } + + const std::vector priors = buildPriors(); + const size_t anchorCount = priors.size(); + if (loc.size() < anchorCount * 4 || scores.size() < anchorCount) { + LOGW("Output size mismatch: priors=%zu loc=%zu scores=%zu", anchorCount, loc.size(), scores.size()); + rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data()); + return empty; + } + + std::vector candidates; + candidates.reserve(anchorCount / 8); + for (size_t i = 0; i < anchorCount; ++i) { + const float score = scores[i]; + if (score < scoreThreshold_) continue; + + const PriorBox& p = priors[i]; + const float dx = loc[i * 4 + 0]; + const float dy = loc[i * 4 + 1]; + const float dw = loc[i * 4 + 2]; + const float dh = loc[i * 4 + 3]; + + const float cx = p.cx + dx * kVariance0 * p.w; + const float cy = p.cy + dy * kVariance0 * p.h; + const float w = p.w * std::exp(dw * kVariance1); + const float h = p.h * std::exp(dh * kVariance1); + + FaceCandidate box; + box.left = std::max(0.0f, (cx - w * 0.5f) * width); + box.top = std::max(0.0f, (cy - h * 0.5f) * height); + box.right = std::min(static_cast(width), (cx + w * 0.5f) * width); + box.bottom = std::min(static_cast(height), (cy + h * 0.5f) * height); + box.score = score; + candidates.push_back(box); + } + + rknn_outputs_release(ctx_, ioNum_.n_output, outputs.data()); + + std::vector filtered = nms(candidates, nmsThreshold_); + std::vector result; + result.reserve(filtered.size() * 5); + for (const auto& f : filtered) { + result.push_back(f.left); + result.push_back(f.top); + result.push_back(f.right); + result.push_back(f.bottom); + result.push_back(f.score); + } + return result; +} + +void RetinaFaceEngineRKNN::release() { + if (ctx_ != 0) { + rknn_destroy(ctx_); + ctx_ = 0; + } + outputAttrs_.clear(); + std::memset(&ioNum_, 0, sizeof(ioNum_)); + std::memset(&inputAttr_, 0, sizeof(inputAttr_)); + initialized_ = false; +} diff --git a/app/src/main/cpp/RetinaFaceEngineRKNN.h b/app/src/main/cpp/RetinaFaceEngineRKNN.h new file mode 100644 index 0000000..2f92430 --- /dev/null +++ b/app/src/main/cpp/RetinaFaceEngineRKNN.h @@ -0,0 +1,55 @@ +#ifndef DIGITAL_PERSON_RETINAFACE_ENGINE_RKNN_H +#define DIGITAL_PERSON_RETINAFACE_ENGINE_RKNN_H + +#include +#include +#include + +#include "rknn_api.h" + +class RetinaFaceEngineRKNN { +public: + RetinaFaceEngineRKNN(); + ~RetinaFaceEngineRKNN(); + + int init(const char* modelPath, int inputSize, float scoreThreshold, float nmsThreshold); + std::vector detect(const uint32_t* argbPixels, int width, int height, int strideBytes); + void release(); + +private: + struct PriorBox { + float cx; + float cy; + float w; + float h; + }; + + struct FaceCandidate { + float left; + float top; + float right; + float bottom; + float score; + }; + + static size_t tensorElemCount(const rknn_tensor_attr& attr); + static float iou(const FaceCandidate& a, const FaceCandidate& b); + static std::vector nms(const std::vector& boxes, float threshold); + + std::vector buildPriors() const; + bool parseRetinaOutputs( + rknn_output* outputs, + std::vector* locOut, + std::vector* scoreOut) const; + + rknn_context ctx_ = 0; + bool initialized_ = false; + int inputSize_ = 320; + float scoreThreshold_ = 0.6f; + float nmsThreshold_ = 0.4f; + rknn_input_output_num ioNum_{}; + rknn_tensor_attr inputAttr_{}; + std::vector outputAttrs_; +}; + +#endif diff --git a/app/src/main/cpp/RetinaFaceEngineRKNNJNI.cpp b/app/src/main/cpp/RetinaFaceEngineRKNNJNI.cpp new file mode 100644 index 0000000..cf1b2ae --- /dev/null +++ b/app/src/main/cpp/RetinaFaceEngineRKNNJNI.cpp @@ -0,0 +1,100 @@ +#include +#include +#include + +#include "RetinaFaceEngineRKNN.h" + +#define LOG_TAG "RetinaFaceJNI" +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_com_digitalperson_engine_RetinaFaceEngineRKNN_createEngineNative(JNIEnv* env, jobject thiz) { + auto* engine = new RetinaFaceEngineRKNN(); + if (engine == nullptr) { + return 0; + } + return reinterpret_cast(engine); +} + +JNIEXPORT jint JNICALL +Java_com_digitalperson_engine_RetinaFaceEngineRKNN_initNative( + JNIEnv* env, + jobject thiz, + jlong ptr, + jstring modelPath, + jint inputSize, + jfloat scoreThreshold, + jfloat nmsThreshold) { + auto* engine = reinterpret_cast(ptr); + if (engine == nullptr || modelPath == nullptr) { + return -1; + } + const char* model = env->GetStringUTFChars(modelPath, nullptr); + if (model == nullptr) { + return -1; + } + int ret = engine->init(model, static_cast(inputSize), scoreThreshold, nmsThreshold); + env->ReleaseStringUTFChars(modelPath, model); + return ret; +} + +JNIEXPORT jfloatArray JNICALL +Java_com_digitalperson_engine_RetinaFaceEngineRKNN_detectNative( + JNIEnv* env, + jobject thiz, + jlong ptr, + jobject bitmapObj) { + auto* engine = reinterpret_cast(ptr); + if (engine == nullptr || bitmapObj == nullptr) { + return env->NewFloatArray(0); + } + + AndroidBitmapInfo info{}; + if (AndroidBitmap_getInfo(env, bitmapObj, &info) < 0) { + LOGE("AndroidBitmap_getInfo failed"); + return env->NewFloatArray(0); + } + if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) { + LOGE("Unsupported bitmap format: %d", info.format); + return env->NewFloatArray(0); + } + + void* pixels = nullptr; + if (AndroidBitmap_lockPixels(env, bitmapObj, &pixels) < 0 || pixels == nullptr) { + LOGE("AndroidBitmap_lockPixels failed"); + return env->NewFloatArray(0); + } + + std::vector result = engine->detect( + reinterpret_cast(pixels), + static_cast(info.width), + static_cast(info.height), + static_cast(info.stride)); + + AndroidBitmap_unlockPixels(env, bitmapObj); + + jfloatArray out = env->NewFloatArray(static_cast(result.size())); + if (out == nullptr) { + return env->NewFloatArray(0); + } + if (!result.empty()) { + env->SetFloatArrayRegion(out, 0, static_cast(result.size()), result.data()); + } + return out; +} + +JNIEXPORT void JNICALL +Java_com_digitalperson_engine_RetinaFaceEngineRKNN_releaseNative( + JNIEnv* env, + jobject thiz, + jlong ptr) { + auto* engine = reinterpret_cast(ptr); + if (engine != nullptr) { + engine->release(); + delete engine; + } +} + +} // extern "C" diff --git a/app/src/main/cpp/zipformer_headers/rkllm.h b/app/src/main/cpp/zipformer_headers/rkllm.h new file mode 100644 index 0000000..3678287 --- /dev/null +++ b/app/src/main/cpp/zipformer_headers/rkllm.h @@ -0,0 +1,409 @@ +#ifndef _RKLLM_H_ +#define _RKLLM_H_ +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define CPU0 (1 << 0) // 0x01 +#define CPU1 (1 << 1) // 0x02 +#define CPU2 (1 << 2) // 0x04 +#define CPU3 (1 << 3) // 0x08 +#define CPU4 (1 << 4) // 0x10 +#define CPU5 (1 << 5) // 0x20 +#define CPU6 (1 << 6) // 0x40 +#define CPU7 (1 << 7) // 0x80 + +/** + * @typedef LLMHandle + * @brief A handle used to manage and interact with the large language model. + */ +typedef void* LLMHandle; + +/** + * @enum LLMCallState + * @brief Describes the possible states of an LLM call. + */ +typedef enum { + RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */ + RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */ + RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */ + RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */ +} LLMCallState; + +/** + * @enum RKLLMInputType + * @brief Defines the types of inputs that can be fed into the LLM. + */ +typedef enum { + RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */ + RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */ + RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */ + RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */ +} RKLLMInputType; + +/** + * @enum RKLLMInferMode + * @brief Specifies the inference modes of the LLM. + */ +typedef enum { + RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */ + RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */ + RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */ +} RKLLMInferMode; + +/** + * @struct RKLLMExtendParam + * @brief The extend parameters for configuring an LLM instance. + */ +typedef struct { + int32_t base_domain_id; /**< base_domain_id */ + int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */ + int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */ + uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */ + uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */ + int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */ + uint8_t reserved[104]; /**< reserved */ +} RKLLMExtendParam; + +/** + * @struct RKLLMParam + * @brief Defines the parameters for configuring an LLM instance. + */ +typedef struct { + const char* model_path; /**< Path to the model file. */ + int32_t max_context_len; /**< Maximum number of tokens in the context window. */ + int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */ + int32_t top_k; /**< Top-K sampling parameter for token generation. */ + int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */ + float top_p; /**< Top-P (nucleus) sampling parameter. */ + float temperature; /**< Sampling temperature, affecting the randomness of token selection. */ + float repeat_penalty; /**< Penalty for repeating tokens in generation. */ + float frequency_penalty; /**< Penalizes frequent tokens during generation. */ + float presence_penalty; /**< Penalizes tokens based on their presence in the input. */ + int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */ + float mirostat_tau; /**< Tau parameter for Mirostat sampling. */ + float mirostat_eta; /**< Eta parameter for Mirostat sampling. */ + bool skip_special_token; /**< Whether to skip special tokens during generation. */ + bool is_async; /**< Whether to run inference asynchronously. */ + const char* img_start; /**< Starting position of an image in multimodal input. */ + const char* img_end; /**< Ending position of an image in multimodal input. */ + const char* img_content; /**< Pointer to the image content. */ + RKLLMExtendParam extend_param; /**< Extend parameters. */ +} RKLLMParam; + +/** + * @struct RKLLMLoraAdapter + * @brief Defines parameters for a Lora adapter used in model fine-tuning. + */ +typedef struct { + const char* lora_adapter_path; /**< Path to the Lora adapter file. */ + const char* lora_adapter_name; /**< Name of the Lora adapter. */ + float scale; /**< Scaling factor for applying the Lora adapter. */ +} RKLLMLoraAdapter; + +/** + * @struct RKLLMEmbedInput + * @brief Represents an embedding input to the LLM. + */ +typedef struct { + float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */ + size_t n_tokens; /**< Number of tokens represented in the embedding. */ +} RKLLMEmbedInput; + +/** + * @struct RKLLMTokenInput + * @brief Represents token input to the LLM. + */ +typedef struct { + int32_t* input_ids; /**< Array of token IDs. */ + size_t n_tokens; /**< Number of tokens in the input. */ +} RKLLMTokenInput; + +/** + * @struct RKLLMMultiModalInput + * @brief Represents multimodal input (e.g., text and image). + */ +typedef struct { + char* prompt; /**< Text prompt input. */ + float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */ + size_t n_image_tokens; /**< Number of image_token. */ + size_t n_image; /**< Number of image. */ + size_t image_width; /**< Width of image. */ + size_t image_height; /**< Height of image. */ +} RKLLMMultiModalInput; + +/** + * @struct RKLLMInput + * @brief Represents different types of input to the LLM via a union. + */ +typedef struct { + const char* role; /**< Message role: "user" (user input), "tool" (function result) */ + bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */ + RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */ + union { + const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */ + RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */ + RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */ + RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */ + }; +} RKLLMInput; + +/** + * @struct RKLLMLoraParam + * @brief Structure defining parameters for Lora adapters. + */ +typedef struct { + const char* lora_adapter_name; /**< Name of the Lora adapter. */ +} RKLLMLoraParam; + +/** + * @struct RKLLMPromptCacheParam + * @brief Structure to define parameters for caching prompts. + */ +typedef struct { + int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */ + const char* prompt_cache_path; /**< Path to the prompt cache file. */ +} RKLLMPromptCacheParam; + +/** + * @struct RKLLMCrossAttnParam + * @brief Structure holding parameters for cross-attention inference. + * + * This structure is used when performing cross-attention in the decoder. + * It provides the encoder output (key/value caches), position indices, + * and attention mask. + * + * - `encoder_k_cache` must be stored in contiguous memory with layout: + * [num_layers][num_tokens][num_kv_heads][head_dim] + * - `encoder_v_cache` must be stored in contiguous memory with layout: + * [num_layers][num_kv_heads][head_dim][num_tokens] + */ +typedef struct { + float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */ + float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */ + float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */ + int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */ + int num_tokens; /**< Number of tokens in the encoder sequence. */ +} RKLLMCrossAttnParam; + +/** + * @struct RKLLMInferParam + * @brief Structure for defining parameters during inference. + */ +typedef struct { + RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */ + RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */ + RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */ + int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/ +} RKLLMInferParam; + +/** + * @struct RKLLMResultLastHiddenLayer + * @brief Structure to hold the hidden states from the last layer. + */ +typedef struct { + const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */ + int embd_size; /**< Size of the embedding vector. */ + int num_tokens; /**< Number of tokens for which hidden states are stored. */ +} RKLLMResultLastHiddenLayer; + +/** + * @struct RKLLMResultLogits + * @brief Structure to hold the logits. + */ +typedef struct { + const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */ + int vocab_size; /**< Size of the vocab. */ + int num_tokens; /**< Number of tokens for which logits are stored. */ +} RKLLMResultLogits; + +/** + * @struct RKLLMPerfStat + * @brief Structure to hold performance statistics for prefill and generate stages. + */ +typedef struct { + float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */ + int prefill_tokens; /**< Number of tokens processed during the prefill stage. */ + float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */ + int generate_tokens; /**< Number of tokens processed during the generate stage. */ + float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */ +} RKLLMPerfStat; + +/** + * @struct RKLLMResult + * @brief Structure to represent the result of LLM inference. + */ +typedef struct { + const char* text; /**< Generated text result. */ + int32_t token_id; /**< ID of the generated token. */ + RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */ + RKLLMResultLogits logits; /**< Model output logits. */ + RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */ +} RKLLMResult; + +/** + * @typedef LLMResultCallback + * @brief Callback function to handle LLM results. + * @param result Pointer to the LLM result. + * @param userdata Pointer to user data for the callback. + * @param state State of the LLM call (e.g., finished, error). + * @return int Return value indicating the handling status: + * - 0: Continue inference normally. + * - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt), + * return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference. + */ +typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state); + +/** + * @brief Creates a default RKLLMParam structure with preset values. + * @return A default RKLLMParam structure. + */ +RKLLMParam rkllm_createDefaultParam(); + +/** + * @brief Initializes the LLM with the given parameters. + * @param handle Pointer to the LLM handle. + * @param param Configuration parameters for the LLM. + * @param callback Callback function to handle LLM results. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback); + +/** + * @brief Loads a Lora adapter into the LLM. + * @param handle LLM handle. + * @param lora_adapter Pointer to the Lora adapter structure. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter); + +/** + * @brief Loads a prompt cache from a file. + * @param handle LLM handle. + * @param prompt_cache_path Path to the prompt cache file. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path); + +/** + * @brief Releases the prompt cache from memory. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_release_prompt_cache(LLMHandle handle); + +/** + * @brief Destroys the LLM instance and releases resources. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_destroy(LLMHandle handle); + +/** + * @brief Runs an LLM inference task synchronously. + * @param handle LLM handle. + * @param rkllm_input Input data for the LLM. + * @param rkllm_infer_params Parameters for the inference task. + * @param userdata Pointer to user data for the callback. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); + +/** + * @brief Runs an LLM inference task asynchronously. + * @param handle LLM handle. + * @param rkllm_input Input data for the LLM. + * @param rkllm_infer_params Parameters for the inference task. + * @param userdata Pointer to user data for the callback. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); + +/** + * @brief Aborts an ongoing LLM task. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_abort(LLMHandle handle); + +/** + * @brief Checks if an LLM task is currently running. + * @param handle LLM handle. + * @return Status code (0 if a task is running, non-zero for otherwise). + */ +int rkllm_is_running(LLMHandle handle); + +/** + * @brief Clear the key-value cache for a given LLM handle. + * + * This function is used to clear part or all of the KV cache. + * + * @param handle LLM handle. + * @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear). + * This flag is ignored if a specific range [start_pos, end_pos) is provided. + * @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch. + * @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch. + * If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect, + * If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored. + * @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback + * @return Status code (0 if cache was cleared successfully, non-zero otherwise). + */ +int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos); + +/** + * @brief Get the current size of the key-value cache for a given LLM handle. + * + * This function returns the total number of positions currently stored in the model's KV cache. + * + * @param handle LLM handle. + * @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored. + * The array must be preallocated with space for `n_batch` elements. + */ +int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes); + +/** + * @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix. + * + * This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix. + * The system prompt is typically used to define the behavior or context of the language model, + * while the prefix and postfix are used to format the user input and output respectively. + * + * @param handle LLM handle. + * @param system_prompt The system prompt that defines the context or behavior of the language model. + * @param prompt_prefix The prefix added before the user input in the chat. + * @param prompt_postfix The postfix added after the user input in the chat. + * + * @return Status code (0 if the template was set successfully, non-zero for errors). + */ +int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix); + +/** + * @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token. + * + * @param handle LLM handle. + * @param system_prompt The system prompt that defines the context or behavior of the language model. + * @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters. + * @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag, + * allowing tokenizer to recognize tool outputs separately from normal dialogue turns. + * @return Status code (0 if the configuration was set successfully, non-zero for errors). + */ +int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str); + +/** + * @brief Sets the cross-attention parameters for the LLM decoder. + * + * @param handle LLM handle. + * @param cross_attn_params Pointer to the structure containing encoder-related input data + * used for cross-attention (see RKLLMCrossAttnParam for details). + * + * @return Status code (0 if the parameters were set successfully, non-zero for errors). + */ +int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/app/src/main/cpp/zipformer_headers/rkllm.h.2 b/app/src/main/cpp/zipformer_headers/rkllm.h.2 new file mode 100644 index 0000000..3678287 --- /dev/null +++ b/app/src/main/cpp/zipformer_headers/rkllm.h.2 @@ -0,0 +1,409 @@ +#ifndef _RKLLM_H_ +#define _RKLLM_H_ +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define CPU0 (1 << 0) // 0x01 +#define CPU1 (1 << 1) // 0x02 +#define CPU2 (1 << 2) // 0x04 +#define CPU3 (1 << 3) // 0x08 +#define CPU4 (1 << 4) // 0x10 +#define CPU5 (1 << 5) // 0x20 +#define CPU6 (1 << 6) // 0x40 +#define CPU7 (1 << 7) // 0x80 + +/** + * @typedef LLMHandle + * @brief A handle used to manage and interact with the large language model. + */ +typedef void* LLMHandle; + +/** + * @enum LLMCallState + * @brief Describes the possible states of an LLM call. + */ +typedef enum { + RKLLM_RUN_NORMAL = 0, /**< The LLM call is in a normal running state. */ + RKLLM_RUN_WAITING = 1, /**< The LLM call is waiting for complete UTF-8 encoded character. */ + RKLLM_RUN_FINISH = 2, /**< The LLM call has finished execution. */ + RKLLM_RUN_ERROR = 3, /**< An error occurred during the LLM call. */ +} LLMCallState; + +/** + * @enum RKLLMInputType + * @brief Defines the types of inputs that can be fed into the LLM. + */ +typedef enum { + RKLLM_INPUT_PROMPT = 0, /**< Input is a text prompt. */ + RKLLM_INPUT_TOKEN = 1, /**< Input is a sequence of tokens. */ + RKLLM_INPUT_EMBED = 2, /**< Input is an embedding vector. */ + RKLLM_INPUT_MULTIMODAL = 3, /**< Input is multimodal (e.g., text and image). */ +} RKLLMInputType; + +/** + * @enum RKLLMInferMode + * @brief Specifies the inference modes of the LLM. + */ +typedef enum { + RKLLM_INFER_GENERATE = 0, /**< The LLM generates text based on input. */ + RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1, /**< The LLM retrieves the last hidden layer for further processing. */ + RKLLM_INFER_GET_LOGITS = 2, /**< The LLM retrieves logits for further processing. */ +} RKLLMInferMode; + +/** + * @struct RKLLMExtendParam + * @brief The extend parameters for configuring an LLM instance. + */ +typedef struct { + int32_t base_domain_id; /**< base_domain_id */ + int8_t embed_flash; /**< Indicates whether to query word embedding vectors from flash memory (1) or not (0). */ + int8_t enabled_cpus_num; /**< Number of CPUs enabled for inference. */ + uint32_t enabled_cpus_mask; /**< Bitmask indicating which CPUs to enable for inference. */ + uint8_t n_batch; /**< Number of input samples processed concurrently in one forward pass. Set to >1 to enable batched inference. Default is 1. */ + int8_t use_cross_attn; /**< Whether to enable cross attention (non-zero to enable, 0 to disable). */ + uint8_t reserved[104]; /**< reserved */ +} RKLLMExtendParam; + +/** + * @struct RKLLMParam + * @brief Defines the parameters for configuring an LLM instance. + */ +typedef struct { + const char* model_path; /**< Path to the model file. */ + int32_t max_context_len; /**< Maximum number of tokens in the context window. */ + int32_t max_new_tokens; /**< Maximum number of new tokens to generate. */ + int32_t top_k; /**< Top-K sampling parameter for token generation. */ + int32_t n_keep; /** number of kv cache to keep at the beginning when shifting context window */ + float top_p; /**< Top-P (nucleus) sampling parameter. */ + float temperature; /**< Sampling temperature, affecting the randomness of token selection. */ + float repeat_penalty; /**< Penalty for repeating tokens in generation. */ + float frequency_penalty; /**< Penalizes frequent tokens during generation. */ + float presence_penalty; /**< Penalizes tokens based on their presence in the input. */ + int32_t mirostat; /**< Mirostat sampling strategy flag (0 to disable). */ + float mirostat_tau; /**< Tau parameter for Mirostat sampling. */ + float mirostat_eta; /**< Eta parameter for Mirostat sampling. */ + bool skip_special_token; /**< Whether to skip special tokens during generation. */ + bool is_async; /**< Whether to run inference asynchronously. */ + const char* img_start; /**< Starting position of an image in multimodal input. */ + const char* img_end; /**< Ending position of an image in multimodal input. */ + const char* img_content; /**< Pointer to the image content. */ + RKLLMExtendParam extend_param; /**< Extend parameters. */ +} RKLLMParam; + +/** + * @struct RKLLMLoraAdapter + * @brief Defines parameters for a Lora adapter used in model fine-tuning. + */ +typedef struct { + const char* lora_adapter_path; /**< Path to the Lora adapter file. */ + const char* lora_adapter_name; /**< Name of the Lora adapter. */ + float scale; /**< Scaling factor for applying the Lora adapter. */ +} RKLLMLoraAdapter; + +/** + * @struct RKLLMEmbedInput + * @brief Represents an embedding input to the LLM. + */ +typedef struct { + float* embed; /**< Pointer to the embedding vector (of size n_tokens * n_embed). */ + size_t n_tokens; /**< Number of tokens represented in the embedding. */ +} RKLLMEmbedInput; + +/** + * @struct RKLLMTokenInput + * @brief Represents token input to the LLM. + */ +typedef struct { + int32_t* input_ids; /**< Array of token IDs. */ + size_t n_tokens; /**< Number of tokens in the input. */ +} RKLLMTokenInput; + +/** + * @struct RKLLMMultiModalInput + * @brief Represents multimodal input (e.g., text and image). + */ +typedef struct { + char* prompt; /**< Text prompt input. */ + float* image_embed; /**< Embedding of the images (of size n_image * n_image_tokens * image_embed_length). */ + size_t n_image_tokens; /**< Number of image_token. */ + size_t n_image; /**< Number of image. */ + size_t image_width; /**< Width of image. */ + size_t image_height; /**< Height of image. */ +} RKLLMMultiModalInput; + +/** + * @struct RKLLMInput + * @brief Represents different types of input to the LLM via a union. + */ +typedef struct { + const char* role; /**< Message role: "user" (user input), "tool" (function result) */ + bool enable_thinking; /**< Controls whether "thinking mode" is enabled for the Qwen3 model. */ + RKLLMInputType input_type; /**< Specifies the type of input provided (e.g., prompt, token, embed, multimodal). */ + union { + const char* prompt_input; /**< Text prompt input if input_type is RKLLM_INPUT_PROMPT. */ + RKLLMEmbedInput embed_input; /**< Embedding input if input_type is RKLLM_INPUT_EMBED. */ + RKLLMTokenInput token_input; /**< Token input if input_type is RKLLM_INPUT_TOKEN. */ + RKLLMMultiModalInput multimodal_input; /**< Multimodal input if input_type is RKLLM_INPUT_MULTIMODAL. */ + }; +} RKLLMInput; + +/** + * @struct RKLLMLoraParam + * @brief Structure defining parameters for Lora adapters. + */ +typedef struct { + const char* lora_adapter_name; /**< Name of the Lora adapter. */ +} RKLLMLoraParam; + +/** + * @struct RKLLMPromptCacheParam + * @brief Structure to define parameters for caching prompts. + */ +typedef struct { + int save_prompt_cache; /**< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save). */ + const char* prompt_cache_path; /**< Path to the prompt cache file. */ +} RKLLMPromptCacheParam; + +/** + * @struct RKLLMCrossAttnParam + * @brief Structure holding parameters for cross-attention inference. + * + * This structure is used when performing cross-attention in the decoder. + * It provides the encoder output (key/value caches), position indices, + * and attention mask. + * + * - `encoder_k_cache` must be stored in contiguous memory with layout: + * [num_layers][num_tokens][num_kv_heads][head_dim] + * - `encoder_v_cache` must be stored in contiguous memory with layout: + * [num_layers][num_kv_heads][head_dim][num_tokens] + */ +typedef struct { + float* encoder_k_cache; /**< Pointer to encoder key cache (size: num_layers * num_tokens * num_kv_heads * head_dim). */ + float* encoder_v_cache; /**< Pointer to encoder value cache (size: num_layers * num_kv_heads * head_dim * num_tokens). */ + float* encoder_mask; /**< Pointer to encoder attention mask (array of size num_tokens). */ + int32_t* encoder_pos; /**< Pointer to encoder token positions (array of size num_tokens). */ + int num_tokens; /**< Number of tokens in the encoder sequence. */ +} RKLLMCrossAttnParam; + +/** + * @struct RKLLMInferParam + * @brief Structure for defining parameters during inference. + */ +typedef struct { + RKLLMInferMode mode; /**< Inference mode (e.g., generate or get last hidden layer). */ + RKLLMLoraParam* lora_params; /**< Pointer to Lora adapter parameters. */ + RKLLMPromptCacheParam* prompt_cache_params; /**< Pointer to prompt cache parameters. */ + int keep_history; /**Flag to determine history retention (1: keep history, 0: discard history).*/ +} RKLLMInferParam; + +/** + * @struct RKLLMResultLastHiddenLayer + * @brief Structure to hold the hidden states from the last layer. + */ +typedef struct { + const float* hidden_states; /**< Pointer to the hidden states (of size num_tokens * embd_size). */ + int embd_size; /**< Size of the embedding vector. */ + int num_tokens; /**< Number of tokens for which hidden states are stored. */ +} RKLLMResultLastHiddenLayer; + +/** + * @struct RKLLMResultLogits + * @brief Structure to hold the logits. + */ +typedef struct { + const float* logits; /**< Pointer to the logits (of size num_tokens * vocab_size). */ + int vocab_size; /**< Size of the vocab. */ + int num_tokens; /**< Number of tokens for which logits are stored. */ +} RKLLMResultLogits; + +/** + * @struct RKLLMPerfStat + * @brief Structure to hold performance statistics for prefill and generate stages. + */ +typedef struct { + float prefill_time_ms; /**< Total time taken for the prefill stage in milliseconds. */ + int prefill_tokens; /**< Number of tokens processed during the prefill stage. */ + float generate_time_ms; /**< Total time taken for the generate stage in milliseconds. */ + int generate_tokens; /**< Number of tokens processed during the generate stage. */ + float memory_usage_mb; /**< VmHWM resident memory usage during inference, in megabytes. */ +} RKLLMPerfStat; + +/** + * @struct RKLLMResult + * @brief Structure to represent the result of LLM inference. + */ +typedef struct { + const char* text; /**< Generated text result. */ + int32_t token_id; /**< ID of the generated token. */ + RKLLMResultLastHiddenLayer last_hidden_layer; /**< Hidden states of the last layer (if requested). */ + RKLLMResultLogits logits; /**< Model output logits. */ + RKLLMPerfStat perf; /**< Pointer to performance statistics (prefill and generate). */ +} RKLLMResult; + +/** + * @typedef LLMResultCallback + * @brief Callback function to handle LLM results. + * @param result Pointer to the LLM result. + * @param userdata Pointer to user data for the callback. + * @param state State of the LLM call (e.g., finished, error). + * @return int Return value indicating the handling status: + * - 0: Continue inference normally. + * - 1: Pause inference. If the user wants to modify or intervene in the result (e.g., editing output, injecting new prompt), + * return 1 to suspend the current inference. Later, call `rkllm_run` with updated content to resume inference. + */ +typedef int(*LLMResultCallback)(RKLLMResult* result, void* userdata, LLMCallState state); + +/** + * @brief Creates a default RKLLMParam structure with preset values. + * @return A default RKLLMParam structure. + */ +RKLLMParam rkllm_createDefaultParam(); + +/** + * @brief Initializes the LLM with the given parameters. + * @param handle Pointer to the LLM handle. + * @param param Configuration parameters for the LLM. + * @param callback Callback function to handle LLM results. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_init(LLMHandle* handle, RKLLMParam* param, LLMResultCallback callback); + +/** + * @brief Loads a Lora adapter into the LLM. + * @param handle LLM handle. + * @param lora_adapter Pointer to the Lora adapter structure. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_load_lora(LLMHandle handle, RKLLMLoraAdapter* lora_adapter); + +/** + * @brief Loads a prompt cache from a file. + * @param handle LLM handle. + * @param prompt_cache_path Path to the prompt cache file. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_load_prompt_cache(LLMHandle handle, const char* prompt_cache_path); + +/** + * @brief Releases the prompt cache from memory. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_release_prompt_cache(LLMHandle handle); + +/** + * @brief Destroys the LLM instance and releases resources. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_destroy(LLMHandle handle); + +/** + * @brief Runs an LLM inference task synchronously. + * @param handle LLM handle. + * @param rkllm_input Input data for the LLM. + * @param rkllm_infer_params Parameters for the inference task. + * @param userdata Pointer to user data for the callback. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_run(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); + +/** + * @brief Runs an LLM inference task asynchronously. + * @param handle LLM handle. + * @param rkllm_input Input data for the LLM. + * @param rkllm_infer_params Parameters for the inference task. + * @param userdata Pointer to user data for the callback. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_run_async(LLMHandle handle, RKLLMInput* rkllm_input, RKLLMInferParam* rkllm_infer_params, void* userdata); + +/** + * @brief Aborts an ongoing LLM task. + * @param handle LLM handle. + * @return Status code (0 for success, non-zero for failure). + */ +int rkllm_abort(LLMHandle handle); + +/** + * @brief Checks if an LLM task is currently running. + * @param handle LLM handle. + * @return Status code (0 if a task is running, non-zero for otherwise). + */ +int rkllm_is_running(LLMHandle handle); + +/** + * @brief Clear the key-value cache for a given LLM handle. + * + * This function is used to clear part or all of the KV cache. + * + * @param handle LLM handle. + * @param keep_system_prompt Flag indicating whether to retain the system prompt in the cache (1 to retain, 0 to clear). + * This flag is ignored if a specific range [start_pos, end_pos) is provided. + * @param start_pos Array of start positions (inclusive) of the KV cache ranges to clear, one per batch. + * @param end_pos Array of end positions (exclusive) of the KV cache ranges to clear, one per batch. + * If both start_pos and end_pos are set to nullptr, the entire cache will be cleared and keep_system_prompt will take effect, + * If start_pos[i] < end_pos[i], only the specified range will be cleared, and keep_system_prompt will be ignored. + * @note: start_pos or end_pos is only valid when keep_history == 0 and the generation has been paused by returning 1 in the callback + * @return Status code (0 if cache was cleared successfully, non-zero otherwise). + */ +int rkllm_clear_kv_cache(LLMHandle handle, int keep_system_prompt, int* start_pos, int* end_pos); + +/** + * @brief Get the current size of the key-value cache for a given LLM handle. + * + * This function returns the total number of positions currently stored in the model's KV cache. + * + * @param handle LLM handle. + * @param cache_sizes Pointer to an array where the per-batch cache sizes will be stored. + * The array must be preallocated with space for `n_batch` elements. + */ +int rkllm_get_kv_cache_size(LLMHandle handle, int* cache_sizes); + +/** + * @brief Sets the chat template for the LLM, including system prompt, prefix, and postfix. + * + * This function allows you to customize the chat template by providing a system prompt, a prompt prefix, and a prompt postfix. + * The system prompt is typically used to define the behavior or context of the language model, + * while the prefix and postfix are used to format the user input and output respectively. + * + * @param handle LLM handle. + * @param system_prompt The system prompt that defines the context or behavior of the language model. + * @param prompt_prefix The prefix added before the user input in the chat. + * @param prompt_postfix The postfix added after the user input in the chat. + * + * @return Status code (0 if the template was set successfully, non-zero for errors). + */ +int rkllm_set_chat_template(LLMHandle handle, const char* system_prompt, const char* prompt_prefix, const char* prompt_postfix); + +/** + * @brief Sets the function calling configuration for the LLM, including system prompt, tool definitions, and tool response token. + * + * @param handle LLM handle. + * @param system_prompt The system prompt that defines the context or behavior of the language model. + * @param tools A JSON-formatted string that defines the available functions, including their names, descriptions, and parameters. + * @param tool_response_str A unique tag used to identify function call results within a conversation. It acts as the marker tag, + * allowing tokenizer to recognize tool outputs separately from normal dialogue turns. + * @return Status code (0 if the configuration was set successfully, non-zero for errors). + */ +int rkllm_set_function_tools(LLMHandle handle, const char* system_prompt, const char* tools, const char* tool_response_str); + +/** + * @brief Sets the cross-attention parameters for the LLM decoder. + * + * @param handle LLM handle. + * @param cross_attn_params Pointer to the structure containing encoder-related input data + * used for cross-attention (see RKLLMCrossAttnParam for details). + * + * @return Status code (0 if the parameters were set successfully, non-zero for errors). + */ +int rkllm_set_cross_attn_params(LLMHandle handle, RKLLMCrossAttnParam* cross_attn_params); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/app/src/main/java/com/digitalperson/EntryActivity.kt b/app/src/main/java/com/digitalperson/EntryActivity.kt index 396b9a9..656790a 100644 --- a/app/src/main/java/com/digitalperson/EntryActivity.kt +++ b/app/src/main/java/com/digitalperson/EntryActivity.kt @@ -2,10 +2,15 @@ package com.digitalperson import android.content.Intent import android.os.Bundle +import android.util.Log import androidx.appcompat.app.AppCompatActivity import com.digitalperson.config.AppConfig class EntryActivity : AppCompatActivity() { + companion object { + private const val TAG = "EntryActivity" + } + override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) @@ -14,6 +19,7 @@ class EntryActivity : AppCompatActivity() { } else { MainActivity::class.java } + Log.i(TAG, "USE_LIVE2D=${AppConfig.Avatar.USE_LIVE2D}, target=${target.simpleName}") startActivity(Intent(this, target)) finish() } diff --git a/app/src/main/java/com/digitalperson/Live2DChatActivity.kt b/app/src/main/java/com/digitalperson/Live2DChatActivity.kt index 7b28f46..3961625 100644 --- a/app/src/main/java/com/digitalperson/Live2DChatActivity.kt +++ b/app/src/main/java/com/digitalperson/Live2DChatActivity.kt @@ -2,20 +2,37 @@ package com.digitalperson import android.Manifest import android.content.pm.PackageManager +import android.graphics.Bitmap import android.os.Bundle import android.util.Log import android.widget.Toast +import androidx.camera.core.CameraSelector +import androidx.camera.core.ImageAnalysis +import androidx.camera.core.ImageProxy +import androidx.camera.core.Preview +import androidx.camera.lifecycle.ProcessCameraProvider +import androidx.camera.view.PreviewView import androidx.appcompat.app.AppCompatActivity import androidx.core.app.ActivityCompat +import androidx.core.content.ContextCompat import com.digitalperson.cloud.CloudApiManager import com.digitalperson.audio.AudioProcessor import com.digitalperson.vad.VadManager import com.digitalperson.asr.AsrManager -import com.digitalperson.tts.TtsManager import com.digitalperson.ui.Live2DUiManager import com.digitalperson.config.AppConfig +import com.digitalperson.face.FaceDetectionPipeline +import com.digitalperson.face.FaceOverlayView +import com.digitalperson.face.ImageProxyBitmapConverter import com.digitalperson.metrics.TraceManager import com.digitalperson.metrics.TraceSession +import com.digitalperson.tts.TtsController +import com.digitalperson.llm.LLMManager +import com.digitalperson.llm.LLMManagerCallback +import com.digitalperson.util.FileHelper +import java.io.File +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -26,14 +43,24 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.withContext class Live2DChatActivity : AppCompatActivity() { + companion object { + private const val TAG_ACTIVITY = "Live2DChatActivity" + private const val TAG_LLM = "LLM_ROUTE" + } private lateinit var uiManager: Live2DUiManager private lateinit var vadManager: VadManager private lateinit var asrManager: AsrManager - private lateinit var ttsManager: TtsManager + private lateinit var ttsController: TtsController private lateinit var audioProcessor: AudioProcessor + private var llmManager: LLMManager? = null + private var useLocalLLM = false // 默认使用云端 LLM - private val permissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) + private val appPermissions: Array = arrayOf( + Manifest.permission.RECORD_AUDIO, + Manifest.permission.CAMERA + ) + private val micPermissions: Array = arrayOf(Manifest.permission.RECORD_AUDIO) @Volatile private var isRecording: Boolean = false @@ -55,23 +82,46 @@ class Live2DChatActivity : AppCompatActivity() { @Volatile private var llmInFlight: Boolean = false private var enableStreaming = false + private lateinit var cameraPreviewView: PreviewView + private lateinit var faceOverlayView: FaceOverlayView + private lateinit var faceDetectionPipeline: FaceDetectionPipeline + private var facePipelineReady: Boolean = false + private var cameraProvider: ProcessCameraProvider? = null + private lateinit var cameraAnalyzerExecutor: ExecutorService + override fun onRequestPermissionsResult( requestCode: Int, permissions: Array, grantResults: IntArray ) { super.onRequestPermissionsResult(requestCode, permissions, grantResults) - val ok = requestCode == AppConfig.REQUEST_RECORD_AUDIO_PERMISSION && - grantResults.isNotEmpty() && - grantResults[0] == PackageManager.PERMISSION_GRANTED - if (!ok) { + if (requestCode != AppConfig.REQUEST_RECORD_AUDIO_PERMISSION) return + if (grantResults.isEmpty()) { + finish() + return + } + val granted = permissions.zip(grantResults.toTypedArray()).associate { it.first to it.second } + val micGranted = granted[Manifest.permission.RECORD_AUDIO] == PackageManager.PERMISSION_GRANTED + val cameraGranted = granted[Manifest.permission.CAMERA] == PackageManager.PERMISSION_GRANTED + + if (!micGranted) { Log.e(AppConfig.TAG, "Audio record is disallowed") finish() + return + } + if (!cameraGranted) { + uiManager.showToast("未授予相机权限,暂不启用人脸检测") + Log.w(AppConfig.TAG, "Camera permission denied") + return + } + if (facePipelineReady) { + startCameraPreviewAndDetection() } } override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) + Log.i(TAG_ACTIVITY, "onCreate") setContentView(R.layout.activity_live2d_chat) uiManager = Live2DUiManager(this) @@ -82,10 +132,28 @@ class Live2DChatActivity : AppCompatActivity() { stopButtonId = R.id.stop_button, recordButtonId = R.id.record_button, traditionalButtonsId = R.id.traditional_buttons, + llmModeSwitchId = R.id.llm_mode_switch, + llmModeSwitchRowId = R.id.llm_mode_switch_row, silentPlayerViewId = 0, speakingPlayerViewId = 0, live2dViewId = R.id.live2d_view ) + + cameraPreviewView = findViewById(R.id.camera_preview) + cameraPreviewView.implementationMode = PreviewView.ImplementationMode.COMPATIBLE + faceOverlayView = findViewById(R.id.face_overlay) + cameraAnalyzerExecutor = Executors.newSingleThreadExecutor() + faceDetectionPipeline = FaceDetectionPipeline( + context = applicationContext, + onResult = { result -> + faceOverlayView.updateResult(result) + }, + onGreeting = { greeting -> + uiManager.appendToUi("\n[Face] $greeting\n") + ttsController.enqueueSegment(greeting) + ttsController.enqueueEnd() + } + ) // 根据配置选择交互方式 uiManager.setUseHoldToSpeak(AppConfig.USE_HOLD_TO_SPEAK) @@ -105,7 +173,7 @@ class Live2DChatActivity : AppCompatActivity() { uiManager.setStopButtonListener { onStopClicked(userInitiated = true) } } - ActivityCompat.requestPermissions(this, permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION) + ActivityCompat.requestPermissions(this, appPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION) try { val streamingSwitch = findViewById(R.id.streaming_switch) @@ -119,6 +187,27 @@ class Live2DChatActivity : AppCompatActivity() { Log.w(AppConfig.TAG, "Streaming switch not found in layout: ${e.message}") } + try { + val ttsModeSwitch = findViewById(R.id.tts_mode_switch) + ttsModeSwitch.isChecked = false // 默认使用本地TTS + ttsModeSwitch.setOnCheckedChangeListener { _, isChecked -> + ttsController.setUseQCloudTts(isChecked) + uiManager.showToast("TTS模式已切换到${if (isChecked) "腾讯云" else "本地"}") + } + } catch (e: Exception) { + Log.w(AppConfig.TAG, "TTS mode switch not found in layout: ${e.message}") + } + + // 设置 LLM 模式开关 + uiManager.setLLMSwitchListener { isChecked -> + useLocalLLM = isChecked + Log.i(TAG_LLM, "LLM mode switched: useLocalLLM=$useLocalLLM") + uiManager.showToast("LLM模式已切换到${if (isChecked) "本地" else "云端"}") + // 重新初始化 LLM + initLLM() + } + // 默认不显示 LLM 开关,等模型下载完成后再显示 + if (AppConfig.USE_HOLD_TO_SPEAK) { uiManager.setButtonsEnabled(recordEnabled = false) } else { @@ -127,8 +216,8 @@ class Live2DChatActivity : AppCompatActivity() { uiManager.setText("初始化中…") audioProcessor = AudioProcessor(this) - ttsManager = TtsManager(this) - ttsManager.setCallback(createTtsCallback()) + ttsController = TtsController(this) + ttsController.setCallback(createTtsCallback()) asrManager = AsrManager(this) asrManager.setAudioProcessor(audioProcessor) @@ -137,6 +226,64 @@ class Live2DChatActivity : AppCompatActivity() { vadManager = VadManager(this) vadManager.setCallback(createVadCallback()) + // 初始化 LLM 管理器 + initLLM() + + // 检查是否需要下载模型 + if (!FileHelper.isLocalLLMAvailable(this)) { + // 显示下载进度对话框 + uiManager.showDownloadProgressDialog() + + // 异步下载模型文件 + FileHelper.downloadModelFilesWithProgress( + this, + onProgress = { fileName, downloaded, total, progress -> + runOnUiThread { + val downloadedMB = downloaded / (1024 * 1024) + val totalMB = total / (1024 * 1024) + uiManager.updateDownloadProgress( + fileName, + downloadedMB, + totalMB, + progress + ) + } + }, + onComplete = { success, message -> + runOnUiThread { + uiManager.dismissDownloadProgressDialog() + if (success) { + Log.i(AppConfig.TAG, "Model files downloaded successfully") + uiManager.showToast("模型下载完成", Toast.LENGTH_SHORT) + // 检查本地 LLM 是否可用 + if (FileHelper.isLocalLLMAvailable(this)) { + Log.i(AppConfig.TAG, "Local LLM is available, enabling local LLM switch") + // 显示本地 LLM 开关,并同步状态 + uiManager.showLLMSwitch(true) + uiManager.setLLMSwitchChecked(useLocalLLM) + } + } else { + Log.e(AppConfig.TAG, "Failed to download model files: $message") + uiManager.showToast("模型下载失败: $message", Toast.LENGTH_LONG) + } + // 下载完成后初始化其他组件 + initializeOtherComponents() + } + } + ) + } else { + // 模型已存在,直接初始化其他组件 + initializeOtherComponents() + // 显示本地 LLM 开关,并同步状态 + uiManager.showLLMSwitch(true) + uiManager.setLLMSwitchChecked(useLocalLLM) + } + } + + /** + * 初始化其他组件(VAD、ASR、TTS、人脸检测等) + */ + private fun initializeOtherComponents() { ioScope.launch { try { Log.i(AppConfig.TAG, "Init VAD + SenseVoice(RKNN) + TTS (background)") @@ -144,7 +291,8 @@ class Live2DChatActivity : AppCompatActivity() { vadManager.initVadModel() asrManager.initSenseVoiceModel() } - val ttsOk = ttsManager.initTtsAndAudioTrack() + val ttsOk = ttsController.init() + facePipelineReady = faceDetectionPipeline.initialize() withContext(Dispatchers.Main) { if (!ttsOk) { uiManager.showToast( @@ -152,6 +300,11 @@ class Live2DChatActivity : AppCompatActivity() { Toast.LENGTH_LONG ) } + if (!facePipelineReady) { + uiManager.showToast("RetinaFace 初始化失败,请检查模型和 rknn 运行库", Toast.LENGTH_LONG) + } else if (allPermissionsGranted()) { + startCameraPreviewAndDetection() + } uiManager.setText(getString(R.string.hint)) if (AppConfig.USE_HOLD_TO_SPEAK) { uiManager.setButtonsEnabled(recordEnabled = true) @@ -203,14 +356,22 @@ class Live2DChatActivity : AppCompatActivity() { Log.d(AppConfig.TAG, "ASR segment skipped: $reason") } - override fun shouldSkipAsr(): Boolean = ttsManager.isPlaying() + override fun shouldSkipAsr(): Boolean = ttsController.isPlaying() override fun isLlmInFlight(): Boolean = llmInFlight override fun onLlmCalled(text: String) { llmInFlight = true Log.d(AppConfig.TAG, "Calling LLM with text: $text") - cloudApiManager.callLLM(text) + if (useLocalLLM) { + Log.i(TAG_LLM, "Routing to LOCAL LLM") + // 使用本地 LLM 生成回复 + generateResponse(text) + } else { + Log.i(TAG_LLM, "Routing to CLOUD LLM") + // 使用云端 LLM 生成回复 + cloudApiManager.callLLM(text) + } } } @@ -220,7 +381,7 @@ class Live2DChatActivity : AppCompatActivity() { asrManager.enqueueAudioSegment(originalAudio, processedAudio) } - override fun shouldSkipProcessing(): Boolean = ttsManager.isPlaying() || llmInFlight + override fun shouldSkipProcessing(): Boolean = ttsController.isPlaying() || llmInFlight } private fun createCloudApiListener() = object : CloudApiManager.CloudApiListener { @@ -232,9 +393,9 @@ class Live2DChatActivity : AppCompatActivity() { if (enableStreaming) { for (seg in segmenter.flush()) { - ttsManager.enqueueSegment(seg) + ttsController.enqueueSegment(seg) } - ttsManager.enqueueEnd() + ttsController.enqueueEnd() } else { val previousMood = com.digitalperson.mood.MoodManager.getCurrentMood() val (filteredText, mood) = com.digitalperson.mood.MoodManager.extractAndFilterMood(response) @@ -247,8 +408,8 @@ class Live2DChatActivity : AppCompatActivity() { runOnUiThread { uiManager.appendToUi("${filteredText}\n") } - ttsManager.enqueueSegment(filteredText) - ttsManager.enqueueEnd() + ttsController.enqueueSegment(filteredText) + ttsController.enqueueEnd() } } @@ -271,7 +432,7 @@ class Live2DChatActivity : AppCompatActivity() { val segments = segmenter.processChunk(filteredText) for (seg in segments) { - ttsManager.enqueueSegment(seg) + ttsController.enqueueSegment(seg) } } } @@ -285,7 +446,7 @@ class Live2DChatActivity : AppCompatActivity() { } } - private fun createTtsCallback() = object : TtsManager.TtsCallback { + private fun createTtsCallback() = object : TtsController.TtsCallback { override fun onTtsStarted(text: String) { runOnUiThread { uiManager.appendToUi("\n[TTS] 开始合成...\n") @@ -310,32 +471,6 @@ class Live2DChatActivity : AppCompatActivity() { uiManager.setSpeaking(speaking) } - override fun getCurrentTrace(): TraceSession? = currentTrace - - override fun onTraceMarkTtsRequestEnqueued() { - currentTrace?.markTtsRequestEnqueued() - } - - override fun onTraceMarkTtsSynthesisStart() { - currentTrace?.markTtsSynthesisStart() - } - - override fun onTraceMarkTtsFirstPcmReady() { - currentTrace?.markTtsFirstPcmReady() - } - - override fun onTraceMarkTtsFirstAudioPlay() { - currentTrace?.markTtsFirstAudioPlay() - } - - override fun onTraceMarkTtsDone() { - currentTrace?.markTtsDone() - } - - override fun onTraceAddDuration(name: String, value: Long) { - currentTrace?.addDuration(name, value) - } - override fun onEndTurn() { TraceManager.getInstance().endTurn() currentTrace = null @@ -344,27 +479,97 @@ class Live2DChatActivity : AppCompatActivity() { override fun onDestroy() { super.onDestroy() + stopCameraPreviewAndDetection() onStopClicked(userInitiated = false) ioScope.cancel() synchronized(nativeLock) { try { vadManager.release() } catch (_: Throwable) {} try { asrManager.release() } catch (_: Throwable) {} } - try { ttsManager.release() } catch (_: Throwable) {} + try { faceDetectionPipeline.release() } catch (_: Throwable) {} + try { cameraAnalyzerExecutor.shutdown() } catch (_: Throwable) {} + try { ttsController.release() } catch (_: Throwable) {} + try { llmManager?.destroy() } catch (_: Throwable) {} try { uiManager.release() } catch (_: Throwable) {} try { audioProcessor.release() } catch (_: Throwable) {} } override fun onResume() { super.onResume() + Log.i(TAG_ACTIVITY, "onResume") uiManager.onResume() + if (facePipelineReady && allPermissionsGranted()) { + startCameraPreviewAndDetection() + } } override fun onPause() { + Log.i(TAG_ACTIVITY, "onPause") + stopCameraPreviewAndDetection() uiManager.onPause() super.onPause() } + private fun allPermissionsGranted(): Boolean { + return appPermissions.all { + ContextCompat.checkSelfPermission(this, it) == PackageManager.PERMISSION_GRANTED + } + } + + private fun startCameraPreviewAndDetection() { + val cameraProviderFuture = ProcessCameraProvider.getInstance(this) + cameraProviderFuture.addListener({ + try { + val provider = cameraProviderFuture.get() + cameraProvider = provider + provider.unbindAll() + + val preview = Preview.Builder().build().apply { + setSurfaceProvider(cameraPreviewView.surfaceProvider) + } + cameraPreviewView.scaleType = PreviewView.ScaleType.FIT_CENTER + + val analyzer = ImageAnalysis.Builder() + .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST) + .build() + + analyzer.setAnalyzer(cameraAnalyzerExecutor) { imageProxy -> + analyzeCameraFrame(imageProxy) + } + + val selector = CameraSelector.Builder() + .requireLensFacing(CameraSelector.LENS_FACING_FRONT) + .build() + + provider.bindToLifecycle(this, selector, preview, analyzer) + } catch (t: Throwable) { + Log.e(AppConfig.TAG, "startCameraPreviewAndDetection failed: ${t.message}", t) + } + }, ContextCompat.getMainExecutor(this)) + } + + private fun stopCameraPreviewAndDetection() { + try { + cameraProvider?.unbindAll() + } catch (_: Throwable) { + } finally { + cameraProvider = null + } + } + + private fun analyzeCameraFrame(imageProxy: ImageProxy) { + try { + val bitmap: Bitmap? = ImageProxyBitmapConverter.toBitmap(imageProxy) + if (bitmap != null) { + faceDetectionPipeline.submitFrame(bitmap) + } + } catch (t: Throwable) { + Log.w(AppConfig.TAG, "analyzeCameraFrame error: ${t.message}") + } finally { + imageProxy.close() + } + } + private fun onStartClicked() { Log.d(AppConfig.TAG, "onStartClicked called") if (isRecording) { @@ -372,7 +577,7 @@ class Live2DChatActivity : AppCompatActivity() { return } - if (!audioProcessor.initMicrophone(permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) { + if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) { uiManager.showToast("麦克风初始化失败/无权限") return } @@ -383,8 +588,7 @@ class Live2DChatActivity : AppCompatActivity() { uiManager.clearText() - ttsManager.reset() - ttsManager.setCurrentTrace(currentTrace) + ttsController.reset() segmenter.reset() vadManager.reset() @@ -409,12 +613,12 @@ class Live2DChatActivity : AppCompatActivity() { } // 如果TTS正在播放,打断它 - val interrupted = ttsManager.interruptForNewTurn() + val interrupted = ttsController.interruptForNewTurn() if (interrupted) { uiManager.appendToUi("\n[LOG] 已打断TTS播放\n") } - if (!audioProcessor.initMicrophone(permissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) { + if (!audioProcessor.initMicrophone(micPermissions, AppConfig.REQUEST_RECORD_AUDIO_PERMISSION)) { uiManager.showToast("麦克风初始化失败/无权限") return } @@ -427,7 +631,7 @@ class Live2DChatActivity : AppCompatActivity() { // interruptForNewTurn() already prepared TTS state for next turn. // Keep reset() only for non-interrupt entry points. - ttsManager.setCurrentTrace(currentTrace) + segmenter.reset() // 启动按住说话的动作 @@ -479,7 +683,7 @@ class Live2DChatActivity : AppCompatActivity() { recordingJob?.cancel() recordingJob = null - ttsManager.stop() + ttsController.stop() if (AppConfig.USE_HOLD_TO_SPEAK) { uiManager.setButtonsEnabled(recordEnabled = true) @@ -515,10 +719,10 @@ class Live2DChatActivity : AppCompatActivity() { while (isRecording && ioScope.coroutineContext.isActive) { loopCount++ if (loopCount % 100 == 0) { - Log.d(AppConfig.TAG, "processSamplesLoop running, loopCount=$loopCount, ttsPlaying=${ttsManager.isPlaying()}") + Log.d(AppConfig.TAG, "processSamplesLoop running, loopCount=$loopCount, ttsPlaying=${ttsController.isPlaying()}") } - if (ttsManager.isPlaying()) { + if (ttsController.isPlaying()) { if (vadManager.isInSpeech()) { Log.d(AppConfig.TAG, "TTS playing, resetting VAD state") vadManager.clearState() @@ -546,11 +750,134 @@ class Live2DChatActivity : AppCompatActivity() { } val forced = segmenter.maybeForceByTime() - for (seg in forced) ttsManager.enqueueSegment(seg) + for (seg in forced) ttsController.enqueueSegment(seg) } vadManager.forceFinalize() } Log.d(AppConfig.TAG, "processSamplesLoop stopped") } + + /** + * 初始化 LLM 管理器 + */ + private fun initLLM() { + try { + Log.i(TAG_LLM, "initLLM called, useLocalLLM=$useLocalLLM") + llmManager?.destroy() + llmManager = null + if (useLocalLLM) { + // // 本地 LLM 初始化前,先暂停/释放重模块 + // Log.i(AppConfig.TAG, "Pausing camera and releasing face detection before LLM initialization") + // stopCameraPreviewAndDetection() + // try { + // faceDetectionPipeline.release() + // Log.i(AppConfig.TAG, "Face detection pipeline released") + // } catch (e: Exception) { + // Log.w(AppConfig.TAG, "Failed to release face detection pipeline: ${e.message}") + // } + + // // 释放 VAD 管理器 + // try { + // vadManager.release() + // Log.i(AppConfig.TAG, "VAD manager released") + // } catch (e: Exception) { + // Log.w(AppConfig.TAG, "Failed to release VAD manager: ${e.message}") + // } + + val modelPath = FileHelper.getLLMModelPath(applicationContext) + if (!File(modelPath).exists()) { + throw IllegalStateException("RKLLM model file missing: $modelPath") + } + Log.i(AppConfig.TAG, "Initializing LLM with model path: $modelPath") + val localLlmResponseBuffer = StringBuilder() + llmManager = LLMManager(modelPath, object : LLMManagerCallback { + override fun onThinking(msg: String, finished: Boolean) { + // 处理思考过程 + Log.d(TAG_LLM, "LOCAL onThinking finished=$finished msg=${msg.take(60)}") + runOnUiThread { + if (!finished && enableStreaming) { + uiManager.appendToUi("\n[LLM] 思考中: $msg\n") + } + } + } + + override fun onResult(msg: String, finished: Boolean) { + // 处理生成结果 + Log.d(TAG_LLM, "LOCAL onResult finished=$finished len=${msg.length}") + runOnUiThread { + if (!finished) { + localLlmResponseBuffer.append(msg) + if (enableStreaming) { + uiManager.appendToUi(msg) + } + } else { + val finalText = localLlmResponseBuffer.toString().trim() + localLlmResponseBuffer.setLength(0) + if (!enableStreaming && finalText.isNotEmpty()) { + uiManager.appendToUi("$finalText\n") + } + uiManager.appendToUi("\n\n[LLM] 生成完成\n") + llmInFlight = false + if (finalText.isNotEmpty()) { + ttsController.enqueueSegment(finalText) + ttsController.enqueueEnd() + } else { + Log.w(TAG_LLM, "LOCAL final text is empty, skip TTS enqueue") + } + } + } + } + }) + Log.i(AppConfig.TAG, "LLM initialized successfully") + Log.i(TAG_LLM, "LOCAL LLM initialized") + } else { + // 使用云端 LLM,不需要初始化本地 LLM + Log.i(AppConfig.TAG, "Using cloud LLM, skipping local LLM initialization") + Log.i(TAG_LLM, "CLOUD mode active") + } + } catch (e: Exception) { + Log.e(AppConfig.TAG, "Failed to initialize LLM: ${e.message}", e) + Log.e(TAG_LLM, "LOCAL init failed: ${e.message}", e) + useLocalLLM = false + runOnUiThread { + uiManager.setLLMSwitchChecked(false) + uiManager.showToast("LLM 初始化失败: ${e.message}", Toast.LENGTH_LONG) + uiManager.appendToUi("\n[错误] LLM 初始化失败: ${e.message}\n") + } + } + } + + /** + * 使用 LLM 生成回复 + */ + private fun generateResponse(userInput: String) { + try { + if (useLocalLLM) { + val systemPrompt = "你是一个友好的数字人助手,回答要简洁明了。" + Log.d(AppConfig.TAG, "Generating response for: $userInput") + val local = llmManager + if (local == null) { + Log.e(TAG_LLM, "LOCAL LLM manager is null, fallback to CLOUD") + cloudApiManager.callLLM(userInput) + return + } + Log.i(TAG_LLM, "LOCAL generateResponseWithSystem") + local.generateResponseWithSystem(systemPrompt, userInput) + } else { + // 使用云端 LLM + Log.d(AppConfig.TAG, "Using cloud LLM for response: $userInput") + Log.i(TAG_LLM, "CLOUD callLLM") + // 调用云端 LLM + cloudApiManager.callLLM(userInput) + } + } catch (e: Exception) { + Log.e(AppConfig.TAG, "Failed to generate response: ${e.message}", e) + Log.e(TAG_LLM, "generateResponse failed: ${e.message}", e) + runOnUiThread { + uiManager.appendToUi("\n\n[Error] LLM 生成失败: ${e.message}\n") + llmInFlight = false + } + } + } } \ No newline at end of file diff --git a/app/src/main/java/com/digitalperson/config/AppConfig.kt b/app/src/main/java/com/digitalperson/config/AppConfig.kt index 53c2ed0..f250e1b 100644 --- a/app/src/main/java/com/digitalperson/config/AppConfig.kt +++ b/app/src/main/java/com/digitalperson/config/AppConfig.kt @@ -34,6 +34,25 @@ object AppConfig { const val MAX_TEXT_LENGTH = 50 const val MODEL_DIR = "sensevoice_models" } + + object Face { + const val MODEL_DIR = "RetinaFace" + const val MODEL_NAME = "RetinaFace_mobile320.rknn" + const val INPUT_SIZE = 320 + const val SCORE_THRESHOLD = 0.6f + const val NMS_THRESHOLD = 0.4f + const val TRACK_IOU_THRESHOLD = 0.45f + const val STABLE_MS = 1000L + const val FRONTAL_MIN_FACE_SIZE = 90f + const val FRONTAL_MAX_ASPECT_DIFF = 0.35f + } + + object FaceRecognition { + const val MODEL_DIR = "Insightface" + const val MODEL_NAME = "ms1mv3_arcface_r18.rknn" + const val SIMILARITY_THRESHOLD = 0.5f + const val GREETING_COOLDOWN_MS = 6000L + } object Audio { const val GAIN_SMOOTHING_FACTOR = 0.1f @@ -48,4 +67,10 @@ object AppConfig { const val MODEL_DIR = "live2d_model/Haru_pro_jp" const val MODEL_JSON = "haru_greeter_t05.model3.json" } + + object QCloud { + const val APP_ID = "1302849512" // 替换为你的腾讯云APP_ID + const val SECRET_ID = "AKIDbBdyBGE5oPuIGA1iDlDYlFallaJ0YODB" // 替换为你的腾讯云SECRET_ID + const val SECRET_KEY = "32vhIl9OQIRclmLjvuleLp9LLAnFVYEp" // 替换为你的腾讯云SECRET_KEY + } } diff --git a/app/src/main/java/com/digitalperson/engine/ArcFaceEngineRKNN.java b/app/src/main/java/com/digitalperson/engine/ArcFaceEngineRKNN.java new file mode 100644 index 0000000..7cf908d --- /dev/null +++ b/app/src/main/java/com/digitalperson/engine/ArcFaceEngineRKNN.java @@ -0,0 +1,79 @@ +package com.digitalperson.engine; + +import android.content.Context; +import android.graphics.Bitmap; +import android.util.Log; + +import com.digitalperson.config.AppConfig; +import com.digitalperson.util.FileHelper; + +import java.io.File; + +public class ArcFaceEngineRKNN { + private static final String TAG = "ArcFaceEngineRKNN"; + + static { + try { + System.loadLibrary("rknnrt"); + System.loadLibrary("sensevoiceEngine"); + Log.d(TAG, "Loaded native libs for ArcFace RKNN"); + } catch (UnsatisfiedLinkError e) { + Log.e(TAG, "Failed to load native libraries for ArcFace", e); + throw e; + } + } + + private final long nativePtr; + private boolean initialized = false; + private boolean released = false; + + public ArcFaceEngineRKNN() { + nativePtr = createEngineNative(); + if (nativePtr == 0) { + throw new RuntimeException("Failed to create native ArcFace engine"); + } + } + + public boolean initialize(Context context) { + if (released) return false; + File modelDir = FileHelper.copyInsightFaceAssets(context); + File modelFile = new File(modelDir, AppConfig.FaceRecognition.MODEL_NAME); + int ret = initNative(nativePtr, modelFile.getAbsolutePath()); + initialized = ret == 0; + if (!initialized) { + Log.e(TAG, "ArcFace init failed, code=" + ret + ", model=" + modelFile.getAbsolutePath()); + } + return initialized; + } + + public float[] extractEmbedding(Bitmap bitmap, float left, float top, float right, float bottom) { + Log.d(TAG, "extractEmbedding called: initialized=" + initialized + ", released=" + released + ", bitmap=" + (bitmap != null)); + if (!initialized || released || bitmap == null) { + Log.w(TAG, "extractEmbedding failed: initialized=" + initialized + ", released=" + released + ", bitmap=" + (bitmap != null)); + return new float[0]; + } + float[] emb = extractEmbeddingNative(nativePtr, bitmap, left, top, right, bottom); + Log.d(TAG, "extractEmbeddingNative returned: " + (emb != null ? emb.length : "null")); + return emb != null ? emb : new float[0]; + } + + public void release() { + if (!released && nativePtr != 0) { + releaseNative(nativePtr); + } + released = true; + initialized = false; + } + + private native long createEngineNative(); + private native int initNative(long ptr, String modelPath); + private native float[] extractEmbeddingNative( + long ptr, + Bitmap bitmap, + float left, + float top, + float right, + float bottom + ); + private native void releaseNative(long ptr); +} diff --git a/app/src/main/java/com/digitalperson/engine/RetinaFaceEngineRKNN.java b/app/src/main/java/com/digitalperson/engine/RetinaFaceEngineRKNN.java new file mode 100644 index 0000000..c34a591 --- /dev/null +++ b/app/src/main/java/com/digitalperson/engine/RetinaFaceEngineRKNN.java @@ -0,0 +1,77 @@ +package com.digitalperson.engine; + +import android.content.Context; +import android.graphics.Bitmap; +import android.util.Log; + +import com.digitalperson.config.AppConfig; +import com.digitalperson.util.FileHelper; + +import java.io.File; + +public class RetinaFaceEngineRKNN { + private static final String TAG = "RetinaFaceEngineRKNN"; + + static { + try { + System.loadLibrary("rknnrt"); + System.loadLibrary("sensevoiceEngine"); + Log.d(TAG, "Loaded native libs for RetinaFace RKNN"); + } catch (UnsatisfiedLinkError e) { + Log.e(TAG, "Failed to load native libraries for RetinaFace", e); + throw e; + } + } + + private final long nativePtr; + private boolean initialized = false; + private boolean released = false; + + public RetinaFaceEngineRKNN() { + nativePtr = createEngineNative(); + if (nativePtr == 0) { + throw new RuntimeException("Failed to create native RetinaFace engine"); + } + } + + public boolean initialize(Context context) { + if (released) { + return false; + } + File modelDir = FileHelper.copyRetinaFaceAssets(context); + File modelFile = new File(modelDir, AppConfig.Face.MODEL_NAME); + int ret = initNative( + nativePtr, + modelFile.getAbsolutePath(), + AppConfig.Face.INPUT_SIZE, + AppConfig.Face.SCORE_THRESHOLD, + AppConfig.Face.NMS_THRESHOLD + ); + initialized = ret == 0; + if (!initialized) { + Log.e(TAG, "RetinaFace init failed, code=" + ret + ", model=" + modelFile.getAbsolutePath()); + } + return initialized; + } + + public float[] detect(Bitmap bitmap) { + if (!initialized || released || bitmap == null) { + return new float[0]; + } + float[] raw = detectNative(nativePtr, bitmap); + return raw != null ? raw : new float[0]; + } + + public void release() { + if (!released && nativePtr != 0) { + releaseNative(nativePtr); + } + released = true; + initialized = false; + } + + private native long createEngineNative(); + private native int initNative(long ptr, String modelPath, int inputSize, float scoreThreshold, float nmsThreshold); + private native float[] detectNative(long ptr, Bitmap bitmap); + private native void releaseNative(long ptr); +} diff --git a/app/src/main/java/com/digitalperson/face/FaceDetectionPipeline.kt b/app/src/main/java/com/digitalperson/face/FaceDetectionPipeline.kt new file mode 100644 index 0000000..33c4207 --- /dev/null +++ b/app/src/main/java/com/digitalperson/face/FaceDetectionPipeline.kt @@ -0,0 +1,223 @@ +package com.digitalperson.face + +import android.content.Context +import android.graphics.Bitmap +import android.util.Log +import com.digitalperson.config.AppConfig +import com.digitalperson.engine.RetinaFaceEngineRKNN +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.math.abs +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext + +data class FaceBox( + val left: Float, + val top: Float, + val right: Float, + val bottom: Float, + val score: Float, +) + +data class FaceDetectionResult( + val sourceWidth: Int, + val sourceHeight: Int, + val faces: List, +) + +class FaceDetectionPipeline( + private val context: Context, + private val onResult: (FaceDetectionResult) -> Unit, + private val onGreeting: (String) -> Unit, +) { + private val engine = RetinaFaceEngineRKNN() + private val recognizer = FaceRecognizer(context) + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val frameInFlight = AtomicBoolean(false) + private val initialized = AtomicBoolean(false) + private var trackFace: FaceBox? = null + private var trackId: Long = 0 + private var trackStableSinceMs: Long = 0 + private var greetedTrackId: Long = -1 + private var lastGreetMs: Long = 0 + + fun initialize(): Boolean { + val detectorOk = engine.initialize(context) + val recognizerOk = recognizer.initialize() + val ok = detectorOk && recognizerOk + initialized.set(ok) + Log.i(AppConfig.TAG, "Face pipeline initialize result=$ok detector=$detectorOk recognizer=$recognizerOk") + return ok + } + + fun submitFrame(bitmap: Bitmap) { + if (!initialized.get()) { + bitmap.recycle() + return + } + if (!frameInFlight.compareAndSet(false, true)) { + bitmap.recycle() + return + } + + scope.launch { + try { + val width = bitmap.width + val height = bitmap.height + val raw = engine.detect(bitmap) + + val faceCount = raw.size / 5 + val faces = ArrayList(faceCount) + var i = 0 + while (i + 4 < raw.size) { + faces.add( + FaceBox( + left = raw[i], + top = raw[i + 1], + right = raw[i + 2], + bottom = raw[i + 3], + score = raw[i + 4], + ) + ) + i += 5 + } + // 过滤太小的人脸 + val minFaceSize = 50 // 最小人脸大小(像素) + val filteredFaces = faces.filter { face -> + val width = face.right - face.left + val height = face.bottom - face.top + width > minFaceSize && height > minFaceSize + } + +// if (filteredFaces.isNotEmpty()) { +// Log.d( +// AppConfig.TAG,"[Face] filtered detected ${filteredFaces.size} face(s)" +// ) +// } + + maybeRecognizeAndGreet(bitmap, filteredFaces) + withContext(Dispatchers.Main) { + onResult(FaceDetectionResult(width, height, filteredFaces)) + } + } catch (t: Throwable) { + Log.e(AppConfig.TAG, "Face detection pipeline failed: ${t.message}", t) + } finally { + bitmap.recycle() + frameInFlight.set(false) + } + } + } + + private suspend fun maybeRecognizeAndGreet(bitmap: Bitmap, faces: List) { + val now = System.currentTimeMillis() + if (faces.isEmpty()) { + trackFace = null + trackStableSinceMs = 0 + return + } + + val primary = faces.maxByOrNull { (it.right - it.left) * (it.bottom - it.top) } ?: return + val prev = trackFace + if (prev == null || iou(prev, primary) < AppConfig.Face.TRACK_IOU_THRESHOLD) { + trackId += 1 + greetedTrackId = -1 + trackStableSinceMs = now + } + trackFace = primary + + val stableMs = now - trackStableSinceMs + val frontal = isFrontal(primary, bitmap.width, bitmap.height) + val coolingDown = (now - lastGreetMs) < AppConfig.FaceRecognition.GREETING_COOLDOWN_MS + if (stableMs < AppConfig.Face.STABLE_MS || !frontal || greetedTrackId == trackId || coolingDown) { + return + } + + val match = recognizer.identify(bitmap, primary) + + Log.d(AppConfig.TAG, "[Face] Recognition result: matchedName=${match.matchedName}, similarity=${match.similarity}") + + // 检查是否需要保存新人脸 + if (match.matchedName.isNullOrBlank()) { + Log.d(AppConfig.TAG, "[Face] No match found, attempting to add new face") + // 提取人脸特征 + val embedding = extractEmbedding(bitmap, primary) + Log.d(AppConfig.TAG, "[Face] Extracted embedding size: ${embedding.size}") + if (embedding.isNotEmpty()) { + // 尝试添加新人脸 + val added = recognizer.addNewFace(embedding) + Log.d(AppConfig.TAG, "[Face] Add new face result: $added") + if (added) { + Log.i(AppConfig.TAG, "[Face] New face added to database") + } else { + Log.i(AppConfig.TAG, "[Face] Face already exists in database (similar face found)") + } + } else { + Log.w(AppConfig.TAG, "[Face] Failed to extract embedding") + } + } else { + Log.d(AppConfig.TAG, "[Face] Matched existing face: ${match.matchedName}") + } + + val greeting = if (!match.matchedName.isNullOrBlank()) { + "你好,${match.matchedName}!" + } else { + "你好,很高兴见到你。" + } + greetedTrackId = trackId + lastGreetMs = now + Log.i( + AppConfig.TAG, + "[Face] greeting track=$trackId stable=${stableMs}ms frontal=$frontal matched=${match.matchedName} score=${match.similarity}" + ) + withContext(Dispatchers.Main) { + onGreeting(greeting) + } + } + + private fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray { + return recognizer.extractEmbedding(bitmap, face) + } + + private fun isFrontal(face: FaceBox, frameW: Int, frameH: Int): Boolean { + val w = face.right - face.left + val h = face.bottom - face.top + if (w < AppConfig.Face.FRONTAL_MIN_FACE_SIZE || h < AppConfig.Face.FRONTAL_MIN_FACE_SIZE) { + return false + } + val aspectDiff = abs((w / h) - 1f) + if (aspectDiff > AppConfig.Face.FRONTAL_MAX_ASPECT_DIFF) { + return false + } + val cx = (face.left + face.right) * 0.5f + val cy = (face.top + face.bottom) * 0.5f + val minX = frameW * 0.15f + val maxX = frameW * 0.85f + val minY = frameH * 0.15f + val maxY = frameH * 0.85f + return cx in minX..maxX && cy in minY..maxY + } + + private fun iou(a: FaceBox, b: FaceBox): Float { + val left = maxOf(a.left, b.left) + val top = maxOf(a.top, b.top) + val right = minOf(a.right, b.right) + val bottom = minOf(a.bottom, b.bottom) + val w = maxOf(0f, right - left) + val h = maxOf(0f, bottom - top) + val inter = w * h + val areaA = maxOf(0f, a.right - a.left) * maxOf(0f, a.bottom - a.top) + val areaB = maxOf(0f, b.right - b.left) * maxOf(0f, b.bottom - b.top) + val union = areaA + areaB - inter + return if (union <= 0f) 0f else inter / union + } + + fun release() { + scope.cancel() + engine.release() + recognizer.release() + initialized.set(false) + } +} diff --git a/app/src/main/java/com/digitalperson/face/FaceFeatureStore.kt b/app/src/main/java/com/digitalperson/face/FaceFeatureStore.kt new file mode 100644 index 0000000..52cb025 --- /dev/null +++ b/app/src/main/java/com/digitalperson/face/FaceFeatureStore.kt @@ -0,0 +1,93 @@ +package com.digitalperson.face + +import android.content.ContentValues +import android.content.Context +import android.database.sqlite.SQLiteDatabase +import android.database.sqlite.SQLiteOpenHelper +import android.util.Log +import com.digitalperson.config.AppConfig +import java.nio.ByteBuffer +import java.nio.ByteOrder + +data class FaceProfile( + val id: Long, + val name: String, + val embedding: FloatArray, +) + +class FaceFeatureStore(context: Context) : SQLiteOpenHelper(context, DB_NAME, null, DB_VERSION) { + override fun onCreate(db: SQLiteDatabase) { + db.execSQL( + """ + CREATE TABLE IF NOT EXISTS face_profiles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + embedding BLOB NOT NULL, + updated_at INTEGER NOT NULL + ) + """.trimIndent() + ) + } + + override fun onUpgrade(db: SQLiteDatabase, oldVersion: Int, newVersion: Int) { + db.execSQL("DROP TABLE IF EXISTS face_profiles") + onCreate(db) + } + + fun loadAllProfiles(): List { + val db = readableDatabase + val list = ArrayList() + db.rawQuery("SELECT id, name, embedding FROM face_profiles", null).use { c -> + val idIdx = c.getColumnIndexOrThrow("id") + val nameIdx = c.getColumnIndexOrThrow("name") + val embIdx = c.getColumnIndexOrThrow("embedding") + while (c.moveToNext()) { + val embBlob = c.getBlob(embIdx) ?: continue + list.add( + FaceProfile( + id = c.getLong(idIdx), + name = c.getString(nameIdx), + embedding = blobToFloatArray(embBlob), + ) + ) + } + } + return list + } + + fun upsertProfile(name: String, embedding: FloatArray) { + // 确保名字不为null,使用空字符串作为默认值 + val safeName = name.takeIf { it.isNotBlank() } ?: "" + val values = ContentValues().apply { + put("name", safeName) + put("embedding", floatArrayToBlob(embedding)) + put("updated_at", System.currentTimeMillis()) + } + val rowId = writableDatabase.insertWithOnConflict( + "face_profiles", + null, + values, + SQLiteDatabase.CONFLICT_REPLACE + ) + Log.i(AppConfig.TAG, "[FaceFeatureStore] upsertProfile name='$safeName' rowId=$rowId dim=${embedding.size}") + } + + private fun floatArrayToBlob(values: FloatArray): ByteArray { + val buf = ByteBuffer.allocate(values.size * 4).order(ByteOrder.LITTLE_ENDIAN) + for (v in values) buf.putFloat(v) + return buf.array() + } + + private fun blobToFloatArray(blob: ByteArray): FloatArray { + if (blob.isEmpty()) return FloatArray(0) + val buf = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN) + val out = FloatArray(blob.size / 4) + for (i in out.indices) out[i] = buf.getFloat() + return out + } + + companion object { + private const val DB_NAME = "face_feature.db" + private const val DB_VERSION = 1 + } +} diff --git a/app/src/main/java/com/digitalperson/face/FaceOverlayView.kt b/app/src/main/java/com/digitalperson/face/FaceOverlayView.kt new file mode 100644 index 0000000..a7f17be --- /dev/null +++ b/app/src/main/java/com/digitalperson/face/FaceOverlayView.kt @@ -0,0 +1,61 @@ +package com.digitalperson.face + +import android.content.Context +import android.graphics.Canvas +import android.graphics.Color +import android.graphics.Paint +import android.graphics.RectF +import android.util.AttributeSet +import android.view.View + +class FaceOverlayView @JvmOverloads constructor( + context: Context, + attrs: AttributeSet? = null, +) : View(context, attrs) { + + private val boxPaint = Paint(Paint.ANTI_ALIAS_FLAG).apply { + color = Color.GREEN + style = Paint.Style.STROKE + strokeWidth = 4f + } + + private val textPaint = Paint(Paint.ANTI_ALIAS_FLAG).apply { + color = Color.GREEN + textSize = 28f + } + + @Volatile + private var latestResult: FaceDetectionResult? = null + + fun updateResult(result: FaceDetectionResult) { + latestResult = result + postInvalidateOnAnimation() + } + + override fun onDraw(canvas: Canvas) { + super.onDraw(canvas) + val result = latestResult ?: return + if (result.sourceWidth <= 0 || result.sourceHeight <= 0) return + + val srcW = result.sourceWidth.toFloat() + val srcH = result.sourceHeight.toFloat() + val viewW = width.toFloat() + val viewH = height.toFloat() + if (viewW <= 0f || viewH <= 0f) return + + val scale = minOf(viewW / srcW, viewH / srcH) + val dx = (viewW - srcW * scale) / 2f + val dy = (viewH - srcH * scale) / 2f + + for (face in result.faces) { + val rect = RectF( + dx + face.left * scale, + dy + face.top * scale, + dx + face.right * scale, + dy + face.bottom * scale, + ) + canvas.drawRect(rect, boxPaint) + canvas.drawText(String.format("%.2f", face.score), rect.left, rect.top - 8f, textPaint) + } + } +} diff --git a/app/src/main/java/com/digitalperson/face/FaceRecognizer.kt b/app/src/main/java/com/digitalperson/face/FaceRecognizer.kt new file mode 100644 index 0000000..9ccbd91 --- /dev/null +++ b/app/src/main/java/com/digitalperson/face/FaceRecognizer.kt @@ -0,0 +1,129 @@ +package com.digitalperson.face + +import android.content.Context +import android.graphics.Bitmap +import android.util.Log +import com.digitalperson.config.AppConfig +import com.digitalperson.engine.ArcFaceEngineRKNN +import kotlin.math.sqrt + +data class FaceRecognitionResult( + val matchedName: String?, + val similarity: Float, + val embeddingDim: Int, +) + +class FaceRecognizer(context: Context) { + private val appContext = context.applicationContext + private val engine = ArcFaceEngineRKNN() + private val store = FaceFeatureStore(appContext) + private val cache = ArrayList() + + @Volatile + private var initialized = false + + fun initialize(): Boolean { + Log.d(AppConfig.TAG, "[FaceRecognizer] initialize: starting...") + val ok = engine.initialize(appContext) + Log.d(AppConfig.TAG, "[FaceRecognizer] initialize: engine.initialize() returned $ok") + if (!ok) { + initialized = false + Log.e(AppConfig.TAG, "[FaceRecognizer] initialize: failed - engine initialization failed") + return false + } + cache.clear() + val profiles = store.loadAllProfiles() + cache.addAll(profiles) + initialized = true + Log.i(AppConfig.TAG, "[FaceRecognizer] initialized, profiles=${cache.size}") + return true + } + + fun identify(bitmap: Bitmap, face: FaceBox): FaceRecognitionResult { + if (!initialized) return FaceRecognitionResult(null, 0f, 0) + val embedding = extractEmbedding(bitmap, face) + if (embedding.isEmpty()) return FaceRecognitionResult(null, 0f, 0) + + var bestName: String? = null + var bestScore = -1f + for (p in cache) { + if (p.embedding.size != embedding.size) continue + val score = cosineSimilarity(embedding, p.embedding) + if (score > bestScore) { + bestScore = score + bestName = p.name + } + } + if (bestScore >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) { + return FaceRecognitionResult(bestName, bestScore, embedding.size) + } + return FaceRecognitionResult(null, bestScore, embedding.size) + } + + fun extractEmbedding(bitmap: Bitmap, face: FaceBox): FloatArray { + if (!initialized) return FloatArray(0) + return engine.extractEmbedding(bitmap, face.left, face.top, face.right, face.bottom) + } + + fun addOrUpdateProfile(name: String?, embedding: FloatArray) { + val normalized = normalize(embedding) + store.upsertProfile(name ?: "", normalized) + // 移除旧的记录(如果存在) + if (name != null) { + cache.removeAll { it.name == name } + } + cache.add(FaceProfile(id = -1L, name = name ?: "", embedding = normalized)) + } + + fun addNewFace(embedding: FloatArray): Boolean { + Log.d(AppConfig.TAG, "[FaceRecognizer] addNewFace: embedding size=${embedding.size}, cache size=${cache.size}") + + // 检查是否已经存在相似的人脸 + for (p in cache) { + if (p.embedding.size != embedding.size) { + Log.d(AppConfig.TAG, "[FaceRecognizer] Skipping profile with different embedding size: ${p.embedding.size}") + continue + } + val score = cosineSimilarity(embedding, p.embedding) + Log.d(AppConfig.TAG, "[FaceRecognizer] Comparing with profile '${p.name}': similarity=$score, threshold=${AppConfig.FaceRecognition.SIMILARITY_THRESHOLD}") + if (score >= AppConfig.FaceRecognition.SIMILARITY_THRESHOLD) { + // 已经存在相似的人脸,不需要添加 + Log.i(AppConfig.TAG, "[FaceRecognizer] Similar face found: ${p.name} with similarity=$score, not adding new face") + return false + } + } + + // 添加新人脸,名字为null + Log.i(AppConfig.TAG, "[FaceRecognizer] No similar face found, adding new face") + addOrUpdateProfile(null, embedding) + return true + } + + fun release() { + initialized = false + engine.release() + store.close() + } + + private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { + var dot = 0f + var na = 0f + var nb = 0f + for (i in a.indices) { + dot += a[i] * b[i] + na += a[i] * a[i] + nb += b[i] * b[i] + } + if (na <= 1e-12f || nb <= 1e-12f) return -1f + return (dot / (sqrt(na) * sqrt(nb))).coerceIn(-1f, 1f) + } + + private fun normalize(v: FloatArray): FloatArray { + var sum = 0f + for (x in v) sum += x * x + val norm = sqrt(sum.coerceAtLeast(1e-12f)) + val out = FloatArray(v.size) + for (i in v.indices) out[i] = v[i] / norm + return out + } +} diff --git a/app/src/main/java/com/digitalperson/face/ImageProxyBitmapConverter.kt b/app/src/main/java/com/digitalperson/face/ImageProxyBitmapConverter.kt new file mode 100644 index 0000000..b5dc16d --- /dev/null +++ b/app/src/main/java/com/digitalperson/face/ImageProxyBitmapConverter.kt @@ -0,0 +1,87 @@ +package com.digitalperson.face + +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.ImageFormat +import android.graphics.Matrix +import android.graphics.Rect +import android.graphics.YuvImage +import androidx.camera.core.ImageProxy +import java.io.ByteArrayOutputStream + +object ImageProxyBitmapConverter { + fun toBitmap(image: ImageProxy): Bitmap? { + val nv21 = yuv420ToNv21(image) ?: return null + val yuvImage = YuvImage(nv21, ImageFormat.NV21, image.width, image.height, null) + val out = ByteArrayOutputStream() + if (!yuvImage.compressToJpeg(Rect(0, 0, image.width, image.height), 80, out)) { + return null + } + val bytes = out.toByteArray() + var bitmap = BitmapFactory.decodeByteArray(bytes, 0, bytes.size) ?: return null + if (bitmap.config != Bitmap.Config.ARGB_8888) { + val converted = bitmap.copy(Bitmap.Config.ARGB_8888, false) + bitmap.recycle() + bitmap = converted + } + + val matrix = Matrix() + + // 前置摄像头需要水平翻转 + // 注意:这里假设我们使用的是前置摄像头 + // 如果需要支持后置摄像头,需要根据实际使用的摄像头类型来决定是否翻转 + matrix.postScale(-1f, 1f, bitmap.width / 2f, bitmap.height / 2f) + + // 处理旋转 + val rotation = image.imageInfo.rotationDegrees + if (rotation != 0) { + matrix.postRotate(rotation.toFloat()) + } + + // 应用变换 + val transformed = Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true) + bitmap.recycle() + bitmap = transformed + + return bitmap + } + + private fun yuv420ToNv21(image: ImageProxy): ByteArray? { + val planes = image.planes + if (planes.size < 3) return null + val width = image.width + val height = image.height + val ySize = width * height + val uvSize = width * height / 4 + val nv21 = ByteArray(ySize + uvSize * 2) + + val yPlane = planes[0] + val yBuffer = yPlane.buffer + val yRowStride = yPlane.rowStride + var dst = 0 + for (row in 0 until height) { + yBuffer.position(row * yRowStride) + yBuffer.get(nv21, dst, width) + dst += width + } + + val uPlane = planes[1] + val vPlane = planes[2] + val uBuffer = uPlane.buffer + val vBuffer = vPlane.buffer + val uRowStride = uPlane.rowStride + val vRowStride = vPlane.rowStride + val uPixelStride = uPlane.pixelStride + val vPixelStride = vPlane.pixelStride + + for (row in 0 until height / 2) { + for (col in 0 until width / 2) { + val uIndex = row * uRowStride + col * uPixelStride + val vIndex = row * vRowStride + col * vPixelStride + nv21[dst++] = vBuffer.get(vIndex) + nv21[dst++] = uBuffer.get(uIndex) + } + } + return nv21 + } +} diff --git a/app/src/main/java/com/digitalperson/live2d/Live2DAvatarManager.kt b/app/src/main/java/com/digitalperson/live2d/Live2DAvatarManager.kt index 23a8264..924b39f 100644 --- a/app/src/main/java/com/digitalperson/live2d/Live2DAvatarManager.kt +++ b/app/src/main/java/com/digitalperson/live2d/Live2DAvatarManager.kt @@ -7,6 +7,7 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) { init { glSurfaceView.setEGLContextClientVersion(2) + glSurfaceView.setPreserveEGLContextOnPause(true) glSurfaceView.setRenderer(renderer) glSurfaceView.renderMode = GLSurfaceView.RENDERMODE_CONTINUOUSLY } @@ -16,11 +17,15 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) { } fun setMood(mood: String) { - renderer.setMood(mood) + glSurfaceView.queueEvent { + renderer.setMood(mood) + } } fun startSpecificMotion(motionName: String) { - renderer.startSpecificMotion(motionName) + glSurfaceView.queueEvent { + renderer.startSpecificMotion(motionName) + } } fun onResume() { @@ -32,6 +37,8 @@ class Live2DAvatarManager(private val glSurfaceView: GLSurfaceView) { } fun release() { - renderer.release() + glSurfaceView.queueEvent { + renderer.release() + } } } \ No newline at end of file diff --git a/app/src/main/java/com/digitalperson/live2d/Live2DCharacter.kt b/app/src/main/java/com/digitalperson/live2d/Live2DCharacter.kt index 9585625..300b0c0 100644 --- a/app/src/main/java/com/digitalperson/live2d/Live2DCharacter.kt +++ b/app/src/main/java/com/digitalperson/live2d/Live2DCharacter.kt @@ -214,32 +214,8 @@ class Live2DCharacter : CubismUserModel() { } private fun loadMoodMotions(assets: AssetManager, modelDir: String) { - // 开心心情动作 - moodMotions["开心"] = listOf( - "haru_g_m22.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m22.motion3.json"), - "haru_g_m21.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m21.motion3.json"), - "haru_g_m18.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m18.motion3.json") - ).mapNotNull { (fileName, motion) -> - motion?.let { - motionFileMap[it] = fileName - it - } - } - - // 伤心心情动作 - moodMotions["伤心"] = listOf( - "haru_g_m25.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m25.motion3.json"), - "haru_g_m24.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m24.motion3.json"), - "haru_g_m05.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m05.motion3.json") - ).mapNotNull { (fileName, motion) -> - motion?.let { - motionFileMap[it] = fileName - it - } - } - - // 平和心情动作 - moodMotions["平和"] = listOf( + // 中性心情动作 + moodMotions["中性"] = listOf( "haru_g_m15.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m15.motion3.json"), "haru_g_m07.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m07.motion3.json"), "haru_g_m06.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m06.motion3.json"), @@ -252,8 +228,50 @@ class Live2DCharacter : CubismUserModel() { } } - // 惊讶心情动作 - moodMotions["惊讶"] = listOf( + // 悲伤心情动作 + moodMotions["悲伤"] = listOf( + "haru_g_m25.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m25.motion3.json"), + "haru_g_m24.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m24.motion3.json"), + "haru_g_m05.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m05.motion3.json"), + "haru_g_m16.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m16.motion3.json"), + "haru_g_m20.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m20.motion3.json"), + ).mapNotNull { (fileName, motion) -> + motion?.let { + motionFileMap[it] = fileName + it + } + } + + // 高兴心情动作 + moodMotions["高兴"] = listOf( + "haru_g_m22.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m22.motion3.json"), + "haru_g_m21.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m21.motion3.json"), + "haru_g_m18.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m18.motion3.json"), + "haru_g_m09.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m09.motion3.json"), + "haru_g_m08.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m08.motion3.json") + ).mapNotNull { (fileName, motion) -> + motion?.let { + motionFileMap[it] = fileName + it + } + } + + // 生气心情动作 + moodMotions["生气"] = listOf( + "haru_g_m10.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m10.motion3.json"), + "haru_g_m11.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m11.motion3.json"), + "haru_g_m04.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m04.motion3.json"), + "haru_g_m03.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m03.motion3.json"), + + ).mapNotNull { (fileName, motion) -> + motion?.let { + motionFileMap[it] = fileName + it + } + } + + // 恐惧心情动作 + moodMotions["恐惧"] = listOf( "haru_g_m26.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m26.motion3.json"), "haru_g_m12.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m12.motion3.json") ).mapNotNull { (fileName, motion) -> @@ -263,18 +281,8 @@ class Live2DCharacter : CubismUserModel() { } } - // 关心心情动作 - moodMotions["关心"] = listOf( - "haru_g_m17.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m17.motion3.json") - ).mapNotNull { (fileName, motion) -> - motion?.let { - motionFileMap[it] = fileName - it - } - } - - // 害羞心情动作 - moodMotions["害羞"] = listOf( + // 撒娇心情动作 + moodMotions["撒娇"] = listOf( "haru_g_m19.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m19.motion3.json") ).mapNotNull { (fileName, motion) -> motion?.let { @@ -282,6 +290,38 @@ class Live2DCharacter : CubismUserModel() { it } } + + // 震惊心情动作 + moodMotions["震惊"] = listOf( + "haru_g_m26.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m26.motion3.json"), + "haru_g_m12.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m12.motion3.json") + ).mapNotNull { (fileName, motion) -> + motion?.let { + motionFileMap[it] = fileName + it + } + } + + // 厌恶心情动作 + moodMotions["厌恶"] = listOf( + "haru_g_m14.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m14.motion3.json"), + "haru_g_m13.motion3.json" to loadMotionByName(assets, modelDir, "haru_g_m13.motion3.json") + ).mapNotNull { (fileName, motion) -> + motion?.let { + motionFileMap[it] = fileName + it + } + } + + + + // 兼容旧的心情名称 + moodMotions["开心"] = moodMotions["高兴"] ?: emptyList() + moodMotions["伤心"] = moodMotions["悲伤"] ?: emptyList() + moodMotions["平和"] = moodMotions["平静"] ?: emptyList() + moodMotions["惊讶"] = moodMotions["震惊"] ?: emptyList() + moodMotions["关心"] = moodMotions["中性"] ?: emptyList() + moodMotions["害羞"] = moodMotions["撒娇"] ?: emptyList() } private fun loadSpecificMotions(assets: AssetManager, modelDir: String) { diff --git a/app/src/main/java/com/digitalperson/live2d/Live2DRenderer.kt b/app/src/main/java/com/digitalperson/live2d/Live2DRenderer.kt index e665318..2a24e45 100644 --- a/app/src/main/java/com/digitalperson/live2d/Live2DRenderer.kt +++ b/app/src/main/java/com/digitalperson/live2d/Live2DRenderer.kt @@ -14,6 +14,9 @@ import javax.microedition.khronos.opengles.GL10 class Live2DRenderer( private val context: Context ) : GLSurfaceView.Renderer { + companion object { + private const val TAG = "Live2DRenderer" + } @Volatile private var speaking = false @@ -25,6 +28,7 @@ class Live2DRenderer( GLES20.glClearColor(0f, 0f, 0f, 0f) ensureFrameworkInitialized() startTimeMs = SystemClock.elapsedRealtime() + Log.i(TAG, "onSurfaceCreated") runCatching { val model = Live2DCharacter() @@ -35,6 +39,7 @@ class Live2DRenderer( ) model.bindTextures(context.assets, AppConfig.Avatar.MODEL_DIR) character = model + Log.i(TAG, "Live2D model loaded and textures bound") }.onFailure { Log.e(AppConfig.TAG, "Load Live2D model failed: ${it.message}", it) character = null diff --git a/app/src/main/java/com/digitalperson/llm/LLMManager.kt b/app/src/main/java/com/digitalperson/llm/LLMManager.kt new file mode 100644 index 0000000..c923332 --- /dev/null +++ b/app/src/main/java/com/digitalperson/llm/LLMManager.kt @@ -0,0 +1,46 @@ +package com.digitalperson.llm + +interface LLMManagerCallback { + fun onThinking(msg: String, finished: Boolean) + fun onResult(msg: String, finished: Boolean) +} + +class LLMManager(modelPath: String, callback: LLMManagerCallback) : + RKLLM(modelPath, object : LLMCallback { + var inThinking = false + + override fun onCallback(data: String, state: LLMCallback.State) { + if (state == LLMCallback.State.NORMAL) { + if (data == "") { + inThinking = true + return + } else if (data == "") { + inThinking = false + callback.onThinking("", true) + return + } + + if (inThinking) { + callback.onThinking(data, false) + } else { + if (data == "\n") return + callback.onResult(data, false) + } + } else { + callback.onThinking("", true) + callback.onResult("", true) + } + } + }) + +{ + fun generateResponse(prompt: String) { + val msg = "<|User|>$prompt<|Assistant|>" + say(msg) + } + + fun generateResponseWithSystem(systemPrompt: String, userPrompt: String) { + val msg = "<|System|>$systemPrompt<|User|>$userPrompt<|Assistant|>" + say(msg) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/digitalperson/llm/RKLLM.kt b/app/src/main/java/com/digitalperson/llm/RKLLM.kt new file mode 100644 index 0000000..c109f5e --- /dev/null +++ b/app/src/main/java/com/digitalperson/llm/RKLLM.kt @@ -0,0 +1,52 @@ +package com.digitalperson.llm + +interface LLMCallback { + enum class State { + ERROR, NORMAL, FINISH + } + fun onCallback(data: String, state: State) +} + +open class RKLLM(modelPath: String, callback: LLMCallback) { + companion object { + init { + System.loadLibrary("rkllmrt") + } + } + + private var mInstance: Long + private var mCallback: LLMCallback + + init { + mInstance = initLLM(modelPath) + mCallback = callback + if (mInstance == 0L) { + throw IllegalStateException("RKLLM init failed: native handle is null") + } + } + + fun destroy() { + deinitLLM(mInstance) + mInstance = 0 + } + + protected fun say(text: String) { + if (mInstance == 0L) { + mCallback.onCallback("RKLLM is not initialized", LLMCallback.State.ERROR) + return + } + infer(mInstance, text) + } + + fun callbackFromNative(data: String, state: Int) { + var s = LLMCallback.State.ERROR + s = if (state == 0) LLMCallback.State.FINISH + else if (state < 0) LLMCallback.State.ERROR + else LLMCallback.State.NORMAL + mCallback.onCallback(data, s) + } + + private external fun initLLM(modelPath: String): Long + private external fun deinitLLM(handle: Long) + private external fun infer(handle: Long, text: String) +} \ No newline at end of file diff --git a/app/src/main/java/com/digitalperson/tts/QCloudTtsManager.kt b/app/src/main/java/com/digitalperson/tts/QCloudTtsManager.kt new file mode 100644 index 0000000..39425eb --- /dev/null +++ b/app/src/main/java/com/digitalperson/tts/QCloudTtsManager.kt @@ -0,0 +1,330 @@ +package com.digitalperson.tts + +import android.content.Context +import android.media.AudioAttributes +import android.media.AudioFormat +import android.media.AudioManager +import android.media.AudioTrack +import android.util.Log +import com.digitalperson.config.AppConfig +import com.digitalperson.mood.MoodManager +import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizer +import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizerListener +import com.tencent.cloud.realtime.tts.RealTimeSpeechSynthesizerRequest +import com.tencent.cloud.realtime.tts.SpeechSynthesizerResponse +import com.tencent.cloud.realtime.tts.core.ws.Credential +import com.tencent.cloud.realtime.tts.core.ws.SpeechClient +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import java.nio.ByteBuffer +import java.util.UUID +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.AtomicBoolean + +class QCloudTtsManager(private val context: Context) { + + companion object { + private const val TAG = "QCloudTtsManager" + private const val SAMPLE_RATE = 16000 + private val proxy = SpeechClient() + } + private var audioTrack: AudioTrack? = null + private var synthesizer: RealTimeSpeechSynthesizer? = null + + private sealed class TtsQueueItem { + data class Segment(val text: String) : TtsQueueItem() + data object End : TtsQueueItem() + } + + private val ttsQueue = LinkedBlockingQueue() + private val ttsStopped = AtomicBoolean(false) + private val ttsWorkerRunning = AtomicBoolean(false) + private val ttsPlaying = AtomicBoolean(false) + private val interrupting = AtomicBoolean(false) + + private val ioScope = CoroutineScope(Dispatchers.IO) + + interface TtsCallback { + fun onTtsStarted(text: String) + fun onTtsCompleted() + fun onTtsSegmentCompleted(durationMs: Long) + fun isTtsStopped(): Boolean + fun onClearAsrQueue() + fun onSetSpeaking(speaking: Boolean) + fun onEndTurn() + } + + private var callback: TtsCallback? = null + + fun setCallback(callback: TtsCallback) { + this.callback = callback + } + + fun init(): Boolean { + return try { + initAudioTrack() + true + } catch (e: Exception) { + Log.e(TAG, "Init QCloud TTS failed: ${e.message}", e) + false + } + } + + private fun initAudioTrack() { + val bufferSize = AudioTrack.getMinBufferSize( + SAMPLE_RATE, + AudioFormat.CHANNEL_OUT_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) + val attr = AudioAttributes.Builder() + .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .setUsage(AudioAttributes.USAGE_MEDIA) + .build() + val format = AudioFormat.Builder() + .setEncoding(AudioFormat.ENCODING_PCM_16BIT) + .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) + .setSampleRate(SAMPLE_RATE) + .build() + audioTrack = AudioTrack( + attr, + format, + bufferSize, + AudioTrack.MODE_STREAM, + AudioManager.AUDIO_SESSION_ID_GENERATE + ) + } + + fun enqueueSegment(seg: String) { + if (ttsStopped.get()) { + ttsStopped.set(false) + } + val cleanedSeg = seg.trimEnd('.', '。', '!', '!', '?', '?', ',', ',', ';', ';', ':', ':') + ttsQueue.offer(TtsQueueItem.Segment(cleanedSeg)) + ensureTtsWorker() + } + + fun enqueueEnd() { + ttsQueue.offer(TtsQueueItem.End) + } + + fun isPlaying(): Boolean = ttsPlaying.get() + + fun reset() { + val workerRunning = ttsWorkerRunning.get() + val wasStopped = ttsStopped.get() + + ttsStopped.set(false) + ttsPlaying.set(false) + ttsQueue.clear() + + if (wasStopped && workerRunning) { + ttsQueue.offer(TtsQueueItem.End) + } + } + + fun stop() { + ttsStopped.set(true) + ttsPlaying.set(false) + ttsQueue.clear() + ttsQueue.offer(TtsQueueItem.End) + + try { + synthesizer?.cancel() + synthesizer = null + audioTrack?.pause() + audioTrack?.flush() + } catch (_: Throwable) { + } + } + + fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean { + if (!interrupting.compareAndSet(false, true)) return false + try { + val hadPendingPlayback = ttsPlaying.get() || ttsWorkerRunning.get() || ttsQueue.isNotEmpty() + if (!hadPendingPlayback) { + ttsStopped.set(false) + ttsPlaying.set(false) + return false + } + + ttsStopped.set(true) + ttsPlaying.set(false) + ttsQueue.clear() + ttsQueue.offer(TtsQueueItem.End) + + try { + synthesizer?.cancel() + synthesizer = null + audioTrack?.pause() + audioTrack?.flush() + } catch (_: Throwable) { + } + + val deadline = System.currentTimeMillis() + waitTimeoutMs + while (ttsWorkerRunning.get() && System.currentTimeMillis() < deadline) { + Thread.sleep(10) + } + + if (ttsWorkerRunning.get()) { + Log.w(TAG, "interruptForNewTurn timeout: worker still running") + } + + ttsQueue.clear() + ttsStopped.set(false) + ttsPlaying.set(false) + callback?.onSetSpeaking(false) + return true + } finally { + interrupting.set(false) + } + } + + fun release() { + try { + synthesizer?.cancel() + synthesizer = null + } catch (_: Throwable) { + } + try { + audioTrack?.release() + audioTrack = null + } catch (_: Throwable) { + } + } + + private fun ensureTtsWorker() { + if (!ttsWorkerRunning.compareAndSet(false, true)) return + ioScope.launch { + try { + runTtsWorker() + } finally { + ttsWorkerRunning.set(false) + if (!ttsStopped.get() && ttsQueue.isNotEmpty()) { + ensureTtsWorker() + } + } + } + } + + private fun runTtsWorker() { + val audioTrack = audioTrack ?: return + + while (true) { + val item = ttsQueue.take() + if (ttsStopped.get()) break + + when (item) { + is TtsQueueItem.Segment -> { + ttsPlaying.set(true) + callback?.onSetSpeaking(true) + Log.d(TAG, "QCloud TTS started: processing segment '${item.text}'") + callback?.onTtsStarted(item.text) + + val startMs = System.currentTimeMillis() + + try { + if (audioTrack.playState != AudioTrack.PLAYSTATE_PLAYING) { + audioTrack.play() + } + + val credential = Credential( + AppConfig.QCloud.APP_ID, + AppConfig.QCloud.SECRET_ID, + AppConfig.QCloud.SECRET_KEY, + "" + ) + + val request = RealTimeSpeechSynthesizerRequest() + request.setVolume(0f) // 音量大小,范围[-10,10] + request.setSpeed(0f) // 语速,范围:[-2,6] + request.setCodec("pcm") // 返回音频格式:pcm + request.setSampleRate(SAMPLE_RATE) // 音频采样率 + request.setVoiceType(601010) // 音色ID + request.setEnableSubtitle(true) // 是否开启时间戳功能 + // 根据当前心情设置情感类别 + val currentMood = MoodManager.getCurrentMood() + val emotionCategory = when (currentMood) { + "中性" -> "neutral" + "悲伤" -> "sad" + "高兴" -> "happy" + "生气" -> "angry" + "恐惧" -> "fear" + "撒娇" -> "sajiao" + "震惊" -> "amaze" + "厌恶" -> "disgusted" + "平静" -> "peaceful" + // 兼容旧的心情名称 + "开心" -> "happy" + "伤心" -> "sad" + "平和" -> "peaceful" + "惊讶" -> "amaze" + "关心" -> "neutral" + "害羞" -> "sajiao" + else -> "neutral" + } + request.setEmotionCategory(emotionCategory) // 控制合成音频的情感 + request.setEmotionIntensity(100) // 控制合成音频情感程度 + request.setSessionId(UUID.randomUUID().toString()) // sessionId + request.setText(item.text) // 合成文本 + + val listener = object : RealTimeSpeechSynthesizerListener() { + override fun onSynthesisStart(response: SpeechSynthesizerResponse) { + Log.d(TAG, "onSynthesisStart: ${response.sessionId}") + } + + override fun onSynthesisEnd(response: SpeechSynthesizerResponse) { + Log.d(TAG, "onSynthesisEnd: ${response.sessionId}") + val ttsMs = System.currentTimeMillis() - startMs + callback?.onTtsSegmentCompleted(ttsMs) + } + + override fun onAudioResult(buffer: ByteBuffer) { + val data = ByteArray(buffer.remaining()) + buffer.get(data) + // 播放pcm + audioTrack.write(data, 0, data.size) + } + + override fun onTextResult(response: SpeechSynthesizerResponse) { + Log.d(TAG, "onTextResult: ${response.sessionId}") + } + + override fun onSynthesisCancel() { + Log.d(TAG, "onSynthesisCancel") + } + + override fun onSynthesisFail(response: SpeechSynthesizerResponse) { + Log.e(TAG, "onSynthesisFail: ${response.sessionId}, error: ${response.message}") + } + } + + synthesizer = RealTimeSpeechSynthesizer(proxy, credential, request, listener) + synthesizer?.start() + + } catch (e: Exception) { + Log.e(TAG, "QCloud TTS error: ${e.message}", e) + } + } + + TtsQueueItem.End -> { + callback?.onClearAsrQueue() + + waitForPlaybackComplete(audioTrack) + + callback?.onTtsCompleted() + + ttsPlaying.set(false) + callback?.onSetSpeaking(false) + callback?.onEndTurn() + break + } + } + } + } + + private fun waitForPlaybackComplete(audioTrack: AudioTrack) { + // 等待音频播放完成 + Thread.sleep(1000) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/digitalperson/tts/TtsController.kt b/app/src/main/java/com/digitalperson/tts/TtsController.kt new file mode 100644 index 0000000..ce27da9 --- /dev/null +++ b/app/src/main/java/com/digitalperson/tts/TtsController.kt @@ -0,0 +1,181 @@ +package com.digitalperson.tts + +import android.content.Context +import android.util.Log + +class TtsController(private val context: Context) { + + companion object { + private const val TAG = "TtsController" + } + + private var localTts: TtsManager? = null + private var qcloudTts: QCloudTtsManager? = null + private var useQCloudTts = false + + interface TtsCallback { + fun onTtsStarted(text: String) + fun onTtsCompleted() + fun onTtsSegmentCompleted(durationMs: Long) + fun isTtsStopped(): Boolean + fun onClearAsrQueue() + fun onSetSpeaking(speaking: Boolean) + fun onEndTurn() + } + + private var callback: TtsCallback? = null + + fun setCallback(callback: TtsCallback) { + this.callback = callback + localTts?.setCallback(object : TtsManager.TtsCallback { + override fun onTtsStarted(text: String) { + callback.onTtsStarted(text) + } + + override fun onTtsCompleted() { + callback.onTtsCompleted() + } + + override fun onTtsSegmentCompleted(durationMs: Long) { + callback.onTtsSegmentCompleted(durationMs) + } + + override fun isTtsStopped(): Boolean { + return callback.isTtsStopped() + } + + override fun onClearAsrQueue() { + callback.onClearAsrQueue() + } + + override fun onSetSpeaking(speaking: Boolean) { + callback.onSetSpeaking(speaking) + } + + override fun getCurrentTrace() = null + + override fun onTraceMarkTtsRequestEnqueued() { + } + + override fun onTraceMarkTtsSynthesisStart() { + } + + override fun onTraceMarkTtsFirstPcmReady() { + } + + override fun onTraceMarkTtsFirstAudioPlay() { + } + + override fun onTraceMarkTtsDone() { + } + + override fun onTraceAddDuration(name: String, value: Long) { + } + + override fun onEndTurn() { + callback.onEndTurn() + } + }) + qcloudTts?.setCallback(object : QCloudTtsManager.TtsCallback { + override fun onTtsStarted(text: String) { + callback.onTtsStarted(text) + } + + override fun onTtsCompleted() { + callback.onTtsCompleted() + } + + override fun onTtsSegmentCompleted(durationMs: Long) { + callback.onTtsSegmentCompleted(durationMs) + } + + override fun isTtsStopped(): Boolean { + return callback.isTtsStopped() + } + + override fun onClearAsrQueue() { + callback.onClearAsrQueue() + } + + override fun onSetSpeaking(speaking: Boolean) { + callback.onSetSpeaking(speaking) + } + + override fun onEndTurn() { + callback.onEndTurn() + } + }) + } + + fun init(): Boolean { + // 初始化本地TTS + localTts = TtsManager(context) + val localInit = localTts?.initTtsAndAudioTrack() ?: false + Log.d(TAG, "Local TTS init: $localInit") + + // 初始化腾讯云TTS + qcloudTts = QCloudTtsManager(context) + val qcloudInit = qcloudTts?.init() ?: false + Log.d(TAG, "QCloud TTS init: $qcloudInit") + + return localInit || qcloudInit + } + + fun setUseQCloudTts(useQCloud: Boolean) { + this.useQCloudTts = useQCloud + Log.d(TAG, "TTS mode changed: ${if (useQCloud) "QCloud" else "Local"}") + } + + fun enqueueSegment(seg: String) { + if (useQCloudTts) { + qcloudTts?.enqueueSegment(seg) + } else { + localTts?.enqueueSegment(seg) + } + } + + fun enqueueEnd() { + if (useQCloudTts) { + qcloudTts?.enqueueEnd() + } else { + localTts?.enqueueEnd() + } + } + + fun isPlaying(): Boolean { + return if (useQCloudTts) { + qcloudTts?.isPlaying() ?: false + } else { + localTts?.isPlaying() ?: false + } + } + + fun reset() { + if (useQCloudTts) { + qcloudTts?.reset() + } else { + localTts?.reset() + } + } + + fun stop() { + if (useQCloudTts) { + qcloudTts?.stop() + } else { + localTts?.stop() + } + } + + fun interruptForNewTurn(waitTimeoutMs: Long = 300): Boolean { + return if (useQCloudTts) { + qcloudTts?.interruptForNewTurn(waitTimeoutMs) ?: false + } else { + localTts?.interruptForNewTurn(waitTimeoutMs) ?: false + } + } + + fun release() { + localTts?.release() + qcloudTts?.release() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/digitalperson/ui/Live2DUiManager.kt b/app/src/main/java/com/digitalperson/ui/Live2DUiManager.kt index 9c38944..f390391 100644 --- a/app/src/main/java/com/digitalperson/ui/Live2DUiManager.kt +++ b/app/src/main/java/com/digitalperson/ui/Live2DUiManager.kt @@ -1,12 +1,14 @@ package com.digitalperson.ui import android.app.Activity +import android.app.ProgressDialog import android.opengl.GLSurfaceView import android.text.method.ScrollingMovementMethod import android.view.MotionEvent import android.widget.Button import android.widget.LinearLayout import android.widget.ScrollView +import android.widget.Switch import android.widget.TextView import android.widget.Toast import com.digitalperson.live2d.Live2DAvatarManager @@ -18,7 +20,10 @@ class Live2DUiManager(private val activity: Activity) { private var stopButton: Button? = null private var recordButton: Button? = null private var traditionalButtons: LinearLayout? = null + private var llmModeSwitch: Switch? = null + private var llmModeSwitchRow: LinearLayout? = null private var avatarManager: Live2DAvatarManager? = null + private var downloadProgressDialog: ProgressDialog? = null private var lastUiText: String = "" @@ -29,6 +34,8 @@ class Live2DUiManager(private val activity: Activity) { stopButtonId: Int = -1, recordButtonId: Int = -1, traditionalButtonsId: Int = -1, + llmModeSwitchId: Int = -1, + llmModeSwitchRowId: Int = -1, silentPlayerViewId: Int, speakingPlayerViewId: Int, live2dViewId: Int @@ -39,12 +46,17 @@ class Live2DUiManager(private val activity: Activity) { if (stopButtonId != -1) stopButton = activity.findViewById(stopButtonId) if (recordButtonId != -1) recordButton = activity.findViewById(recordButtonId) if (traditionalButtonsId != -1) traditionalButtons = activity.findViewById(traditionalButtonsId) + if (llmModeSwitchId != -1) llmModeSwitch = activity.findViewById(llmModeSwitchId) + if (llmModeSwitchRowId != -1) llmModeSwitchRow = activity.findViewById(llmModeSwitchRowId) textView?.movementMethod = ScrollingMovementMethod() val glView = activity.findViewById(live2dViewId) avatarManager = Live2DAvatarManager(glView) avatarManager?.setSpeaking(false) + + // 默认隐藏本地 LLM 开关 + llmModeSwitchRow?.visibility = LinearLayout.GONE } fun setStartButtonListener(listener: () -> Unit) { @@ -131,6 +143,72 @@ class Live2DUiManager(private val activity: Activity) { } } + /** + * 显示或隐藏本地 LLM 开关 + */ + fun showLLMSwitch(show: Boolean) { + activity.runOnUiThread { + llmModeSwitchRow?.visibility = if (show) LinearLayout.VISIBLE else LinearLayout.GONE + } + } + + /** + * 设置 LLM 模式开关的监听器 + */ + fun setLLMSwitchListener(listener: (Boolean) -> Unit) { + llmModeSwitch?.setOnCheckedChangeListener { _, isChecked -> + listener(isChecked) + } + } + + /** + * 设置 LLM 模式开关的状态 + */ + fun setLLMSwitchChecked(checked: Boolean) { + activity.runOnUiThread { + llmModeSwitch?.isChecked = checked + } + } + + /** + * 显示下载进度对话框 + */ + fun showDownloadProgressDialog() { + activity.runOnUiThread { + downloadProgressDialog = ProgressDialog(activity).apply { + setTitle("下载模型") + setMessage("正在下载 LLM 模型文件,请稍候...") + setProgressStyle(ProgressDialog.STYLE_HORIZONTAL) + isIndeterminate = false + setCancelable(false) + setCanceledOnTouchOutside(false) + show() + } + } + } + + /** + * 更新下载进度 + */ + fun updateDownloadProgress(fileName: String, downloadedMB: Long, totalMB: Long, progress: Int) { + activity.runOnUiThread { + downloadProgressDialog?.apply { + setMessage("正在下载: $fileName\n$downloadedMB MB / $totalMB MB") + setProgress(progress) + } + } + } + + /** + * 关闭下载进度对话框 + */ + fun dismissDownloadProgressDialog() { + activity.runOnUiThread { + downloadProgressDialog?.dismiss() + downloadProgressDialog = null + } + } + fun onResume() { avatarManager?.onResume() } diff --git a/app/src/main/java/com/digitalperson/util/FileHelper.kt b/app/src/main/java/com/digitalperson/util/FileHelper.kt index 017918f..b4f81c2 100644 --- a/app/src/main/java/com/digitalperson/util/FileHelper.kt +++ b/app/src/main/java/com/digitalperson/util/FileHelper.kt @@ -1,6 +1,8 @@ package com.digitalperson.util +import android.content.ContentUris import android.content.Context +import android.provider.MediaStore import android.util.Log import com.digitalperson.config.AppConfig import java.io.File @@ -48,13 +50,270 @@ object FileHelper { ) return copyAssetsToInternal(context, AppConfig.Asr.MODEL_DIR, outDir, files) } + + @JvmStatic + fun copyRetinaFaceAssets(context: Context): File { + val outDir = File(context.filesDir, AppConfig.Face.MODEL_DIR) + val files = arrayOf(AppConfig.Face.MODEL_NAME) + return copyAssetsToInternal(context, AppConfig.Face.MODEL_DIR, outDir, files) + } + + @JvmStatic + fun copyInsightFaceAssets(context: Context): File { + val outDir = File(context.filesDir, AppConfig.FaceRecognition.MODEL_DIR) + val files = arrayOf(AppConfig.FaceRecognition.MODEL_NAME) + return copyAssetsToInternal(context, AppConfig.FaceRecognition.MODEL_DIR, outDir, files) + } fun ensureDir(dir: File): File { - if (!dir.exists()) dir.mkdirs() + if (!dir.exists()) { + val created = dir.mkdirs() + if (!created) { + Log.e(TAG, "Failed to create directory: ${dir.absolutePath}") + // 如果创建失败,使用应用内部存储 + return File("/data/data/${dir.parentFile?.parentFile?.name}/files/llm") + } + } return dir } fun getAsrAudioDir(context: Context): File { return ensureDir(File(context.filesDir, "asr_audio")) } + +// @JvmStatic + // 当前使用的模型文件名 + private const val MODEL_FILE_NAME = "Qwen3-0.6B-rk3588-w8a8.rkllm" + + fun getLLMModelPath(context: Context): String { + Log.d(TAG, "=== getLLMModelPath START ===") + + // 从应用内部存储目录加载模型 + val llmDir = ensureDir(File(context.filesDir, "llm")) + + Log.d(TAG, "Loading models from: ${llmDir.absolutePath}") + + // 检查文件是否存在 + val rkllmFile = File(llmDir, MODEL_FILE_NAME) + + if (!rkllmFile.exists()) { + Log.e(TAG, "RKLLM model not found: ${rkllmFile.absolutePath}") + } else { + Log.i(TAG, "RKLLM model exists, size: ${rkllmFile.length() / (1024*1024)} MB") + } + + val modelPath = rkllmFile.absolutePath + Log.i(TAG, "Using RKLLM model path: $modelPath") + Log.d(TAG, "=== getLLMModelPath END ===") + return modelPath + } + + /** + * 异步下载模型文件,带进度回调 + * @param context 上下文 + * @param onProgress 进度回调 (currentFile, downloadedBytes, totalBytes, progressPercent) + * @param onComplete 完成回调 (success, message) + */ + @JvmStatic + fun downloadModelFilesWithProgress( + context: Context, + onProgress: (String, Long, Long, Int) -> Unit, + onComplete: (Boolean, String) -> Unit + ) { + Log.d(TAG, "=== downloadModelFilesWithProgress START ===") + + val llmDir = ensureDir(File(context.filesDir, "llm")) + + // 模型文件列表 - 使用 DeepSeek-R1-Distill-Qwen-1.5B 模型 + val modelFiles = listOf( + MODEL_FILE_NAME + ) + + // 在后台线程下载 + Thread { + try { + var allSuccess = true + var totalDownloaded: Long = 0 + var totalSize: Long = 0 + + // 首先计算总大小 + for (fileName in modelFiles) { + val modelFile = File(llmDir, fileName) + if (!modelFile.exists() || modelFile.length() == 0L) { + val size = getFileSizeFromServer("http://192.168.1.19:5000/download/$fileName") + if (size > 0) { + totalSize += size + } else { + // 如果无法获取文件大小,使用估计值 + when (fileName) { + MODEL_FILE_NAME -> totalSize += 1L * 1024 * 1024 * 1024 // 1.5B模型约1GB + else -> totalSize += 1L * 1024 * 1024 * 1024 // 1GB 默认 + } + Log.i(TAG, "Using estimated size for $fileName: ${totalSize / (1024*1024)} MB") + } + } + } + + for (fileName in modelFiles) { + val modelFile = File(llmDir, fileName) + if (!modelFile.exists() || modelFile.length() == 0L) { + Log.i(TAG, "Downloading model file: $fileName") + try { + downloadFileWithProgress( + "http://192.168.1.19:5000/download/$fileName", + modelFile + ) { downloaded, total -> + val progress = if (totalSize > 0) { + ((totalDownloaded + downloaded) * 100 / totalSize).toInt() + } else 0 + onProgress(fileName, downloaded, total, progress) + } + totalDownloaded += modelFile.length() + Log.i(TAG, "Downloaded model file: $fileName, size: ${modelFile.length() / (1024*1024)} MB") + } catch (e: Exception) { + Log.e(TAG, "Failed to download model file $fileName: ${e.message}") + allSuccess = false + } + } else { + totalDownloaded += modelFile.length() + Log.i(TAG, "Model file exists: $fileName, size: ${modelFile.length() / (1024*1024)} MB") + } + } + Log.d(TAG, "=== downloadModelFilesWithProgress END ===") + if (allSuccess) { + onComplete(true, "模型下载完成") + } else { + onComplete(false, "部分模型下载失败") + } + } catch (e: Exception) { + Log.e(TAG, "Download failed: ${e.message}") + onComplete(false, "下载失败: ${e.message}") + } + }.start() + } + + /** + * 从服务器获取文件大小 + */ + private fun getFileSizeFromServer(url: String): Long { + return try { + val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection + connection.requestMethod = "HEAD" + connection.connectTimeout = 15000 + connection.readTimeout = 15000 + + // 从响应头获取 Content-Length,避免 int 溢出 + val contentLengthStr = connection.getHeaderField("Content-Length") + var size = 0L + + if (contentLengthStr != null) { + try { + size = contentLengthStr.toLong() + if (size < 0) { + Log.w(TAG, "Invalid Content-Length value: $size") + size = 0 + } + } catch (e: NumberFormatException) { + Log.w(TAG, "Invalid Content-Length format: $contentLengthStr") + size = 0 + } + } else { + val contentLength = connection.contentLength + if (contentLength > 0) { + size = contentLength.toLong() + } else { + Log.w(TAG, "Content-Length not available or invalid: $contentLength") + size = 0 + } + } + + connection.disconnect() + Log.i(TAG, "File size for $url: $size bytes") + size + } catch (e: Exception) { + Log.w(TAG, "Failed to get file size: ${e.message}") + 0 + } + } + + /** + * 从网络下载文件,带进度回调 + */ + private fun downloadFileWithProgress( + url: String, + destination: File, + onProgress: (Long, Long) -> Unit + ) { + val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection + connection.connectTimeout = 30000 + connection.readTimeout = 6000000 + + // 从响应头获取 Content-Length,避免 int 溢出 + val contentLengthStr = connection.getHeaderField("Content-Length") + val totalSize = if (contentLengthStr != null) { + try { + contentLengthStr.toLong() + } catch (e: NumberFormatException) { + Log.w(TAG, "Invalid Content-Length format: $contentLengthStr") + 0 + } + } else { + connection.contentLength.toLong() + } + Log.i(TAG, "Downloading file $url, size: $totalSize bytes") + + try { + connection.inputStream.use { input -> + FileOutputStream(destination).use { output -> + val buffer = ByteArray(8192) + var downloaded: Long = 0 + var bytesRead: Int + + while (input.read(buffer).also { bytesRead = it } != -1) { + output.write(buffer, 0, bytesRead) + downloaded += bytesRead + onProgress(downloaded, totalSize) + } + } + } + } finally { + connection.disconnect() + } + } + + /** + * 检查本地 LLM 模型是否可用 + */ + @JvmStatic + fun isLocalLLMAvailable(context: Context): Boolean { + val llmDir = File(context.filesDir, "llm") + + val rkllmFile = File(llmDir, MODEL_FILE_NAME) + + val rkllmExists = rkllmFile.exists() && rkllmFile.length() > 0 + + Log.i(TAG, "LLM model check: rkllm=$rkllmExists") + Log.i(TAG, "RKLLM file: ${rkllmFile.absolutePath}, size: ${if (rkllmFile.exists()) rkllmFile.length() / (1024*1024) else 0} MB") + + return rkllmExists + } + + /** + * 从网络下载文件 + */ + private fun downloadFile(url: String, destination: File) { + val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection + connection.connectTimeout = 30000 // 30秒超时 + connection.readTimeout = 60000 // 60秒读取超时 + + try { + connection.inputStream.use { input -> + FileOutputStream(destination).use { output -> + input.copyTo(output) + } + } + } finally { + connection.disconnect() + } + } } diff --git a/app/src/main/jniLibs/arm64-v8a/libomp.so b/app/src/main/jniLibs/arm64-v8a/libomp.so new file mode 100644 index 0000000..674d9bb Binary files /dev/null and b/app/src/main/jniLibs/arm64-v8a/libomp.so differ diff --git a/app/src/main/jniLibs/arm64-v8a/librkllmrt.so b/app/src/main/jniLibs/arm64-v8a/librkllmrt.so new file mode 100644 index 0000000..348986e Binary files /dev/null and b/app/src/main/jniLibs/arm64-v8a/librkllmrt.so differ diff --git a/app/src/main/jniLibs/arm64-v8a/librknnrt.so.new b/app/src/main/jniLibs/arm64-v8a/librknnrt.so.new new file mode 100644 index 0000000..843401f Binary files /dev/null and b/app/src/main/jniLibs/arm64-v8a/librknnrt.so.new differ diff --git a/app/src/main/res/layout/activity_live2d_chat.xml b/app/src/main/res/layout/activity_live2d_chat.xml index fe608e2..4607961 100644 --- a/app/src/main/res/layout/activity_live2d_chat.xml +++ b/app/src/main/res/layout/activity_live2d_chat.xml @@ -16,11 +16,36 @@ app:layout_constraintStart_toStartOf="parent" app:layout_constraintTop_toTopOf="parent" /> + + + + + + + + android:fillViewport="true" + app:layout_constraintBottom_toTopOf="@+id/llm_mode_switch_row" + app:layout_constraintEnd_toEndOf="parent" + app:layout_constraintStart_toStartOf="parent"> + + + + + + + + + + + + + + @@ -48,11 +123,11 @@ android:layout_width="wrap_content" android:layout_height="wrap_content" android:layout_marginEnd="16dp" - android:text="流式输出" + android:text="腾讯云TTS" android:textSize="16sp" /> diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index a140ee9..217898c 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -3,5 +3,5 @@ 开始 结束 点击“开始”说话;识别后会请求大模型并用 TTS 播放回复。 - 你是一名小学女老师,喜欢回答学生的各种问题,请简洁但温柔地回答,每个回答不超过30字。在每次回复的最前面,用方括号标注你的心情,格式为[开心/伤心/愤怒/平和/惊讶/关心/害羞],例如:[开心]同学你好呀!请问有什么问题吗? + 你是一名小学女老师,喜欢回答学生的各种问题,请简洁但温柔地回答,每个回答不超过30字。在每次回复的最前面,用方括号标注你的心情,格式为[中性、悲伤、高兴、生气、恐惧、撒娇、震惊、厌恶],例如:[高兴]同学你好呀!请问有什么问题吗? diff --git a/gradle.properties b/gradle.properties index 6f40ab1..7cbdada 100644 --- a/gradle.properties +++ b/gradle.properties @@ -6,7 +6,7 @@ # http://www.gradle.org/docs/current/userguide/build_environment.html # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. -org.gradle.jvmargs=-Xmx6g -Dfile.encoding=UTF-8 +org.gradle.jvmargs=-Xmx8g -Dfile.encoding=UTF-8 # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. More details, visit # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects