From bd38ad2eaa5800270ccd094086b7b55489b65b9c Mon Sep 17 00:00:00 2001 From: xion Date: Tue, 1 Oct 2024 00:47:14 +0800 Subject: [PATCH] feat: add chat message --- packages/ai-lang/src/module/create-agent.ts | 4 +- src/models/agent.ts | 1 + src/models/chat-history.ts | 10 +- src/models/chat-prompt.ts | 7 +- src/models/chat-session.ts | 61 +++++++ src/models/code.ts | 11 +- src/models/prompt.ts | 1 + src/models/user.ts | 5 +- src/routes/chat-history/chat-io.ts | 184 ++++++++++++++++++++ src/routes/chat-history/index.ts | 3 + src/routes/chat-history/list.ts | 33 ++++ src/routes/chat-history/session-list.ts | 83 +++++++++ src/routes/chat-prompt/index.ts | 1 + src/routes/chat-prompt/list.ts | 126 ++++++++++++++ src/routes/container/models/index.ts | 1 + src/routes/index.ts | 4 + src/routes/page/models/index.ts | 1 + src/routes/resource/models/index.ts | 1 + src/routes/user/list.ts | 2 +- src/scripts/recover.ts | 27 +++ 20 files changed, 556 insertions(+), 10 deletions(-) create mode 100644 src/models/chat-session.ts create mode 100644 src/routes/chat-history/chat-io.ts create mode 100644 src/routes/chat-history/index.ts create mode 100644 src/routes/chat-history/list.ts create mode 100644 src/routes/chat-history/session-list.ts create mode 100644 src/routes/chat-prompt/index.ts create mode 100644 src/routes/chat-prompt/list.ts create mode 100644 src/scripts/recover.ts diff --git a/packages/ai-lang/src/module/create-agent.ts b/packages/ai-lang/src/module/create-agent.ts index cdecfab..e1cad9b 100644 --- a/packages/ai-lang/src/module/create-agent.ts +++ b/packages/ai-lang/src/module/create-agent.ts @@ -64,11 +64,11 @@ export class AiAgent { } const agentModel = this.agentModel; - const memoerSaver = this.memorySaver; + const memorySaver = this.memorySaver; this.agent = createReactAgent({ llm: agentModel, tools: [], - checkpointSaver: memoerSaver, + checkpointSaver: memorySaver, }); this.status = 'ready'; } diff --git a/src/models/agent.ts b/src/models/agent.ts index 331e253..34de2f0 100644 --- a/src/models/agent.ts +++ b/src/models/agent.ts @@ -86,6 +86,7 @@ AiAgent.init( { sequelize, tableName: 'ai_agent', + paranoid: true, }, ); AiAgent.sync({ alter: true, logging: false }).catch((e) => { diff --git a/src/models/chat-history.ts b/src/models/chat-history.ts index 8d40fd0..adcb7e8 100644 --- a/src/models/chat-history.ts +++ b/src/models/chat-history.ts @@ -12,6 +12,9 @@ export class ChatHistory extends Model { declare root: boolean; declare show: boolean; declare uid: string; + declare chatId: string; + declare chatPromptId: string; + declare role: string; } ChatHistory.init( @@ -37,6 +40,11 @@ ChatHistory.init( type: DataTypes.BOOLEAN, // 是否是根节点 defaultValue: false, }, + role: { + type: DataTypes.STRING, // 角色 + allowNull: true, + defaultValue: 'user', + }, show: { type: DataTypes.BOOLEAN, // 当创建返回的时候,配置是否显示 defaultValue: true, @@ -53,6 +61,6 @@ ChatHistory.init( ); // force 只能run一次,否则会删除表 -ChatHistory.sync({ alter: true, force: true, logging: false }).catch((e) => { +ChatHistory.sync({ alter: true, logging: false }).catch((e) => { console.error('History sync error', e); }); diff --git a/src/models/chat-prompt.ts b/src/models/chat-prompt.ts index d2fd59a..c25a7b8 100644 --- a/src/models/chat-prompt.ts +++ b/src/models/chat-prompt.ts @@ -7,6 +7,7 @@ export type ChatPromptData = { aiAgentId: string; // 使用那个初始化的prompt,如果不存在则纯粹的白对话。 promptId?: string; + chainPromptId?: string; }; /** * chat绑定就的agent和prompt @@ -18,7 +19,7 @@ export class ChatPrompt extends Model { declare description: string; declare uid: string; declare key: string; - declare data: string; + declare data: ChatPromptData; } ChatPrompt.init( @@ -43,6 +44,7 @@ ChatPrompt.init( key: { type: DataTypes.STRING, // 页面属于 /container/edit/list allowNull: false, + defaultValue: '', }, uid: { type: DataTypes.STRING, @@ -52,10 +54,11 @@ ChatPrompt.init( { sequelize, // 传入 Sequelize 实例 modelName: 'chat_prompt', // 模型名称 + paranoid: true, }, ); // force 只能run一次,否则会删除表 -ChatPrompt.sync({ alter: true, force: true, logging: false }).catch((e) => { +ChatPrompt.sync({ alter: true, logging: false }).catch((e) => { console.error('Prompt sync error', e); }); diff --git a/src/models/chat-session.ts b/src/models/chat-session.ts new file mode 100644 index 0000000..6430b6b --- /dev/null +++ b/src/models/chat-session.ts @@ -0,0 +1,61 @@ +import { sequelize } from '../modules/sequelize.ts'; +import { DataTypes, Model } from 'sequelize'; + +/** + * chat 回话记录 + * 有一些内容是预置的。 + */ +export class ChatSession extends Model { + declare id: string; + declare data: string; + declare chatPromptId: string; + declare type: string; + declare key: string; + declare title: string; + declare uid: string; +} + +ChatSession.init( + { + id: { + type: DataTypes.UUID, + primaryKey: true, + defaultValue: DataTypes.UUIDV4, + }, + data: { + type: DataTypes.JSON, + allowNull: true, + defaultValue: {}, + }, + chatPromptId: { + type: DataTypes.UUID, // 属于哪一个prompt + allowNull: true, + }, + type: { + type: DataTypes.STRING, // 属于测试的,还是正式的 + defaultValue: 'production', + }, + title: { + type: DataTypes.STRING, + allowNull: true, + defaultValue: '', + }, + key: { + type: DataTypes.STRING, // 页面属于 /container/edit/list + allowNull: true, + }, + uid: { + type: DataTypes.STRING, + allowNull: true, + }, + }, + { + sequelize, // 传入 Sequelize 实例 + modelName: 'chat_session', // 模型名称 + }, +); + +// force 只能run一次,否则会删除表 +ChatSession.sync({ alter: true, logging: false }).catch((e) => { + console.error('Sessuib sync error', e); +}); diff --git a/src/models/code.ts b/src/models/code.ts index dc7ca2a..94a4146 100644 --- a/src/models/code.ts +++ b/src/models/code.ts @@ -91,9 +91,14 @@ RouterCodeModel.init( { sequelize, tableName: 'cf_router_code', // container flow router code + paranoid: true, }, ); -RouterCodeModel.sync({ alter: true, logging: false }).catch((e) => { - console.error('RouterCodeModel sync', e); -}); +RouterCodeModel.sync({ alter: true, logging: false }) + .then((res) => { + console.log('RouterCodeModel sync', res); + }) + .catch((e) => { + console.error('RouterCodeModel sync', e); + }); // RouterCodeModel.sync({force: true}); diff --git a/src/models/prompt.ts b/src/models/prompt.ts index 414d47c..c6d4644 100644 --- a/src/models/prompt.ts +++ b/src/models/prompt.ts @@ -57,6 +57,7 @@ Prompt.init( { sequelize, // 传入 Sequelize 实例 modelName: 'prompt', // 模型名称 + paranoid: true, }, ); diff --git a/src/models/user.ts b/src/models/user.ts index ea37e92..e9a808c 100644 --- a/src/models/user.ts +++ b/src/models/user.ts @@ -24,7 +24,9 @@ export class User extends Model { return { token, expireTime: now + expireTime }; } static async verifyToken(token: string) { - return await checkToken(token, config.tokenSecret); + const ct = await checkToken(token, config.tokenSecret); + const tokenUser = ct.payload; + return tokenUser; } static createUser(username: string, password?: string, description?: string) { const user = User.findOne({ where: { username } }); @@ -79,6 +81,7 @@ User.init( { sequelize, tableName: 'cf_user', + paranoid: true, }, ); User.sync({ alter: true, logging: false }) diff --git a/src/routes/chat-history/chat-io.ts b/src/routes/chat-history/chat-io.ts new file mode 100644 index 0000000..e5831f9 --- /dev/null +++ b/src/routes/chat-history/chat-io.ts @@ -0,0 +1,184 @@ +import { app } from '@/app.ts'; +import { AiAgent } from '@/models/agent.ts'; +import { ChatPrompt } from '@/models/chat-prompt.ts'; +import { ChatSession } from '@/models/chat-session.ts'; +import { Prompt } from '@/models/prompt.ts'; +import { agentManger } from '@kevisual/ai-lang'; +import { PromptTemplate } from '@kevisual/ai-graph'; +import { v4 } from 'uuid'; +import { ChatHistory } from '@/models/chat-history.ts'; +import { User } from '@/models/user.ts'; +const clients = []; +export const compotedToken = () => { + // 计算token消耗 +}; +export const getConfigByKey = async (key) => { + const chatPrompt = await ChatPrompt.findOne({ where: { key } }); + const { promptId, aiAgentId } = chatPrompt.data; + const prompt = await Prompt.findByPk(promptId); + let aiAgent = agentManger.getAgent(aiAgentId); + if (!aiAgent) { + // throw new Error('aiAgent not found'); + const aiAgnetModel = await AiAgent.findByPk(aiAgentId); + aiAgent = agentManger.createAgent({ + id: aiAgnetModel.id, + type: aiAgnetModel.type as any, + model: aiAgnetModel.model as any, + baseUrl: aiAgnetModel.baseUrl, + apiKey: aiAgnetModel.apiKey, + temperature: aiAgnetModel.temperature, + cache: aiAgnetModel.cache as any, + cacheName: aiAgnetModel.cacheName, + }); + } + return { chatPrompt, prompt, aiAgent }; +}; +export const getTemplate = async ({ data, inputs }) => { + const promptTemplate = new PromptTemplate({ + prompt: data.prompt, + inputVariables: inputs.map((item) => { + return { + key: item.key, + value: item.value, + }; + }), + localVariables: [], + }); // 传入参数 + return await promptTemplate.getTemplate(); +}; +const onMessage = async ({ data, end, ws }) => { + // messages的 data + const client = clients.find((client) => client.ws === ws); + if (!client) { + end({ code: 404, data: {}, message: 'client not found' }); + return; + } + const { uid, id, key } = client.data; + const { inputs, message: sendMessage } = data; + let root = data.root || false; + let chatSession = await ChatSession.findByPk(id); + const config = await getConfigByKey(key); + const { prompt, aiAgent, chatPrompt } = config; + if (!chatSession) { + chatSession = await ChatSession.create({ key, id, data: {}, uid, chatPromptId: chatPrompt.id }); + root = true; + } else { + root = false; + const chatHistory = await ChatHistory.findAll({ where: { chatId: id }, logging: false }); + chatHistory.forEach((item) => { + end({ code: 200, data: item, message: 'success', type: 'messages' }); + }); + return; + } + + const template = await getTemplate({ data: prompt.presetData.data, inputs }); + if (!template) { + end({ code: 404, data: {}, message: 'template not found' }); + return; + } + const userQuestion = template || sendMessage; + // 保存到数据库 + const roleUser = await ChatHistory.create({ + data: { + message: userQuestion, + }, + chatId: id, + chatPromptId: chatPrompt.id, + root: root, + uid: uid, + show: true, + role: 'user', + }); + end({ code: 200, data: roleUser, message: 'success', type: 'messages' }); + const result = await aiAgent.sendHumanMessage(template, { thread_id: id }); + const lastMessage = result.messages[result.messages.length - 1]; + const message = result.messages[result.messages.length - 1].content; + // 根据key找到对应的prompt + // 保存到数据库 + const roleAi = await ChatHistory.create({ + data: { + message, + result: lastMessage, + }, + chatId: id, + chatPromptId: chatPrompt.id, + root: false, + uid: uid, + show: true, + role: 'ai', + }); + end({ code: 200, data: roleAi, message: 'success', type: 'messages' }); +}; +const getHistory = async (id: string, { data, end, ws }) => { + const chatHistory = await ChatHistory.findAll({ where: { chatId: id }, logging: false }); + chatHistory.forEach((item) => { + end({ code: 200, data: item, message: 'success', type: 'messages' }); + }); +}; +app.io.addListener('chat', async ({ data, end, ws }) => { + const { type } = data || {}; + if (type === 'subscribe') { + const token = data?.token; + if (!token) { + end({ code: 401, data: {}, message: 'need token' }); + return; + } + let tokenUesr; + try { + tokenUesr = await User.verifyToken(token); + } catch (e) { + end({ code: 401, data: {}, message: 'token is invaild' }); + return; + } + const uid = tokenUesr.id; + const id = v4(); + const clientData = { ...data?.data, uid }; + if (!clientData.id) { + clientData.id = id; + } + const client = clients.find((client) => client.ws === ws); + if (!client) { + clients.push({ ws, data: clientData }); // 拆包,里面包含的type信息,去掉 + } + end({ code: 200, data: clientData, message: 'subscribe success' }); + } else if (type === 'unsubscribe') { + // 需要手动取消订阅 + const index = clients.findIndex((client) => client.ws === ws); + if (index > -1) { + const data = clients[index]?.data; + clients.splice(index, 1); + end({ code: 200, data, message: 'unsubscribe success' }); + return; + } + end({ code: 200, data: {}, message: 'unsubscribe success' }); + return; + } else if (type === 'messages') { + try { + await onMessage({ data: data.data, end, ws }); + } catch (e) { + console.error('onMessage error', e); + end({ code: 500, data: {}, message: 'onMessage error' }); + } + return; + } else if (type === 'changeSession') { + // 修改client的session的id + const client = clients.find((client) => client.ws === ws); + if (!client) { + end({ code: 404, data: {}, message: 'client not found' }); + return; + } + const { id } = data?.data; + client.data.id = id || v4(); + // 返回修改后的history的内容 + end({ code: 200, data: client.data, message: 'changeSession success' }); + getHistory(id, { data, end, ws }); + return; + } else { + end({ code: 404, data: {}, message: 'subscribe fail' }); + return; + } + ws.on('close', () => { + const index = clients.findIndex((client) => client.ws === ws); + if (index > -1) clients.splice(index, 1); + }); +}); diff --git a/src/routes/chat-history/index.ts b/src/routes/chat-history/index.ts new file mode 100644 index 0000000..4da34c8 --- /dev/null +++ b/src/routes/chat-history/index.ts @@ -0,0 +1,3 @@ +import './list.ts' +import './session-list.ts' +import './chat-io.ts' \ No newline at end of file diff --git a/src/routes/chat-history/list.ts b/src/routes/chat-history/list.ts new file mode 100644 index 0000000..4ad45a7 --- /dev/null +++ b/src/routes/chat-history/list.ts @@ -0,0 +1,33 @@ +import { app } from '@/app.ts'; +import { ChatHistory } from '@/models/chat-history.ts'; +import { CustomError } from '@abearxiong/router'; + +// Admin only +app + .route({ + path: 'chat-history', + key: 'list', + }) + .define(async (ctx) => { + const chatPrompt = await ChatHistory.findAll({ + order: [['updatedAt', 'DESC']], + }); + ctx.body = chatPrompt; + }) + .addTo(app); + +app + .route({ + path: 'chat-history', + key: 'delete', + }) + .define(async (ctx) => { + const { id } = ctx.query; + const chatHistory = await ChatHistory.findByPk(id); + if (!chatHistory) { + throw new CustomError('ChatHistory not found'); + } + await chatHistory.destroy(); + ctx.body = chatHistory; + }) + .addTo(app); diff --git a/src/routes/chat-history/session-list.ts b/src/routes/chat-history/session-list.ts new file mode 100644 index 0000000..666ceb7 --- /dev/null +++ b/src/routes/chat-history/session-list.ts @@ -0,0 +1,83 @@ +import { app } from '@/app.ts'; +import { ChatSession } from '@/models/chat-session.ts'; +import { ChatPrompt } from '@/models/chat-prompt.ts'; +import { CustomError } from '@abearxiong/router'; +app + .route({ + path: 'chat-session', + key: 'list', + }) + .define(async (ctx) => { + const chatSession = await ChatSession.findAll({ + order: [['updatedAt', 'DESC']], + }); + ctx.body = chatSession; + }) + .addTo(app); +// Admin only +app + .route({ + path: 'chat-session', + key: 'list-history', + }) + .define(async (ctx) => { + const data = ctx.query.data; + const chatPrompt = await ChatPrompt.findOne({ + where: { + key: data.key, + }, + }); + if (!chatPrompt) { + throw new CustomError('ChatPrompt not found'); + } + console.log('chatPrompt', chatPrompt.id); + const chatSession = await ChatSession.findAll({ + order: [['updatedAt', 'DESC']], + where: { + chatPromptId: chatPrompt.id, + }, + limit: data.limit || 10, + }); + ctx.body = chatSession; + }) + .addTo(app); + +app + .route({ + path: 'chat-session', + key: 'update', + middleware: ['auth'], + }) + .define(async (ctx) => { + const tokenUser = ctx.state.tokenUser; + const uid = tokenUser.id; + const { id, ...data } = ctx.query.data; + if (id) { + const session = await ChatSession.findByPk(id); + if (session) { + await session.update(data); + } else { + throw new CustomError('Session not found'); + } + ctx.body = session; + return; + } + const session = await ChatSession.create({ ...data, uid }); + ctx.body = session; + }) + .addTo(app); +app + .route({ + path: 'chat-session', + key: 'delete', + }) + .define(async (ctx) => { + const { id } = ctx.query; + const session = await ChatSession.findByPk(id); + if (!session) { + throw new CustomError('Session not found'); + } + await session.destroy(); + ctx.body = session; + }) + .addTo(app); diff --git a/src/routes/chat-prompt/index.ts b/src/routes/chat-prompt/index.ts new file mode 100644 index 0000000..83ec5cd --- /dev/null +++ b/src/routes/chat-prompt/index.ts @@ -0,0 +1 @@ +import './list.ts'; diff --git a/src/routes/chat-prompt/list.ts b/src/routes/chat-prompt/list.ts new file mode 100644 index 0000000..7a6c84b --- /dev/null +++ b/src/routes/chat-prompt/list.ts @@ -0,0 +1,126 @@ +import { app } from '@/app.ts'; +import { AiAgent } from '@/models/agent.ts'; +import { ChatPrompt } from '@/models/chat-prompt.ts'; +import { Prompt } from '@/models/prompt.ts'; +import { CustomError } from '@abearxiong/router'; + +// Admin only +app + .route({ + path: 'chat-prompt', + key: 'list', + }) + .define(async (ctx) => { + const chatPrompt = await ChatPrompt.findAll({ + order: [['updatedAt', 'DESC']], + }); + ctx.body = chatPrompt; + }) + .addTo(app); + +app + .route({ + path: 'chat-prompt', + key: 'get', + validator: { + id: { + type: 'string', + required: true, + }, + }, + }) + .define(async (ctx) => { + const { id } = ctx.query; + const chatPrompt = await ChatPrompt.findByPk(id); + if (!chatPrompt) { + throw new CustomError('ChatPrompt not found'); + } + ctx.body = chatPrompt; + }) + .addTo(app); + +app + .route({ + path: 'chat-prompt', + key: 'update', + middleware: ['auth'], + }) + .define(async (ctx) => { + const tokenUser = ctx.state.tokenUser; + const uid = tokenUser.id; + const { data } = ctx.query; + const { id, ...rest } = data; + if (id) { + const page = await ChatPrompt.findByPk(id); + if (page) { + if (rest.data) { + rest.data = { ...page.data, ...rest.data }; + } + const newPage = await page.update(rest); + ctx.body = newPage; + } else { + throw new CustomError('page not found'); + } + } else if (data) { + const page = await ChatPrompt.create({ ...rest, uid }); + ctx.body = page; + } + }) + .addTo(app); + +app + .route({ + path: 'chat-prompt', + key: 'delete', + validator: { + id: { + type: 'string', + required: true, + }, + }, + }) + .define(async (ctx) => { + const id = ctx.query.id; + const chatPrompt = await ChatPrompt.findByPk(id); + if (!chatPrompt) { + throw new CustomError('chatPrompt not found'); + } + await chatPrompt.destroy(); + ctx.body = 'success'; + }) + .addTo(app); +app + .route({ + path: 'chat-prompt', + key: 'getByKey', + }) + .define(async (ctx) => { + const { key } = ctx.query.data || {}; + if (!key) { + throw new CustomError('key is required'); + } + const chatPrompt = await ChatPrompt.findOne({ + where: { key }, + }); + if (!chatPrompt) { + throw new CustomError('chatPrompt not found'); + } + const { promptId, aiAgentId } = chatPrompt.data; + if (!aiAgentId) { + throw new CustomError('promptId is required'); + } + const aiAgent = await AiAgent.findByPk(aiAgentId, { + // 只获取 id 和description 字段 + attributes: ['id', 'description', 'key'], + }); + if (!aiAgent) { + throw new CustomError('aiAgent not found'); + } + const prompt = await Prompt.findByPk(promptId); + ctx.body = { + chatPrompt: chatPrompt, + aiAgent, + prompt, + }; + }) + .addTo(app); diff --git a/src/routes/container/models/index.ts b/src/routes/container/models/index.ts index 92bf904..52f7bbc 100644 --- a/src/routes/container/models/index.ts +++ b/src/routes/container/models/index.ts @@ -79,6 +79,7 @@ ContainerModel.init( { sequelize, tableName: 'kv_container', + paranoid: true, }, ); diff --git a/src/routes/index.ts b/src/routes/index.ts index 63ec4ff..afc80cd 100644 --- a/src/routes/index.ts +++ b/src/routes/index.ts @@ -9,3 +9,7 @@ import './prompt-graph/index.ts'; import './agent/index.ts'; import './user/index.ts'; + +import './chat-prompt/index.ts'; + +import './chat-history/index.ts'; diff --git a/src/routes/page/models/index.ts b/src/routes/page/models/index.ts index f6bdbd9..54eb876 100644 --- a/src/routes/page/models/index.ts +++ b/src/routes/page/models/index.ts @@ -68,6 +68,7 @@ PageModel.init( { sequelize, tableName: 'kv_page', + paranoid: true, }, ); diff --git a/src/routes/resource/models/index.ts b/src/routes/resource/models/index.ts index 42a5b93..0a6a352 100644 --- a/src/routes/resource/models/index.ts +++ b/src/routes/resource/models/index.ts @@ -82,6 +82,7 @@ ResourceModel.init( { sequelize, tableName: 'kv_resource', + paranoid: true, }, ); diff --git a/src/routes/user/list.ts b/src/routes/user/list.ts index 23ca556..e9fd977 100644 --- a/src/routes/user/list.ts +++ b/src/routes/user/list.ts @@ -36,7 +36,7 @@ app const { checkToken: token } = ctx.query; try { const result = await User.verifyToken(token); - ctx.body = result?.payload || {}; + ctx.body = result || {}; } catch (e) { new CustomError(401, 'Token InValid '); } diff --git a/src/scripts/recover.ts b/src/scripts/recover.ts new file mode 100644 index 0000000..37243b8 --- /dev/null +++ b/src/scripts/recover.ts @@ -0,0 +1,27 @@ +import { ContainerModel } from '@/routes/container/models/index.ts'; + +const recoverData = async () => { + const data = { + id: '868970a4-8cab-4141-a73c-cc185fd17508', + title: '测试es6每次导入的变量,运行一次+1,并打印', + description: '', + tags: [], + type: '', + code: "let a = 1\n\nexport const main = () => {\n console.log('current a', a);\n return a++\n}", + source: '', + sourceType: '', + data: { + className: '', + style: {}, + showChild: true, + shadowRoot: false, + }, + publish: {}, + uid: '14206305-8b5c-44cc-b177-766cfe2e452f', + createdAt: '2024-09-19T13:27:58.796Z', + updatedAt: '2024-09-28T05:27:05.381Z', + }; + const r = await ContainerModel.create(data); +}; + +recoverData();