video-tools/src/asr/provider/volcengine/asr-ws-big-model-client.ts

530 lines
14 KiB
TypeScript

import * as fs from 'fs/promises';
import * as path from 'path';
import * as zlib from 'zlib';
import * as util from 'util';
import { Readable } from 'stream';
import { promisify } from 'util';
import { nanoid } from 'nanoid';
import { VolcEngineBase, uuid } from './base.ts';
// Promisify zlib methods
const gzipPromise = promisify(zlib.gzip);
const gunzipPromise = promisify(zlib.gunzip);
// Protocol constants
const PROTOCOL_VERSION = 0b0001;
const DEFAULT_HEADER_SIZE = 0b0001;
// Message Type
const FULL_CLIENT_REQUEST = 0b0001;
const AUDIO_ONLY_REQUEST = 0b0010;
const FULL_SERVER_RESPONSE = 0b1001;
const SERVER_ACK = 0b1011;
const SERVER_ERROR_RESPONSE = 0b1111;
// Message Type Specific Flags
const NO_SEQUENCE = 0b0000; // no check sequence
const POS_SEQUENCE = 0b0001;
const NEG_SEQUENCE = 0b0010;
const NEG_WITH_SEQUENCE = 0b0011;
const NEG_SEQUENCE_1 = 0b0011;
// Message Serialization
const NO_SERIALIZATION = 0b0000;
const JSON_SERIALIZATION = 0b0001;
// Message Compression
const NO_COMPRESSION = 0b0000;
const GZIP_COMPRESSION = 0b0001;
/**
* Generate header for the WebSocket request
*/
function generateHeader(
messageType = FULL_CLIENT_REQUEST,
messageTypeSpecificFlags = NO_SEQUENCE,
serialMethod = JSON_SERIALIZATION,
compressionType = GZIP_COMPRESSION,
reservedData = 0x00,
): Buffer {
const header = Buffer.alloc(4);
const headerSize = 1;
header[0] = (PROTOCOL_VERSION << 4) | headerSize;
header[1] = (messageType << 4) | messageTypeSpecificFlags;
header[2] = (serialMethod << 4) | compressionType;
header[3] = reservedData;
return header;
}
/**
* Generate the sequence part of the request
*/
function generateBeforePayload(sequence: number): Buffer {
const beforePayload = Buffer.alloc(4);
beforePayload.writeInt32BE(sequence);
return beforePayload;
}
/**
* Parse response from the WebSocket server
*/
function parseResponse(res: Buffer): any {
const protocolVersion = res[0] >> 4;
const headerSize = res[0] & 0x0f;
const messageType = res[1] >> 4;
const messageTypeSpecificFlags = res[1] & 0x0f;
const serializationMethod = res[2] >> 4;
const messageCompression = res[2] & 0x0f;
const reserved = res[3];
const headerExtensions = res.slice(4, headerSize * 4);
const payload = res.slice(headerSize * 4);
const result: any = {
isLastPackage: false,
};
let payloadMsg = null;
let payloadSize = 0;
let offset = 0;
if (messageTypeSpecificFlags & 0x01) {
// receive frame with sequence
const seq = payload.readInt32BE(0);
result.payloadSequence = seq;
offset += 4;
}
if (messageTypeSpecificFlags & 0x02) {
// receive last package
result.isLastPackage = true;
}
const remainingPayload = payload.slice(offset);
if (messageType === FULL_SERVER_RESPONSE) {
payloadSize = remainingPayload.readInt32BE(0);
payloadMsg = remainingPayload.slice(4);
} else if (messageType === SERVER_ACK) {
const seq = remainingPayload.readInt32BE(0);
result.seq = seq;
if (remainingPayload.length >= 8) {
payloadSize = remainingPayload.readUInt32BE(4);
payloadMsg = remainingPayload.slice(8);
}
} else if (messageType === SERVER_ERROR_RESPONSE) {
const code = remainingPayload.readUInt32BE(0);
result.code = code;
payloadSize = remainingPayload.readUInt32BE(4);
payloadMsg = remainingPayload.slice(8);
}
if (!payloadMsg) {
return result;
}
if (messageCompression === GZIP_COMPRESSION) {
try {
const decompressed = zlib.gunzipSync(payloadMsg);
payloadMsg = decompressed;
} catch (error) {
console.error('Error decompressing payload:', error);
}
}
if (serializationMethod === JSON_SERIALIZATION) {
try {
payloadMsg = JSON.parse(payloadMsg.toString('utf-8'));
} catch (error) {
console.error('Error parsing JSON payload:', error);
}
} else if (serializationMethod !== NO_SERIALIZATION) {
payloadMsg = payloadMsg.toString('utf-8');
}
result.payloadMsg = payloadMsg;
result.payloadSize = payloadSize;
return result;
}
/**
* Read WAV file information
*/
async function readWavInfo(data: Buffer): Promise<{
channels: number;
sampleWidth: number;
sampleRate: number;
frames: number;
audioData: Buffer;
}> {
// This is a simplified WAV parser - in production you should use a proper library
if (data.length < 44) {
throw new Error('Invalid WAV file: too short');
}
// Check WAV header
if (data.slice(0, 4).toString() !== 'RIFF' || data.slice(8, 12).toString() !== 'WAVE') {
throw new Error('Invalid WAV file: not a WAV format');
}
// Parse header information
const channels = data.readUInt16LE(22);
const sampleRate = data.readUInt32LE(24);
const bitsPerSample = data.readUInt16LE(34);
const sampleWidth = bitsPerSample / 8;
// Find data chunk
let offset = 12; // Start after "WAVE"
let dataSize = 0;
let audioData: Buffer = Buffer.alloc(0);
while (offset < data.length) {
const chunkType = data.slice(offset, offset + 4).toString();
const chunkSize = data.readUInt32LE(offset + 4);
if (chunkType === 'data') {
dataSize = chunkSize;
audioData = data.slice(offset + 8, offset + 8 + chunkSize);
break;
}
offset += 8 + chunkSize;
}
const frames = dataSize / (channels * sampleWidth);
return {
channels,
sampleWidth,
sampleRate,
frames,
audioData,
};
}
/**
* Check if data is a valid WAV file
*/
function judgeWav(data: Buffer): boolean {
if (data.length < 44) {
return false;
}
return data.slice(0, 4).toString() === 'RIFF' && data.slice(8, 12).toString() === 'WAVE';
}
/**
* Slice data into chunks
*/
function* sliceData(data: Buffer, chunkSize: number): Generator<[Buffer, boolean]> {
const dataLen = data.length;
let offset = 0;
while (offset + chunkSize < dataLen) {
yield [data.slice(offset, offset + chunkSize), false];
offset += chunkSize;
}
yield [data.slice(offset, dataLen), true];
}
interface AsrClientOptions {
segDuration?: number;
wsUrl?: string;
uid?: string;
format?: string;
rate?: number;
bits?: number;
channel?: number;
codec?: string;
authMethod?: string;
hotWords?: string[];
streaming?: boolean;
mp3SegSize?: number;
resourceId?: string;
token?: string;
appid?: string;
}
interface AudioItem {
id: string | number;
path: string;
}
/**
* ASR WebSocket Client
*/
export class AsrWsClient extends VolcEngineBase {
private audioPath: string;
private successCode: number = 1000;
private segDuration: number;
private format: string;
private rate: number;
private bits: number;
private channel: number;
private codec: string;
private authMethod: string;
private hotWords: string[] | null;
private streaming: boolean;
private mp3SegSize: number;
private reqEvent: number = 1;
private uid: string;
private seq: number = 1;
private hasSendFullClientRequest: boolean = false;
constructor(audioPath: string, options: AsrClientOptions = {}) {
super({
url: options.wsUrl || 'wss://openspeech.bytedance.com/api/v3/sauc/bigmodel',
onConnect: () => this.onWsConnect(),
wsOptions: {
headers: {
'X-Api-Resource-Id': options.resourceId || 'volc.bigasr.sauc.duration',
'X-Api-Access-Key': options.token || '',
'X-Api-App-Key': options.appid || '',
'X-Api-Request-Id': uuid(),
},
},
});
this.audioPath = audioPath;
this.segDuration = options.segDuration || 100;
this.uid = options.uid || 'test';
this.format = options.format || 'wav';
this.rate = options.rate || 16000;
this.bits = options.bits || 16;
this.channel = options.channel || 1;
this.codec = options.codec || 'raw';
this.authMethod = options.authMethod || 'none';
this.hotWords = options.hotWords || null;
this.streaming = options.streaming !== undefined ? options.streaming : true;
this.mp3SegSize = options.mp3SegSize || 1000;
}
private onWsConnect() {
console.log('ASR WebSocket connected');
}
/**
* Construct request parameters
*/
private constructRequest(reqId: string, data?: any): any {
return {
user: {
uid: this.uid,
},
audio: {
format: this.format,
sample_rate: this.rate,
bits: this.bits,
channel: this.channel,
codec: this.codec,
},
request: {
model_name: 'bigmodel',
enable_punc: true,
},
};
}
private async sendFullClientRequest() {
if (this.hasSendFullClientRequest) {
return;
}
this.seq = 1;
const seq = this.seq;
const reqId = nanoid();
const requestParams = this.constructRequest(reqId);
// Prepare and send initial request
const payloadStr = JSON.stringify(requestParams);
const compressedPayload = await gzipPromise(Buffer.from(payloadStr));
const fullClientRequest = Buffer.concat([
generateHeader(FULL_CLIENT_REQUEST, POS_SEQUENCE),
generateBeforePayload(seq),
Buffer.alloc(4),
compressedPayload,
]);
// Set payload size
fullClientRequest.writeUInt32BE(compressedPayload.length, 8);
// Send initial request
(this as any).ws.send(fullClientRequest);
this.hasSendFullClientRequest = true;
}
/**
* Process audio data in segments
*/
private async segmentDataProcessor(audioData: Buffer, segmentSize: number): Promise<any> {
await this.sendFullClientRequest();
const that = this;
// Wait for response
const result = await new Promise<any>((resolve, reject) => {
const onMessage = async (event: MessageEvent) => {
try {
const response = parseResponse(Buffer.from(event.data as ArrayBuffer));
console.log('Initial response:', response);
// Process audio chunks
for (const [chunk, last] of sliceData(audioData, segmentSize)) {
that.seq += 1;
if (last) {
that.seq = -that.seq;
}
const seq = that.seq;
const start = Date.now();
const compressedChunk = await gzipPromise(chunk);
const messageType = AUDIO_ONLY_REQUEST;
const flags = last ? NEG_WITH_SEQUENCE : POS_SEQUENCE;
const audioRequest = Buffer.concat([generateHeader(messageType, flags), generateBeforePayload(seq), Buffer.alloc(4), compressedChunk]);
// Set payload size
audioRequest.writeUInt32BE(compressedChunk.length, 8);
// Send audio chunk
(this as any).ws.send(audioRequest);
// Wait for each response
const chunkResponse = await new Promise<any>((resolveChunk) => {
const onChunkMessage = (chunkEvent: MessageEvent) => {
(this as any).ws.removeEventListener('message', onChunkMessage);
const parsed = parseResponse(Buffer.from(chunkEvent.data as ArrayBuffer));
console.log(`Seq ${seq} response:`, parsed);
resolveChunk(parsed);
};
(this as any).ws.addEventListener('message', onChunkMessage, { once: true });
});
// If streaming, add delay to simulate real-time
if (this.streaming) {
const elapsed = Date.now() - start;
const sleepTime = Math.max(0, this.segDuration - elapsed);
await new Promise((r) => setTimeout(r, sleepTime));
}
// If this is the last chunk, resolve with final result
if (last) {
resolve(chunkResponse);
break;
}
}
(this as any).ws.removeEventListener('message', onMessage);
} catch (error) {
console.error('Error processing response:', error);
reject(error);
}
};
(this as any).ws.addEventListener('message', onMessage, { once: true });
(this as any).ws.addEventListener(
'error',
(error) => {
console.error('WebSocket error:', error);
reject(error);
},
{ once: true },
);
});
return result;
}
/**
* Execute ASR on the audio file
*/
public async execute(): Promise<any> {
try {
const data = await fs.readFile(this.audioPath);
if (this.format === 'mp3') {
const segmentSize = this.mp3SegSize;
return await this.segmentDataProcessor(data, segmentSize);
}
if (this.format === 'wav') {
const wavInfo = await readWavInfo(data);
const sizePerSec = wavInfo.channels * wavInfo.sampleWidth * wavInfo.sampleRate;
const segmentSize = Math.floor((sizePerSec * this.segDuration) / 1000);
// 3200
return await this.segmentDataProcessor(data, segmentSize);
}
if (this.format === 'pcm') {
const segmentSize = Math.floor((this.rate * 2 * this.channel * this.segDuration) / 500);
return await this.segmentDataProcessor(data, segmentSize);
}
throw new Error('Unsupported format');
} catch (error) {
console.error('Error executing ASR:', error);
throw error;
}
}
/**
* Send OPUS data for processing
*/
public async sendOpusData(audioData: Buffer): Promise<any> {
const segmentSize = Math.floor((this.rate * 2 * this.channel * this.segDuration) / 500);
return await this.segmentDataProcessor(audioData, segmentSize);
}
}
/**
* Execute ASR on a single audio file
*/
export async function executeOne(audioItem: AudioItem, options: AsrClientOptions = {}): Promise<any> {
if (!audioItem.id || !audioItem.path) {
throw new Error('Audio item must have id and path properties');
}
const audioId = audioItem.id;
const audioPath = path.resolve(process.cwd(), audioItem.path);
const asrClient = new AsrWsClient(audioPath, options);
await new Promise((resolve) => setTimeout(resolve, 2000));
return asrClient.execute().then((result) => {
return {
id: audioId,
path: audioPath,
result: result,
};
});
}
/**
* Test stream processing
*/
export const testStream = async () => {
console.log('测试流式');
const audioPath = 'videos/asr_example.wav';
const res = await executeOne({
id: 1,
path: audioPath,
})
.then((result) => {
console.log('====end test=====');
console.log(result);
return result;
})
.catch((error) => {
console.error('Test error:', error);
return '';
});
};
/**
* Handle audio data directly
*/
export async function handleAudioData(audioData: Buffer, options: AsrClientOptions = {}): Promise<any> {
const asrClient = new AsrWsClient('', options);
return await asrClient.sendOpusData(audioData);
}