init
This commit is contained in:
34
src/provider/utils/token.ts
Normal file
34
src/provider/utils/token.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { encoding_for_model, get_encoding } from 'tiktoken';
|
||||
|
||||
|
||||
const MODEL_TO_ENCODING = {
|
||||
'gpt-4': 'cl100k_base',
|
||||
'gpt-4-turbo': 'cl100k_base',
|
||||
'gpt-3.5-turbo': 'cl100k_base',
|
||||
'text-embedding-ada-002': 'cl100k_base',
|
||||
'text-davinci-002': 'p50k_base',
|
||||
'text-davinci-003': 'p50k_base',
|
||||
} as const;
|
||||
|
||||
export function numTokensFromString(text: string, model: keyof typeof MODEL_TO_ENCODING = 'gpt-3.5-turbo'): number {
|
||||
try {
|
||||
// 对于特定模型使用专门的编码器
|
||||
const encoder = encoding_for_model(model);
|
||||
const tokens = encoder.encode(text);
|
||||
const tokenCount = tokens.length;
|
||||
encoder.free(); // 释放编码器
|
||||
return tokenCount;
|
||||
} catch (error) {
|
||||
try {
|
||||
// 如果模型特定的编码器失败,尝试使用基础编码器
|
||||
const encoder = get_encoding(MODEL_TO_ENCODING[model]);
|
||||
const tokens = encoder.encode(text);
|
||||
const tokenCount = tokens.length;
|
||||
encoder.free(); // 释放编码器
|
||||
return tokenCount;
|
||||
} catch (error) {
|
||||
// 如果编码失败,使用一个粗略的估计:平均每个字符0.25个token
|
||||
return Math.ceil(text.length * 0.25);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user