Files
ai/src/provider/core/chat.ts
2025-11-24 04:26:57 +08:00

240 lines
6.2 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { OpenAI } from 'openai';
import type {
BaseChatInterface,
ChatMessageComplete,
ChatMessage,
ChatMessageOptions,
BaseChatUsageInterface,
ChatStream,
EmbeddingMessage,
EmbeddingMessageComplete,
} from './type.ts';
export type BaseChatOptions<T = Record<string, any>> = {
/**
* 默认baseURL
*/
baseURL?: string;
/**
* 默认模型
*/
model?: string;
/**
* 默认apiKey
*/
apiKey: string;
/**
* 是否在浏览器中使用
*/
isBrowser?: boolean;
/**
* 是否流式输出, 默认 false
*/
stream?: boolean;
} & T;
export const getIsBrowser = () => {
try {
// 检查是否存在window对象
return typeof window !== 'undefined' && typeof window.document !== 'undefined';
} catch (e) {
return false;
}
};
export class BaseChat implements BaseChatInterface, BaseChatUsageInterface {
/**
* 默认baseURL
*/
baseURL: string;
/**
* 默认模型
*/
model: string;
/**
* 默认apiKey
*/
apiKey: string;
/**
* 是否在浏览器中使用
*/
isBrowser: boolean;
/**
* openai实例
*/
openai: OpenAI;
prompt_tokens: number;
total_tokens: number;
completion_tokens: number;
responseText: string;
constructor(options: BaseChatOptions) {
this.baseURL = options.baseURL;
this.model = options.model;
this.apiKey = options.apiKey;
// @ts-ignore
const DEFAULT_IS_BROWSER = getIsBrowser();
this.isBrowser = options.isBrowser ?? DEFAULT_IS_BROWSER;
// this.openai = new OpenAI({
// apiKey: this.apiKey,
// baseURL: this.baseURL,
// dangerouslyAllowBrowser: options?.dangerouslyAllowBrowser ?? this.isBrowser,
// });
}
post(url = '', opts: { headers?: Record<string, string>, data?: any } = {}) {
let _url = url.startsWith('http') ? url : this.baseURL + url;
return fetch(_url, {
method: 'POST',
...opts,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
...opts.headers,
},
body: opts?.data ? JSON.stringify(opts.data) : undefined,
});
}
async get<T = any>(url = '', opts: { headers?: Record<string, string> } = {}): Promise<T> {
let _url = url.startsWith('http') ? url : this.baseURL + url;
return fetch(_url, {
method: 'GET',
...opts,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
...opts.headers,
},
}).then((res) => res.json());
}
/**
* 聊天
*/
async chat(messages: ChatMessage[], options?: ChatMessageOptions): Promise<ChatMessageComplete> {
const requestBody = {
model: this.model,
messages,
...options,
stream: false,
};
const response = await this.post(`${this.baseURL}/chat/completions`, { data: requestBody });
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Chat API request failed: ${response.status} ${response.statusText} - ${errorText}`);
}
const res = await response.json() as ChatMessageComplete;
this.prompt_tokens = res.usage?.prompt_tokens ?? 0;
this.total_tokens = res.usage?.total_tokens ?? 0;
this.completion_tokens = res.usage?.completion_tokens ?? 0;
this.responseText = res.choices[0]?.message?.content || '';
return res;
}
async chatStream(messages: ChatMessage[], options?: ChatMessageOptions) {
if (options?.response_format) {
throw new Error('response_format is not supported in stream mode');
}
const requestBody = {
model: this.model,
messages,
...options,
stream: true,
};
const response = await this.post(`${this.baseURL}/chat/completions`, { data: requestBody });
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Chat Stream API request failed: ${response.status} ${response.statusText} - ${errorText}`);
}
const decoder = new TextDecoder();
const reader = response.body?.getReader();
if (!reader) {
throw new Error('Response body is not readable');
}
// 创建一个新的 ReadableStream使用 decoder 解析数据
const stream = new ReadableStream({
async start(controller) {
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
controller.close();
break;
}
// 检查 value 类型,如果是 Uint8Array 才解码,否则直接使用
if (typeof value === 'string') {
controller.enqueue(value);
} else if (value instanceof Uint8Array) {
const text = decoder.decode(value, { stream: true });
controller.enqueue(text);
} else {
controller.enqueue(value);
}
}
} catch (error) {
controller.error(error);
}
},
cancel() {
reader.releaseLock();
}
});
return stream as unknown as ChatStream;
}
/**
* 获取聊天使用情况
* @returns
*/
getChatUsage() {
return {
prompt_tokens: this.prompt_tokens,
total_tokens: this.total_tokens,
completion_tokens: this.completion_tokens,
};
}
getHeaders(headers?: Record<string, string>) {
return {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`,
...headers,
};
}
/**
* 生成embedding 内部
* @param text
* @returns
*/
async generateEmbeddingCore(text: string | string[], options?: EmbeddingMessage): Promise<EmbeddingMessageComplete> {
const embeddingModel = options?.model || this.model;
const requestBody = {
model: embeddingModel,
input: text,
encoding_format: 'float',
...options,
};
const response = await this.post(`${this.baseURL}/embeddings`, { data: requestBody });
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Embedding API request failed: ${response.status} ${response.statusText} - ${errorText}`);
}
const res = await response.json() as EmbeddingMessageComplete;
this.prompt_tokens += res.usage.prompt_tokens;
this.total_tokens += res.usage.total_tokens;
return res;
}
}