fix: 修改错误

This commit is contained in:
xion 2024-10-02 14:13:00 +08:00
parent 8436b0462a
commit 1f81d3400c

View File

@ -54,29 +54,46 @@ const onMessage = async ({ data, end, ws }) => {
return; return;
} }
const { uid, id, key } = client.data; const { uid, id, key } = client.data;
const { inputs, message: sendMessage } = data; const {
inputs,
message: sendMessage,
data: {},
} = data;
let root = data.root || false; let root = data.root || false;
let chatSession = await ChatSession.findByPk(id); let chatSession = await ChatSession.findByPk(id);
const config = await getConfigByKey(key); const config = await getConfigByKey(key);
const { prompt, aiAgent, chatPrompt } = config; const { prompt, aiAgent, chatPrompt } = config;
let userQuestion = sendMessage;
if (!chatSession) { if (!chatSession) {
chatSession = await ChatSession.create({ key, id, data: {}, uid, chatPromptId: chatPrompt.id }); chatSession = await ChatSession.create({ key, id, data: data, uid, chatPromptId: chatPrompt.id });
root = true; root = true;
} else { } else {
root = false; // 更新session context的值
const newData = JSON.parse(data);
if (newData !== '{}' && JSON.stringify(chatSession.data) !== JSON.stringify(data)) {
await chatSession.update({ data: data });
}
if (root) {
const chatHistory = await ChatHistory.findAll({ where: { chatId: id }, logging: false }); const chatHistory = await ChatHistory.findAll({ where: { chatId: id }, logging: false });
chatHistory.forEach((item) => { chatHistory.forEach((item) => {
end({ code: 200, data: item, message: 'success', type: 'messages' }); end({ code: 200, data: item, message: 'success', type: 'messages' });
}); });
// return;
}
root = false;
}
if (!userQuestion) {
if (!prompt?.presetData) {
end({ code: 404, data: {}, message: 'prompt not set, need presetData' });
return; return;
} }
const template = await getTemplate({ data: prompt.presetData.data, inputs }); const template = await getTemplate({ data: prompt.presetData.data, inputs });
if (!template) { if (!template) {
end({ code: 404, data: {}, message: 'template not found' }); end({ code: 404, data: {}, message: 'template not found' });
return; return;
} }
const userQuestion = template || sendMessage; userQuestion = template;
}
// 保存到数据库 // 保存到数据库
const roleUser = await ChatHistory.create({ const roleUser = await ChatHistory.create({
data: { data: {
@ -90,7 +107,7 @@ const onMessage = async ({ data, end, ws }) => {
role: 'user', role: 'user',
}); });
end({ code: 200, data: roleUser, message: 'success', type: 'messages' }); end({ code: 200, data: roleUser, message: 'success', type: 'messages' });
const result = await aiAgent.sendHumanMessage(template, { thread_id: id }); const result = await aiAgent.sendHumanMessage(userQuestion, { thread_id: id });
const lastMessage = result.messages[result.messages.length - 1]; const lastMessage = result.messages[result.messages.length - 1];
const message = result.messages[result.messages.length - 1].content; const message = result.messages[result.messages.length - 1].content;
// 根据key找到对应的prompt // 根据key找到对应的prompt