240 lines
6.2 KiB
TypeScript
240 lines
6.2 KiB
TypeScript
import { OpenAI } from 'openai';
|
||
import type {
|
||
BaseChatInterface,
|
||
ChatMessageComplete,
|
||
ChatMessage,
|
||
ChatMessageOptions,
|
||
BaseChatUsageInterface,
|
||
ChatStream,
|
||
EmbeddingMessage,
|
||
EmbeddingMessageComplete,
|
||
} from './type.ts';
|
||
|
||
export type BaseChatOptions<T = Record<string, any>> = {
|
||
/**
|
||
* 默认baseURL
|
||
*/
|
||
baseURL?: string;
|
||
/**
|
||
* 默认模型
|
||
*/
|
||
model?: string;
|
||
/**
|
||
* 默认apiKey
|
||
*/
|
||
apiKey: string;
|
||
/**
|
||
* 是否在浏览器中使用
|
||
*/
|
||
isBrowser?: boolean;
|
||
/**
|
||
* 是否流式输出, 默认 false
|
||
*/
|
||
stream?: boolean;
|
||
} & T;
|
||
export const getIsBrowser = () => {
|
||
try {
|
||
// 检查是否存在window对象
|
||
return typeof window !== 'undefined' && typeof window.document !== 'undefined';
|
||
} catch (e) {
|
||
return false;
|
||
}
|
||
};
|
||
export class BaseChat implements BaseChatInterface, BaseChatUsageInterface {
|
||
/**
|
||
* 默认baseURL
|
||
*/
|
||
baseURL: string;
|
||
/**
|
||
* 默认模型
|
||
*/
|
||
model: string;
|
||
/**
|
||
* 默认apiKey
|
||
*/
|
||
apiKey: string;
|
||
/**
|
||
* 是否在浏览器中使用
|
||
*/
|
||
isBrowser: boolean;
|
||
/**
|
||
* openai实例
|
||
*/
|
||
openai: OpenAI;
|
||
|
||
prompt_tokens: number;
|
||
total_tokens: number;
|
||
completion_tokens: number;
|
||
responseText: string;
|
||
|
||
constructor(options: BaseChatOptions) {
|
||
this.baseURL = options.baseURL;
|
||
this.model = options.model;
|
||
this.apiKey = options.apiKey;
|
||
// @ts-ignore
|
||
const DEFAULT_IS_BROWSER = getIsBrowser();
|
||
this.isBrowser = options.isBrowser ?? DEFAULT_IS_BROWSER;
|
||
// this.openai = new OpenAI({
|
||
// apiKey: this.apiKey,
|
||
// baseURL: this.baseURL,
|
||
// dangerouslyAllowBrowser: options?.dangerouslyAllowBrowser ?? this.isBrowser,
|
||
// });
|
||
}
|
||
post(url = '', opts: { headers?: Record<string, string>, data?: any } = {}) {
|
||
let _url = url.startsWith('http') ? url : this.baseURL + url;
|
||
return fetch(_url, {
|
||
method: 'POST',
|
||
...opts,
|
||
headers: {
|
||
'Content-Type': 'application/json',
|
||
Authorization: `Bearer ${this.apiKey}`,
|
||
...opts.headers,
|
||
},
|
||
body: opts?.data ? JSON.stringify(opts.data) : undefined,
|
||
});
|
||
}
|
||
async get<T = any>(url = '', opts: { headers?: Record<string, string> } = {}): Promise<T> {
|
||
let _url = url.startsWith('http') ? url : this.baseURL + url;
|
||
return fetch(_url, {
|
||
method: 'GET',
|
||
...opts,
|
||
headers: {
|
||
'Content-Type': 'application/json',
|
||
Authorization: `Bearer ${this.apiKey}`,
|
||
...opts.headers,
|
||
},
|
||
}).then((res) => res.json());
|
||
}
|
||
/**
|
||
* 聊天
|
||
*/
|
||
async chat(messages: ChatMessage[], options?: ChatMessageOptions): Promise<ChatMessageComplete> {
|
||
const requestBody = {
|
||
model: this.model,
|
||
messages,
|
||
...options,
|
||
stream: false,
|
||
};
|
||
|
||
const response = await this.post(`${this.baseURL}/chat/completions`, { data: requestBody });
|
||
|
||
if (!response.ok) {
|
||
const errorText = await response.text();
|
||
throw new Error(`Chat API request failed: ${response.status} ${response.statusText} - ${errorText}`);
|
||
}
|
||
|
||
const res = await response.json() as ChatMessageComplete;
|
||
|
||
this.prompt_tokens = res.usage?.prompt_tokens ?? 0;
|
||
this.total_tokens = res.usage?.total_tokens ?? 0;
|
||
this.completion_tokens = res.usage?.completion_tokens ?? 0;
|
||
this.responseText = res.choices[0]?.message?.content || '';
|
||
return res;
|
||
}
|
||
async chatStream(messages: ChatMessage[], options?: ChatMessageOptions) {
|
||
if (options?.response_format) {
|
||
throw new Error('response_format is not supported in stream mode');
|
||
}
|
||
|
||
const requestBody = {
|
||
model: this.model,
|
||
messages,
|
||
...options,
|
||
stream: true,
|
||
};
|
||
|
||
const response = await this.post(`${this.baseURL}/chat/completions`, { data: requestBody });
|
||
|
||
if (!response.ok) {
|
||
const errorText = await response.text();
|
||
throw new Error(`Chat Stream API request failed: ${response.status} ${response.statusText} - ${errorText}`);
|
||
}
|
||
|
||
const decoder = new TextDecoder();
|
||
const reader = response.body?.getReader();
|
||
|
||
if (!reader) {
|
||
throw new Error('Response body is not readable');
|
||
}
|
||
|
||
// 创建一个新的 ReadableStream,使用 decoder 解析数据
|
||
const stream = new ReadableStream({
|
||
async start(controller) {
|
||
try {
|
||
while (true) {
|
||
const { done, value } = await reader.read();
|
||
if (done) {
|
||
controller.close();
|
||
break;
|
||
}
|
||
// 检查 value 类型,如果是 Uint8Array 才解码,否则直接使用
|
||
if (typeof value === 'string') {
|
||
controller.enqueue(value);
|
||
} else if (value instanceof Uint8Array) {
|
||
const text = decoder.decode(value, { stream: true });
|
||
controller.enqueue(text);
|
||
} else {
|
||
controller.enqueue(value);
|
||
}
|
||
}
|
||
} catch (error) {
|
||
controller.error(error);
|
||
}
|
||
},
|
||
cancel() {
|
||
reader.releaseLock();
|
||
}
|
||
});
|
||
|
||
return stream as unknown as ChatStream;
|
||
}
|
||
|
||
/**
|
||
* 获取聊天使用情况
|
||
* @returns
|
||
*/
|
||
getChatUsage() {
|
||
return {
|
||
prompt_tokens: this.prompt_tokens,
|
||
total_tokens: this.total_tokens,
|
||
completion_tokens: this.completion_tokens,
|
||
};
|
||
}
|
||
|
||
getHeaders(headers?: Record<string, string>) {
|
||
return {
|
||
'Content-Type': 'application/json',
|
||
Authorization: `Bearer ${this.apiKey}`,
|
||
...headers,
|
||
};
|
||
}
|
||
/**
|
||
* 生成embedding 内部
|
||
* @param text
|
||
* @returns
|
||
*/
|
||
async generateEmbeddingCore(text: string | string[], options?: EmbeddingMessage): Promise<EmbeddingMessageComplete> {
|
||
const embeddingModel = options?.model || this.model;
|
||
|
||
const requestBody = {
|
||
model: embeddingModel,
|
||
input: text,
|
||
encoding_format: 'float',
|
||
...options,
|
||
};
|
||
|
||
const response = await this.post(`${this.baseURL}/embeddings`, { data: requestBody });
|
||
|
||
if (!response.ok) {
|
||
const errorText = await response.text();
|
||
throw new Error(`Embedding API request failed: ${response.status} ${response.statusText} - ${errorText}`);
|
||
}
|
||
|
||
const res = await response.json() as EmbeddingMessageComplete;
|
||
|
||
this.prompt_tokens += res.usage.prompt_tokens;
|
||
this.total_tokens += res.usage.total_tokens;
|
||
return res;
|
||
}
|
||
}
|