Files
digital_person/app/src/main/java/com/digitalperson/cloud/CloudApiManager.java
2026-02-28 10:14:03 +08:00

264 lines
11 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.digitalperson.cloud;
import android.content.Context;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import com.digitalperson.BuildConfig;
import com.digitalperson.R;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import java.io.BufferedReader;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
public class CloudApiManager {
private static final String TAG = "CloudApiManager";
// 火山引擎OpenAI兼容API配置
private static final String LLM_API_URL = BuildConfig.LLM_API_URL;
private static final String API_KEY = BuildConfig.LLM_API_KEY;
private static final String LLM_MODEL = BuildConfig.LLM_MODEL;
private CloudApiListener mListener;
private Handler mMainHandler; // 用于在主线程执行UI更新
private JSONArray mConversationHistory; // 存储对话历史
private boolean mEnableStreaming = true; // 默认启用流式输出
public interface CloudApiListener {
void onLLMResponseReceived(String response);
void onLLMStreamingChunkReceived(String chunk);
void onTTSAudioReceived(String audioFilePath);
void onError(String errorMessage);
}
public CloudApiManager(CloudApiListener listener, Context context) {
this.mListener = listener;
this.mMainHandler = new Handler(Looper.getMainLooper()); // 初始化主线程Handler
this.mConversationHistory = new JSONArray(); // 初始化对话历史
// 添加 system message要求回答简洁
try {
JSONObject systemMessage = new JSONObject();
systemMessage.put("role", "system");
String systemPrompt = context.getString(R.string.system_prompt);
systemMessage.put("content", systemPrompt);
mConversationHistory.put(systemMessage);
} catch (JSONException e) {
Log.e(TAG, "Failed to add system message: " + e.getMessage());
}
}
/**
* 设置是否启用流式输出
* @param enableStreaming true: 启用流式输出false: 禁用流式输出(整段输出)
*/
public void setEnableStreaming(boolean enableStreaming) {
this.mEnableStreaming = enableStreaming;
}
/**
* 获取当前是否启用流式输出
* @return true: 启用流式输出false: 禁用流式输出(整段输出)
*/
public boolean isEnableStreaming() {
return mEnableStreaming;
}
public void callLLM(String userInput) {
new Thread(() -> {
try {
// 添加用户输入到对话历史
addMessageToHistory("user", userInput);
// 创建HTTP连接
URL url = new URL(LLM_API_URL);
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
conn.setRequestProperty("Authorization", "Bearer " + API_KEY);
conn.setDoOutput(true);
conn.setConnectTimeout(10000);
conn.setReadTimeout(60000); // 延长读取超时以支持流式响应
// 构建请求体
JSONObject requestBody = new JSONObject();
requestBody.put("model", LLM_MODEL);
requestBody.put("messages", mConversationHistory);
requestBody.put("stream", mEnableStreaming); // 根据配置决定是否启用流式响应
String jsonBody = requestBody.toString();
Log.d(TAG, "LLM Request: " + jsonBody);
// 发送请求
try (DataOutputStream dos = new DataOutputStream(conn.getOutputStream())) {
dos.write(jsonBody.getBytes(StandardCharsets.UTF_8));
dos.flush();
}
// 读取响应
int responseCode = conn.getResponseCode();
StringBuilder fullResponse = new StringBuilder();
StringBuilder accumulatedContent = new StringBuilder();
Log.d(TAG, "LLM Response Code: " + responseCode);
if (responseCode == 200) {
if (mEnableStreaming) {
// 逐行读取流式响应
try (BufferedReader br = new BufferedReader(
new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
Log.d(TAG, "LLM Streaming Line: " + line);
// 处理SSE格式的响应
if (line.startsWith("data: ")) {
String dataPart = line.substring(6);
if (dataPart.equals("[DONE]")) {
// 流式响应结束
break;
}
try {
// 解析JSON
JSONObject chunkObj = new JSONObject(dataPart);
JSONArray choices = chunkObj.getJSONArray("choices");
if (choices.length() > 0) {
JSONObject choice = choices.getJSONObject(0);
JSONObject delta = choice.getJSONObject("delta");
if (delta.has("content")) {
String chunkContent = delta.getString("content");
accumulatedContent.append(chunkContent);
// 发送流式chunk到监听器
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onLLMStreamingChunkReceived(chunkContent);
});
}
}
}
} catch (JSONException e) {
Log.e(TAG, "Failed to parse streaming chunk: " + e.getMessage());
}
}
fullResponse.append(line).append("\n");
}
}
} else {
// 读取完整响应
try (BufferedReader br = new BufferedReader(
new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
fullResponse.append(line);
}
}
// 解析完整JSON响应
try {
JSONObject responseObj = new JSONObject(fullResponse.toString());
JSONArray choices = responseObj.getJSONArray("choices");
if (choices.length() > 0) {
JSONObject choice = choices.getJSONObject(0);
JSONObject message = choice.getJSONObject("message");
if (message.has("content")) {
String content = message.getString("content");
accumulatedContent.append(content);
}
}
} catch (JSONException e) {
Log.e(TAG, "Failed to parse full response: " + e.getMessage());
}
}
String content = accumulatedContent.toString();
Log.d(TAG, "Full LLM Response: " + content);
// 添加AI回复到对话历史
addMessageToHistory("assistant", content);
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onLLMResponseReceived(content);
});
}
} else {
// 读取错误响应
StringBuilder errorResponse = new StringBuilder();
try (BufferedReader br = new BufferedReader(
new InputStreamReader(conn.getErrorStream(), StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
errorResponse.append(line);
}
}
throw new IOException("HTTP " + responseCode + ": " + errorResponse.toString());
}
} catch (Exception e) {
Log.e(TAG, "LLM call failed: " + e.getMessage());
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onError("LLM调用失败: " + e.getMessage());
});
}
}
}).start();
}
/**
* 添加消息到对话历史
*/
private void addMessageToHistory(String role, String content) {
try {
JSONObject message = new JSONObject();
message.put("role", role);
message.put("content", content);
mConversationHistory.put(message);
} catch (JSONException e) {
Log.e(TAG, "Failed to add message to history: " + e.getMessage());
}
}
/**
* 清空对话历史
*/
public void clearConversationHistory() {
mConversationHistory = new JSONArray();
}
public void callTTS(String text, File outputFile) {
if (mListener != null) {
mMainHandler.post(() -> {
mListener.onError("TTS功能暂未实现");
});
}
}
private String extractContentFromResponse(String response) {
try {
int contentStart = response.indexOf("\"content\":\"") + 11;
int contentEnd = response.indexOf("\"", contentStart);
if (contentStart > 10 && contentEnd > contentStart) {
return response.substring(contentStart, contentEnd);
}
} catch (Exception e) {
Log.e(TAG, "Failed to parse response: " + e.getMessage());
}
return "抱歉,无法解析响应";
}
}