35 lines
1.1 KiB
TypeScript
35 lines
1.1 KiB
TypeScript
import { encoding_for_model, get_encoding } from 'tiktoken';
|
||
|
||
|
||
const MODEL_TO_ENCODING = {
|
||
'gpt-4': 'cl100k_base',
|
||
'gpt-4-turbo': 'cl100k_base',
|
||
'gpt-3.5-turbo': 'cl100k_base',
|
||
'text-embedding-ada-002': 'cl100k_base',
|
||
'text-davinci-002': 'p50k_base',
|
||
'text-davinci-003': 'p50k_base',
|
||
} as const;
|
||
|
||
export function numTokensFromString(text: string, model: keyof typeof MODEL_TO_ENCODING = 'gpt-3.5-turbo'): number {
|
||
try {
|
||
// 对于特定模型使用专门的编码器
|
||
const encoder = encoding_for_model(model);
|
||
const tokens = encoder.encode(text);
|
||
const tokenCount = tokens.length;
|
||
encoder.free(); // 释放编码器
|
||
return tokenCount;
|
||
} catch (error) {
|
||
try {
|
||
// 如果模型特定的编码器失败,尝试使用基础编码器
|
||
const encoder = get_encoding(MODEL_TO_ENCODING[model]);
|
||
const tokens = encoder.encode(text);
|
||
const tokenCount = tokens.length;
|
||
encoder.free(); // 释放编码器
|
||
return tokenCount;
|
||
} catch (error) {
|
||
// 如果编码失败,使用一个粗略的估计:平均每个字符0.25个token
|
||
return Math.ceil(text.length * 0.25);
|
||
}
|
||
}
|
||
}
|