264 lines
11 KiB
Java
264 lines
11 KiB
Java
package com.digitalperson.cloud;
|
||
|
||
import android.content.Context;
|
||
import android.os.Handler;
|
||
import android.os.Looper;
|
||
import android.util.Log;
|
||
|
||
import com.digitalperson.BuildConfig;
|
||
import com.digitalperson.R;
|
||
|
||
import org.json.JSONArray;
|
||
import org.json.JSONException;
|
||
import org.json.JSONObject;
|
||
|
||
import java.io.BufferedReader;
|
||
import java.io.DataOutputStream;
|
||
import java.io.File;
|
||
import java.io.IOException;
|
||
import java.io.InputStreamReader;
|
||
import java.net.HttpURLConnection;
|
||
import java.net.URL;
|
||
import java.nio.charset.StandardCharsets;
|
||
|
||
public class CloudApiManager {
|
||
private static final String TAG = "CloudApiManager";
|
||
|
||
// 火山引擎OpenAI兼容API配置
|
||
private static final String LLM_API_URL = BuildConfig.LLM_API_URL;
|
||
private static final String API_KEY = BuildConfig.LLM_API_KEY;
|
||
private static final String LLM_MODEL = BuildConfig.LLM_MODEL;
|
||
|
||
private CloudApiListener mListener;
|
||
private Handler mMainHandler; // 用于在主线程执行UI更新
|
||
private JSONArray mConversationHistory; // 存储对话历史
|
||
private boolean mEnableStreaming = true; // 默认启用流式输出
|
||
|
||
public interface CloudApiListener {
|
||
void onLLMResponseReceived(String response);
|
||
void onLLMStreamingChunkReceived(String chunk);
|
||
void onTTSAudioReceived(String audioFilePath);
|
||
void onError(String errorMessage);
|
||
}
|
||
|
||
public CloudApiManager(CloudApiListener listener, Context context) {
|
||
this.mListener = listener;
|
||
this.mMainHandler = new Handler(Looper.getMainLooper()); // 初始化主线程Handler
|
||
this.mConversationHistory = new JSONArray(); // 初始化对话历史
|
||
|
||
// 添加 system message,要求回答简洁
|
||
try {
|
||
JSONObject systemMessage = new JSONObject();
|
||
systemMessage.put("role", "system");
|
||
String systemPrompt = context.getString(R.string.system_prompt);
|
||
systemMessage.put("content", systemPrompt);
|
||
mConversationHistory.put(systemMessage);
|
||
} catch (JSONException e) {
|
||
Log.e(TAG, "Failed to add system message: " + e.getMessage());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 设置是否启用流式输出
|
||
* @param enableStreaming true: 启用流式输出,false: 禁用流式输出(整段输出)
|
||
*/
|
||
public void setEnableStreaming(boolean enableStreaming) {
|
||
this.mEnableStreaming = enableStreaming;
|
||
}
|
||
|
||
/**
|
||
* 获取当前是否启用流式输出
|
||
* @return true: 启用流式输出,false: 禁用流式输出(整段输出)
|
||
*/
|
||
public boolean isEnableStreaming() {
|
||
return mEnableStreaming;
|
||
}
|
||
|
||
public void callLLM(String userInput) {
|
||
new Thread(() -> {
|
||
try {
|
||
// 添加用户输入到对话历史
|
||
addMessageToHistory("user", userInput);
|
||
|
||
// 创建HTTP连接
|
||
URL url = new URL(LLM_API_URL);
|
||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
||
conn.setRequestMethod("POST");
|
||
conn.setRequestProperty("Content-Type", "application/json");
|
||
conn.setRequestProperty("Authorization", "Bearer " + API_KEY);
|
||
conn.setDoOutput(true);
|
||
conn.setConnectTimeout(10000);
|
||
conn.setReadTimeout(60000); // 延长读取超时以支持流式响应
|
||
|
||
// 构建请求体
|
||
JSONObject requestBody = new JSONObject();
|
||
requestBody.put("model", LLM_MODEL);
|
||
requestBody.put("messages", mConversationHistory);
|
||
requestBody.put("stream", mEnableStreaming); // 根据配置决定是否启用流式响应
|
||
|
||
String jsonBody = requestBody.toString();
|
||
|
||
Log.d(TAG, "LLM Request: " + jsonBody);
|
||
|
||
// 发送请求
|
||
try (DataOutputStream dos = new DataOutputStream(conn.getOutputStream())) {
|
||
dos.write(jsonBody.getBytes(StandardCharsets.UTF_8));
|
||
dos.flush();
|
||
}
|
||
|
||
// 读取响应
|
||
int responseCode = conn.getResponseCode();
|
||
StringBuilder fullResponse = new StringBuilder();
|
||
StringBuilder accumulatedContent = new StringBuilder();
|
||
|
||
Log.d(TAG, "LLM Response Code: " + responseCode);
|
||
|
||
if (responseCode == 200) {
|
||
if (mEnableStreaming) {
|
||
// 逐行读取流式响应
|
||
try (BufferedReader br = new BufferedReader(
|
||
new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
|
||
String line;
|
||
while ((line = br.readLine()) != null) {
|
||
Log.d(TAG, "LLM Streaming Line: " + line);
|
||
|
||
// 处理SSE格式的响应
|
||
if (line.startsWith("data: ")) {
|
||
String dataPart = line.substring(6);
|
||
if (dataPart.equals("[DONE]")) {
|
||
// 流式响应结束
|
||
break;
|
||
}
|
||
|
||
try {
|
||
// 解析JSON
|
||
JSONObject chunkObj = new JSONObject(dataPart);
|
||
JSONArray choices = chunkObj.getJSONArray("choices");
|
||
if (choices.length() > 0) {
|
||
JSONObject choice = choices.getJSONObject(0);
|
||
JSONObject delta = choice.getJSONObject("delta");
|
||
|
||
if (delta.has("content")) {
|
||
String chunkContent = delta.getString("content");
|
||
accumulatedContent.append(chunkContent);
|
||
|
||
// 发送流式chunk到监听器
|
||
if (mListener != null) {
|
||
mMainHandler.post(() -> {
|
||
mListener.onLLMStreamingChunkReceived(chunkContent);
|
||
});
|
||
}
|
||
}
|
||
}
|
||
} catch (JSONException e) {
|
||
Log.e(TAG, "Failed to parse streaming chunk: " + e.getMessage());
|
||
}
|
||
}
|
||
|
||
fullResponse.append(line).append("\n");
|
||
}
|
||
}
|
||
} else {
|
||
// 读取完整响应
|
||
try (BufferedReader br = new BufferedReader(
|
||
new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
|
||
String line;
|
||
while ((line = br.readLine()) != null) {
|
||
fullResponse.append(line);
|
||
}
|
||
}
|
||
|
||
// 解析完整JSON响应
|
||
try {
|
||
JSONObject responseObj = new JSONObject(fullResponse.toString());
|
||
JSONArray choices = responseObj.getJSONArray("choices");
|
||
if (choices.length() > 0) {
|
||
JSONObject choice = choices.getJSONObject(0);
|
||
JSONObject message = choice.getJSONObject("message");
|
||
if (message.has("content")) {
|
||
String content = message.getString("content");
|
||
accumulatedContent.append(content);
|
||
}
|
||
}
|
||
} catch (JSONException e) {
|
||
Log.e(TAG, "Failed to parse full response: " + e.getMessage());
|
||
}
|
||
}
|
||
|
||
String content = accumulatedContent.toString();
|
||
Log.d(TAG, "Full LLM Response: " + content);
|
||
|
||
// 添加AI回复到对话历史
|
||
addMessageToHistory("assistant", content);
|
||
|
||
if (mListener != null) {
|
||
mMainHandler.post(() -> {
|
||
mListener.onLLMResponseReceived(content);
|
||
});
|
||
}
|
||
} else {
|
||
// 读取错误响应
|
||
StringBuilder errorResponse = new StringBuilder();
|
||
try (BufferedReader br = new BufferedReader(
|
||
new InputStreamReader(conn.getErrorStream(), StandardCharsets.UTF_8))) {
|
||
String line;
|
||
while ((line = br.readLine()) != null) {
|
||
errorResponse.append(line);
|
||
}
|
||
}
|
||
throw new IOException("HTTP " + responseCode + ": " + errorResponse.toString());
|
||
}
|
||
|
||
} catch (Exception e) {
|
||
Log.e(TAG, "LLM call failed: " + e.getMessage());
|
||
if (mListener != null) {
|
||
mMainHandler.post(() -> {
|
||
mListener.onError("LLM调用失败: " + e.getMessage());
|
||
});
|
||
}
|
||
}
|
||
}).start();
|
||
}
|
||
|
||
/**
|
||
* 添加消息到对话历史
|
||
*/
|
||
private void addMessageToHistory(String role, String content) {
|
||
try {
|
||
JSONObject message = new JSONObject();
|
||
message.put("role", role);
|
||
message.put("content", content);
|
||
mConversationHistory.put(message);
|
||
} catch (JSONException e) {
|
||
Log.e(TAG, "Failed to add message to history: " + e.getMessage());
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 清空对话历史
|
||
*/
|
||
public void clearConversationHistory() {
|
||
mConversationHistory = new JSONArray();
|
||
}
|
||
|
||
public void callTTS(String text, File outputFile) {
|
||
if (mListener != null) {
|
||
mMainHandler.post(() -> {
|
||
mListener.onError("TTS功能暂未实现");
|
||
});
|
||
}
|
||
}
|
||
|
||
private String extractContentFromResponse(String response) {
|
||
try {
|
||
int contentStart = response.indexOf("\"content\":\"") + 11;
|
||
int contentEnd = response.indexOf("\"", contentStart);
|
||
if (contentStart > 10 && contentEnd > contentStart) {
|
||
return response.substring(contentStart, contentEnd);
|
||
}
|
||
} catch (Exception e) {
|
||
Log.e(TAG, "Failed to parse response: " + e.getMessage());
|
||
}
|
||
return "抱歉,无法解析响应";
|
||
}
|
||
} |