import { OpenAI } from 'openai'; import type { BaseChatInterface, ChatMessageComplete, ChatMessage, ChatMessageOptions, BaseChatUsageInterface, ChatStream, EmbeddingMessage, EmbeddingMessageComplete, } from './type.ts'; export type BaseChatOptions> = { /** * 默认baseURL */ baseURL?: string; /** * 默认模型 */ model?: string; /** * 默认apiKey */ apiKey: string; /** * 是否在浏览器中使用 */ isBrowser?: boolean; /** * 是否流式输出, 默认 false */ stream?: boolean; } & T; export const getIsBrowser = () => { try { // 检查是否存在window对象 return typeof window !== 'undefined' && typeof window.document !== 'undefined'; } catch (e) { return false; } }; export class BaseChat implements BaseChatInterface, BaseChatUsageInterface { /** * 默认baseURL */ baseURL: string; /** * 默认模型 */ model: string; /** * 默认apiKey */ apiKey: string; /** * 是否在浏览器中使用 */ isBrowser: boolean; /** * openai实例 */ openai: OpenAI; prompt_tokens: number; total_tokens: number; completion_tokens: number; responseText: string; constructor(options: BaseChatOptions) { this.baseURL = options.baseURL; this.model = options.model; this.apiKey = options.apiKey; // @ts-ignore const DEFAULT_IS_BROWSER = getIsBrowser(); this.isBrowser = options.isBrowser ?? DEFAULT_IS_BROWSER; // this.openai = new OpenAI({ // apiKey: this.apiKey, // baseURL: this.baseURL, // dangerouslyAllowBrowser: options?.dangerouslyAllowBrowser ?? this.isBrowser, // }); } post(url = '', opts: { headers?: Record, data?: any } = {}) { let _url = url.startsWith('http') ? url : this.baseURL + url; return fetch(_url, { method: 'POST', ...opts, headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${this.apiKey}`, ...opts.headers, }, body: opts?.data ? JSON.stringify(opts.data) : undefined, }); } async get(url = '', opts: { headers?: Record } = {}): Promise { let _url = url.startsWith('http') ? url : this.baseURL + url; return fetch(_url, { method: 'GET', ...opts, headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${this.apiKey}`, ...opts.headers, }, }).then((res) => res.json()); } /** * 聊天 */ async chat(messages: ChatMessage[], options?: ChatMessageOptions): Promise { const requestBody = { model: this.model, messages, ...options, stream: false, }; const response = await this.post(`${this.baseURL}/chat/completions`, { data: requestBody }); if (!response.ok) { const errorText = await response.text(); throw new Error(`Chat API request failed: ${response.status} ${response.statusText} - ${errorText}`); } const res = await response.json() as ChatMessageComplete; this.prompt_tokens = res.usage?.prompt_tokens ?? 0; this.total_tokens = res.usage?.total_tokens ?? 0; this.completion_tokens = res.usage?.completion_tokens ?? 0; this.responseText = res.choices[0]?.message?.content || ''; return res; } async chatStream(messages: ChatMessage[], options?: ChatMessageOptions) { if (options?.response_format) { throw new Error('response_format is not supported in stream mode'); } const requestBody = { model: this.model, messages, ...options, stream: true, }; const response = await this.post(`${this.baseURL}/chat/completions`, { data: requestBody }); if (!response.ok) { const errorText = await response.text(); throw new Error(`Chat Stream API request failed: ${response.status} ${response.statusText} - ${errorText}`); } const decoder = new TextDecoder(); const reader = response.body?.getReader(); if (!reader) { throw new Error('Response body is not readable'); } // 创建一个新的 ReadableStream,使用 decoder 解析数据 const stream = new ReadableStream({ async start(controller) { try { while (true) { const { done, value } = await reader.read(); if (done) { controller.close(); break; } // 检查 value 类型,如果是 Uint8Array 才解码,否则直接使用 if (typeof value === 'string') { controller.enqueue(value); } else if (value instanceof Uint8Array) { const text = decoder.decode(value, { stream: true }); controller.enqueue(text); } else { controller.enqueue(value); } } } catch (error) { controller.error(error); } }, cancel() { reader.releaseLock(); } }); return stream as unknown as ChatStream; } /** * 获取聊天使用情况 * @returns */ getChatUsage() { return { prompt_tokens: this.prompt_tokens, total_tokens: this.total_tokens, completion_tokens: this.completion_tokens, }; } getHeaders(headers?: Record) { return { 'Content-Type': 'application/json', Authorization: `Bearer ${this.apiKey}`, ...headers, }; } /** * 生成embedding 内部 * @param text * @returns */ async generateEmbeddingCore(text: string | string[], options?: EmbeddingMessage): Promise { const embeddingModel = options?.model || this.model; const requestBody = { model: embeddingModel, input: text, encoding_format: 'float', ...options, }; const response = await this.post(`${this.baseURL}/embeddings`, { data: requestBody }); if (!response.ok) { const errorText = await response.text(); throw new Error(`Embedding API request failed: ${response.status} ${response.statusText} - ${errorText}`); } const res = await response.json() as EmbeddingMessageComplete; this.prompt_tokens += res.usage.prompt_tokens; this.total_tokens += res.usage.total_tokens; return res; } }