diff --git a/package.json b/package.json index f4bd2a2..40d80d9 100644 --- a/package.json +++ b/package.json @@ -44,6 +44,7 @@ "dayjs": "^1.11.13", "formidable": "^3.5.2", "lodash-es": "^4.17.21", + "nanoid": "^5.1.5", "node-record-lpcm16": "^1.0.1" }, "devDependencies": { @@ -63,6 +64,7 @@ "commander": "^13.1.0", "concurrently": "^9.1.2", "cross-env": "^7.0.3", + "dotenv": "^16.5.0", "inquire": "^0.4.8", "ioredis": "^5.6.0", "jsrepo": "^1.45.3", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b2fbc9f..df1e834 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -10,16 +10,16 @@ importers: dependencies: '@kevisual/code-center-module': specifier: 0.0.18 - version: 0.0.18(@kevisual/auth@1.0.5)(@kevisual/router@0.0.10)(@kevisual/use-config@1.0.10(dotenv@16.4.7))(ioredis@5.6.0)(pg@8.14.1)(sequelize@6.37.7(pg@8.14.1)) + version: 0.0.18(@kevisual/auth@1.0.5)(@kevisual/router@0.0.10)(@kevisual/use-config@1.0.10(dotenv@16.5.0))(ioredis@5.6.0)(pg@8.14.1)(sequelize@6.37.7(pg@8.14.1)) '@kevisual/mark': specifier: 0.0.7 - version: 0.0.7(dotenv@16.4.7)(esbuild@0.25.0) + version: 0.0.7(dotenv@16.5.0)(esbuild@0.25.0) '@kevisual/router': specifier: 0.0.10 version: 0.0.10 '@kevisual/use-config': specifier: ^1.0.10 - version: 1.0.10(dotenv@16.4.7) + version: 1.0.10(dotenv@16.5.0) cookie: specifier: ^1.0.2 version: 1.0.2 @@ -32,6 +32,9 @@ importers: lodash-es: specifier: ^4.17.21 version: 4.17.21 + nanoid: + specifier: ^5.1.5 + version: 5.1.5 node-record-lpcm16: specifier: ^1.0.1 version: 1.0.1 @@ -81,6 +84,9 @@ importers: cross-env: specifier: ^7.0.3 version: 7.0.3 + dotenv: + specifier: ^16.5.0 + version: 16.5.0 inquire: specifier: ^0.4.8 version: 0.4.8 @@ -1347,8 +1353,8 @@ packages: resolution: {integrity: sha512-1gxPBJpI/pcjQhKgIU91II6Wkay+dLcN3M6rf2uwP8hRur3HtQXjVrdAK3sjC0piaEuxzMwjXChcETiJl47lAQ==} engines: {node: '>=18'} - dotenv@16.4.7: - resolution: {integrity: sha512-47qPchRCykZC03FhkYAhrvwU4xDBFIj1QPqaarj6mdM/hgUzfPHcpkHJOn3mJAufFeeAxAzeGsr5X0M4k6fLZQ==} + dotenv@16.5.0: + resolution: {integrity: sha512-m/C+AwOAr9/W1UOIZUo232ejMNnJAJtYQjUbHoNTBNTJSvqzzDh7vnrei3o3r3m9blf6ZoDkvcw0VmozNRFJxg==} engines: {node: '>=12'} dotignore@0.1.2: @@ -2092,11 +2098,6 @@ packages: engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} hasBin: true - nanoid@5.1.2: - resolution: {integrity: sha512-b+CiXQCNMUGe0Ri64S9SXFcP9hogjAJ2Rd6GdVxhPLRm7mhGaM7VgOvCAJ1ZshfHbqVDI3uqTI5C8/GaKuLI7g==} - engines: {node: ^18 || >=20} - hasBin: true - nanoid@5.1.5: resolution: {integrity: sha512-Ir/+ZpE9fDsNH0hQ3C68uyThDXzYcim2EqcZ8zn8Chtt1iylPT9xXJB0kPCnqzgcEGikO9RxSrh63MsmVCU7Fw==} engines: {node: ^18 || >=20} @@ -3324,11 +3325,11 @@ snapshots: '@kevisual/auth@1.0.5': {} - '@kevisual/code-center-module@0.0.18(@kevisual/auth@1.0.5)(@kevisual/router@0.0.10)(@kevisual/use-config@1.0.10(dotenv@16.4.7))(ioredis@5.6.0)(pg@8.14.1)(sequelize@6.37.7(pg@8.14.1))': + '@kevisual/code-center-module@0.0.18(@kevisual/auth@1.0.5)(@kevisual/router@0.0.10)(@kevisual/use-config@1.0.10(dotenv@16.5.0))(ioredis@5.6.0)(pg@8.14.1)(sequelize@6.37.7(pg@8.14.1))': dependencies: '@kevisual/auth': 1.0.5 '@kevisual/router': 0.0.10 - '@kevisual/use-config': 1.0.10(dotenv@16.4.7) + '@kevisual/use-config': 1.0.10(dotenv@16.5.0) ioredis: 5.6.0 nanoid: 5.1.5 pg: 8.14.1 @@ -3344,14 +3345,14 @@ snapshots: dependencies: eventemitter3: 5.0.1 - '@kevisual/mark@0.0.7(dotenv@16.4.7)(esbuild@0.25.0)': + '@kevisual/mark@0.0.7(dotenv@16.5.0)(esbuild@0.25.0)': dependencies: '@kevisual/auth': 1.0.5 '@kevisual/rollup-tools': 0.0.1(esbuild@0.25.0) '@kevisual/router': 0.0.7 - '@kevisual/use-config': 1.0.10(dotenv@16.4.7) + '@kevisual/use-config': 1.0.10(dotenv@16.5.0) cookie: 1.0.2 - nanoid: 5.1.2 + nanoid: 5.1.5 pg: 8.14.1 sequelize: 6.37.7(pg@8.14.1) transitivePeerDependencies: @@ -3413,10 +3414,10 @@ snapshots: '@kevisual/types@0.0.6': {} - '@kevisual/use-config@1.0.10(dotenv@16.4.7)': + '@kevisual/use-config@1.0.10(dotenv@16.5.0)': dependencies: '@kevisual/load': 0.0.4 - dotenv: 16.4.7 + dotenv: 16.5.0 '@ljharb/resumer@0.1.3': dependencies: @@ -4361,7 +4362,7 @@ snapshots: dependencies: type-fest: 4.37.0 - dotenv@16.4.7: {} + dotenv@16.5.0: {} dotignore@0.1.2: dependencies: @@ -5227,8 +5228,6 @@ snapshots: nanoid@3.3.9: {} - nanoid@5.1.2: {} - nanoid@5.1.5: {} needle@2.4.0: diff --git a/src/asr/provider/funasr/ws.ts b/src/asr/provider/funasr/ws.ts index a835ffd..8ad63f2 100644 --- a/src/asr/provider/funasr/ws.ts +++ b/src/asr/provider/funasr/ws.ts @@ -1,4 +1,5 @@ -import WebSocket from 'ws'; +// import WebSocket from 'ws'; +import { initWs } from '../../ws/index.ts'; type VideoWSOptions = { url?: string; @@ -27,15 +28,17 @@ export class VideoWS { isFile?: boolean; onConnect?: () => void; constructor(options?: VideoWSOptions) { - this.ws = - options?.ws || - new WebSocket(options.url, { - rejectUnauthorized: false, - }); this.itn = options?.itn || false; this.mode = options?.mode || 'online'; this.isFile = options?.isFile || false; - + this.initWs(options); + } + async initWs(options: VideoWSOptions) { + if (options?.ws) { + this.ws = options.ws; + } else { + this.ws = await initWs(options.url); + } this.onConnect = options?.onConnect || (() => {}); this.ws.onopen = this.onOpen.bind(this); this.ws.onmessage = this.onMessage.bind(this); diff --git a/src/asr/provider/volcengine/asr-ws-big-model-client.ts b/src/asr/provider/volcengine/asr-ws-big-model-client.ts new file mode 100644 index 0000000..99587db --- /dev/null +++ b/src/asr/provider/volcengine/asr-ws-big-model-client.ts @@ -0,0 +1,529 @@ +import * as fs from 'fs/promises'; +import * as path from 'path'; +import * as zlib from 'zlib'; +import * as util from 'util'; +import { Readable } from 'stream'; +import { promisify } from 'util'; +import { nanoid } from 'nanoid'; +import { VolcEngineBase, uuid } from './base.ts'; + +// Promisify zlib methods +const gzipPromise = promisify(zlib.gzip); +const gunzipPromise = promisify(zlib.gunzip); + +// Protocol constants +const PROTOCOL_VERSION = 0b0001; +const DEFAULT_HEADER_SIZE = 0b0001; + +// Message Type +const FULL_CLIENT_REQUEST = 0b0001; +const AUDIO_ONLY_REQUEST = 0b0010; +const FULL_SERVER_RESPONSE = 0b1001; +const SERVER_ACK = 0b1011; +const SERVER_ERROR_RESPONSE = 0b1111; + +// Message Type Specific Flags +const NO_SEQUENCE = 0b0000; // no check sequence +const POS_SEQUENCE = 0b0001; +const NEG_SEQUENCE = 0b0010; +const NEG_WITH_SEQUENCE = 0b0011; +const NEG_SEQUENCE_1 = 0b0011; + +// Message Serialization +const NO_SERIALIZATION = 0b0000; +const JSON_SERIALIZATION = 0b0001; + +// Message Compression +const NO_COMPRESSION = 0b0000; +const GZIP_COMPRESSION = 0b0001; + +/** + * Generate header for the WebSocket request + */ +function generateHeader( + messageType = FULL_CLIENT_REQUEST, + messageTypeSpecificFlags = NO_SEQUENCE, + serialMethod = JSON_SERIALIZATION, + compressionType = GZIP_COMPRESSION, + reservedData = 0x00, +): Buffer { + const header = Buffer.alloc(4); + const headerSize = 1; + header[0] = (PROTOCOL_VERSION << 4) | headerSize; + header[1] = (messageType << 4) | messageTypeSpecificFlags; + header[2] = (serialMethod << 4) | compressionType; + header[3] = reservedData; + return header; +} + +/** + * Generate the sequence part of the request + */ +function generateBeforePayload(sequence: number): Buffer { + const beforePayload = Buffer.alloc(4); + beforePayload.writeInt32BE(sequence); + return beforePayload; +} + +/** + * Parse response from the WebSocket server + */ +function parseResponse(res: Buffer): any { + const protocolVersion = res[0] >> 4; + const headerSize = res[0] & 0x0f; + const messageType = res[1] >> 4; + const messageTypeSpecificFlags = res[1] & 0x0f; + const serializationMethod = res[2] >> 4; + const messageCompression = res[2] & 0x0f; + const reserved = res[3]; + const headerExtensions = res.slice(4, headerSize * 4); + const payload = res.slice(headerSize * 4); + + const result: any = { + isLastPackage: false, + }; + + let payloadMsg = null; + let payloadSize = 0; + let offset = 0; + + if (messageTypeSpecificFlags & 0x01) { + // receive frame with sequence + const seq = payload.readInt32BE(0); + result.payloadSequence = seq; + offset += 4; + } + + if (messageTypeSpecificFlags & 0x02) { + // receive last package + result.isLastPackage = true; + } + + const remainingPayload = payload.slice(offset); + + if (messageType === FULL_SERVER_RESPONSE) { + payloadSize = remainingPayload.readInt32BE(0); + payloadMsg = remainingPayload.slice(4); + } else if (messageType === SERVER_ACK) { + const seq = remainingPayload.readInt32BE(0); + result.seq = seq; + if (remainingPayload.length >= 8) { + payloadSize = remainingPayload.readUInt32BE(4); + payloadMsg = remainingPayload.slice(8); + } + } else if (messageType === SERVER_ERROR_RESPONSE) { + const code = remainingPayload.readUInt32BE(0); + result.code = code; + payloadSize = remainingPayload.readUInt32BE(4); + payloadMsg = remainingPayload.slice(8); + } + + if (!payloadMsg) { + return result; + } + + if (messageCompression === GZIP_COMPRESSION) { + try { + const decompressed = zlib.gunzipSync(payloadMsg); + payloadMsg = decompressed; + } catch (error) { + console.error('Error decompressing payload:', error); + } + } + + if (serializationMethod === JSON_SERIALIZATION) { + try { + payloadMsg = JSON.parse(payloadMsg.toString('utf-8')); + } catch (error) { + console.error('Error parsing JSON payload:', error); + } + } else if (serializationMethod !== NO_SERIALIZATION) { + payloadMsg = payloadMsg.toString('utf-8'); + } + + result.payloadMsg = payloadMsg; + result.payloadSize = payloadSize; + return result; +} + +/** + * Read WAV file information + */ +async function readWavInfo(data: Buffer): Promise<{ + channels: number; + sampleWidth: number; + sampleRate: number; + frames: number; + audioData: Buffer; +}> { + // This is a simplified WAV parser - in production you should use a proper library + if (data.length < 44) { + throw new Error('Invalid WAV file: too short'); + } + + // Check WAV header + if (data.slice(0, 4).toString() !== 'RIFF' || data.slice(8, 12).toString() !== 'WAVE') { + throw new Error('Invalid WAV file: not a WAV format'); + } + + // Parse header information + const channels = data.readUInt16LE(22); + const sampleRate = data.readUInt32LE(24); + const bitsPerSample = data.readUInt16LE(34); + const sampleWidth = bitsPerSample / 8; + + // Find data chunk + let offset = 12; // Start after "WAVE" + let dataSize = 0; + let audioData: Buffer = Buffer.alloc(0); + + while (offset < data.length) { + const chunkType = data.slice(offset, offset + 4).toString(); + const chunkSize = data.readUInt32LE(offset + 4); + + if (chunkType === 'data') { + dataSize = chunkSize; + audioData = data.slice(offset + 8, offset + 8 + chunkSize); + break; + } + + offset += 8 + chunkSize; + } + + const frames = dataSize / (channels * sampleWidth); + + return { + channels, + sampleWidth, + sampleRate, + frames, + audioData, + }; +} + +/** + * Check if data is a valid WAV file + */ +function judgeWav(data: Buffer): boolean { + if (data.length < 44) { + return false; + } + return data.slice(0, 4).toString() === 'RIFF' && data.slice(8, 12).toString() === 'WAVE'; +} + +/** + * Slice data into chunks + */ +function* sliceData(data: Buffer, chunkSize: number): Generator<[Buffer, boolean]> { + const dataLen = data.length; + let offset = 0; + + while (offset + chunkSize < dataLen) { + yield [data.slice(offset, offset + chunkSize), false]; + offset += chunkSize; + } + + yield [data.slice(offset, dataLen), true]; +} + +interface AsrClientOptions { + segDuration?: number; + wsUrl?: string; + uid?: string; + format?: string; + rate?: number; + bits?: number; + channel?: number; + codec?: string; + authMethod?: string; + hotWords?: string[]; + streaming?: boolean; + mp3SegSize?: number; + + resourceId?: string; + token?: string; + appid?: string; +} + +interface AudioItem { + id: string | number; + path: string; +} + +/** + * ASR WebSocket Client + */ +export class AsrWsClient extends VolcEngineBase { + private audioPath: string; + private successCode: number = 1000; + private segDuration: number; + private format: string; + private rate: number; + private bits: number; + private channel: number; + private codec: string; + private authMethod: string; + private hotWords: string[] | null; + private streaming: boolean; + private mp3SegSize: number; + private reqEvent: number = 1; + private uid: string; + private seq: number = 1; + private hasSendFullClientRequest: boolean = false; + + constructor(audioPath: string, options: AsrClientOptions = {}) { + super({ + url: options.wsUrl || 'wss://openspeech.bytedance.com/api/v3/sauc/bigmodel', + onConnect: () => this.onWsConnect(), + wsOptions: { + headers: { + 'X-Api-Resource-Id': options.resourceId || 'volc.bigasr.sauc.duration', + 'X-Api-Access-Key': options.token || '', + 'X-Api-App-Key': options.appid || '', + 'X-Api-Request-Id': uuid(), + }, + }, + }); + + this.audioPath = audioPath; + this.segDuration = options.segDuration || 100; + this.uid = options.uid || 'test'; + this.format = options.format || 'wav'; + this.rate = options.rate || 16000; + this.bits = options.bits || 16; + this.channel = options.channel || 1; + this.codec = options.codec || 'raw'; + this.authMethod = options.authMethod || 'none'; + this.hotWords = options.hotWords || null; + this.streaming = options.streaming !== undefined ? options.streaming : true; + this.mp3SegSize = options.mp3SegSize || 1000; + } + + private onWsConnect() { + console.log('ASR WebSocket connected'); + } + + /** + * Construct request parameters + */ + private constructRequest(reqId: string, data?: any): any { + return { + user: { + uid: this.uid, + }, + audio: { + format: this.format, + sample_rate: this.rate, + bits: this.bits, + channel: this.channel, + codec: this.codec, + }, + request: { + model_name: 'bigmodel', + enable_punc: true, + }, + }; + } + private async sendFullClientRequest() { + if (this.hasSendFullClientRequest) { + return; + } + this.seq = 1; + const seq = this.seq; + const reqId = nanoid(); + const requestParams = this.constructRequest(reqId); + + // Prepare and send initial request + const payloadStr = JSON.stringify(requestParams); + const compressedPayload = await gzipPromise(Buffer.from(payloadStr)); + + const fullClientRequest = Buffer.concat([ + generateHeader(FULL_CLIENT_REQUEST, POS_SEQUENCE), + generateBeforePayload(seq), + Buffer.alloc(4), + compressedPayload, + ]); + + // Set payload size + fullClientRequest.writeUInt32BE(compressedPayload.length, 8); + + // Send initial request + (this as any).ws.send(fullClientRequest); + this.hasSendFullClientRequest = true; + } + /** + * Process audio data in segments + */ + private async segmentDataProcessor(audioData: Buffer, segmentSize: number): Promise { + await this.sendFullClientRequest(); + const that = this; + // Wait for response + const result = await new Promise((resolve, reject) => { + const onMessage = async (event: MessageEvent) => { + try { + const response = parseResponse(Buffer.from(event.data as ArrayBuffer)); + console.log('Initial response:', response); + + // Process audio chunks + for (const [chunk, last] of sliceData(audioData, segmentSize)) { + that.seq += 1; + if (last) { + that.seq = -that.seq; + } + const seq = that.seq; + + const start = Date.now(); + const compressedChunk = await gzipPromise(chunk); + + const messageType = AUDIO_ONLY_REQUEST; + const flags = last ? NEG_WITH_SEQUENCE : POS_SEQUENCE; + + const audioRequest = Buffer.concat([generateHeader(messageType, flags), generateBeforePayload(seq), Buffer.alloc(4), compressedChunk]); + + // Set payload size + audioRequest.writeUInt32BE(compressedChunk.length, 8); + + // Send audio chunk + (this as any).ws.send(audioRequest); + + // Wait for each response + const chunkResponse = await new Promise((resolveChunk) => { + const onChunkMessage = (chunkEvent: MessageEvent) => { + (this as any).ws.removeEventListener('message', onChunkMessage); + const parsed = parseResponse(Buffer.from(chunkEvent.data as ArrayBuffer)); + console.log(`Seq ${seq} response:`, parsed); + resolveChunk(parsed); + }; + + (this as any).ws.addEventListener('message', onChunkMessage, { once: true }); + }); + + // If streaming, add delay to simulate real-time + if (this.streaming) { + const elapsed = Date.now() - start; + const sleepTime = Math.max(0, this.segDuration - elapsed); + await new Promise((r) => setTimeout(r, sleepTime)); + } + + // If this is the last chunk, resolve with final result + if (last) { + resolve(chunkResponse); + break; + } + } + + (this as any).ws.removeEventListener('message', onMessage); + } catch (error) { + console.error('Error processing response:', error); + reject(error); + } + }; + + (this as any).ws.addEventListener('message', onMessage, { once: true }); + + (this as any).ws.addEventListener( + 'error', + (error) => { + console.error('WebSocket error:', error); + reject(error); + }, + { once: true }, + ); + }); + + return result; + } + + /** + * Execute ASR on the audio file + */ + public async execute(): Promise { + try { + const data = await fs.readFile(this.audioPath); + + if (this.format === 'mp3') { + const segmentSize = this.mp3SegSize; + return await this.segmentDataProcessor(data, segmentSize); + } + + if (this.format === 'wav') { + const wavInfo = await readWavInfo(data); + const sizePerSec = wavInfo.channels * wavInfo.sampleWidth * wavInfo.sampleRate; + const segmentSize = Math.floor((sizePerSec * this.segDuration) / 1000); + // 3200 + return await this.segmentDataProcessor(data, segmentSize); + } + + if (this.format === 'pcm') { + const segmentSize = Math.floor((this.rate * 2 * this.channel * this.segDuration) / 500); + return await this.segmentDataProcessor(data, segmentSize); + } + + throw new Error('Unsupported format'); + } catch (error) { + console.error('Error executing ASR:', error); + throw error; + } + } + + /** + * Send OPUS data for processing + */ + public async sendOpusData(audioData: Buffer): Promise { + const segmentSize = Math.floor((this.rate * 2 * this.channel * this.segDuration) / 500); + return await this.segmentDataProcessor(audioData, segmentSize); + } +} + +/** + * Execute ASR on a single audio file + */ +export async function executeOne(audioItem: AudioItem, options: AsrClientOptions = {}): Promise { + if (!audioItem.id || !audioItem.path) { + throw new Error('Audio item must have id and path properties'); + } + + const audioId = audioItem.id; + const audioPath = path.resolve(process.cwd(), audioItem.path); + + const asrClient = new AsrWsClient(audioPath, options); + await new Promise((resolve) => setTimeout(resolve, 2000)); + + return asrClient.execute().then((result) => { + return { + id: audioId, + path: audioPath, + result: result, + }; + }); +} + +/** + * Test stream processing + */ +export const testStream = async () => { + console.log('测试流式'); + const audioPath = 'videos/asr_example.wav'; + + const res = await executeOne({ + id: 1, + path: audioPath, + }) + .then((result) => { + console.log('====end test====='); + console.log(result); + return result; + }) + .catch((error) => { + console.error('Test error:', error); + return ''; + }); +}; + +/** + * Handle audio data directly + */ +export async function handleAudioData(audioData: Buffer, options: AsrClientOptions = {}): Promise { + const asrClient = new AsrWsClient('', options); + return await asrClient.sendOpusData(audioData); +} diff --git a/src/asr/provider/volcengine/asr-ws-client.ts b/src/asr/provider/volcengine/asr-ws-client.ts new file mode 100644 index 0000000..6dedd31 --- /dev/null +++ b/src/asr/provider/volcengine/asr-ws-client.ts @@ -0,0 +1,492 @@ +import * as fs from 'fs/promises'; +import * as zlib from 'zlib'; +import { promisify } from 'util'; +import { VolcEngineBase, uuid } from './base.ts'; + +// Promisify zlib methods +const gzipPromise = promisify(zlib.gzip); +const gunzipPromise = promisify(zlib.gunzip); + +// Protocol constants +const PROTOCOL_VERSION = 0b0001; +const DEFAULT_HEADER_SIZE = 0b0001; + +// Message Type +const CLIENT_FULL_REQUEST = 0b0001; +const CLIENT_AUDIO_ONLY_REQUEST = 0b0010; +const SERVER_FULL_RESPONSE = 0b1001; +const SERVER_ACK = 0b1011; +const SERVER_ERROR_RESPONSE = 0b1111; + +// Message Type Specific Flags +const NO_SEQUENCE = 0b0000; // no check sequence +const POS_SEQUENCE = 0b0001; +const NEG_SEQUENCE = 0b0010; +const NEG_SEQUENCE_1 = 0b0011; + +// Message Serialization +const NO_SERIALIZATION = 0b0000; +const JSON_SERIALIZATION = 0b0001; +const THRIFT = 0b0011; +const CUSTOM_TYPE = 0b1111; + +// Message Compression +const NO_COMPRESSION = 0b0000; +const GZIP = 0b0001; +const CUSTOM_COMPRESSION = 0b1111; + +/** + * Generate header for WebSocket requests + */ +function generateHeader( + version = PROTOCOL_VERSION, + messageType = CLIENT_FULL_REQUEST, + messageTypeSpecificFlags = NO_SEQUENCE, + serialMethod = JSON_SERIALIZATION, + compressionType = GZIP, + reservedData = 0x00, +): Buffer { + const header = Buffer.alloc(4); + const headerSize = 1; + header[0] = (version << 4) | headerSize; + header[1] = (messageType << 4) | messageTypeSpecificFlags; + header[2] = (serialMethod << 4) | compressionType; + header[3] = reservedData; + return header; +} + +/** + * Generate full default header for client request + */ +function generateFullDefaultHeader(): Buffer { + return generateHeader(); +} + +/** + * Generate audio default header for client request + */ +function generateAudioDefaultHeader(): Buffer { + return generateHeader(PROTOCOL_VERSION, CLIENT_AUDIO_ONLY_REQUEST); +} + +/** + * Generate last audio default header for client request + */ +function generateLastAudioDefaultHeader(): Buffer { + return generateHeader(PROTOCOL_VERSION, CLIENT_AUDIO_ONLY_REQUEST, NEG_SEQUENCE); +} + +/** + * Parse response from the WebSocket server + */ +function parseResponse(res: Buffer): any { + const protocolVersion = res[0] >> 4; + const headerSize = res[0] & 0x0f; + const messageType = res[1] >> 4; + const messageTypeSpecificFlags = res[1] & 0x0f; + const serializationMethod = res[2] >> 4; + const messageCompression = res[2] & 0x0f; + const reserved = res[3]; + const headerExtensions = res.slice(4, headerSize * 4); + const payload = res.slice(headerSize * 4); + + const result: any = {}; + let payloadMsg = null; + let payloadSize = 0; + + if (messageType === SERVER_FULL_RESPONSE) { + payloadSize = payload.readInt32BE(0); + payloadMsg = payload.slice(4); + } else if (messageType === SERVER_ACK) { + const seq = payload.readInt32BE(0); + result.seq = seq; + if (payload.length >= 8) { + payloadSize = payload.readUInt32BE(4); + payloadMsg = payload.slice(8); + } + } else if (messageType === SERVER_ERROR_RESPONSE) { + const code = payload.readUInt32BE(0); + result.code = code; + payloadSize = payload.readUInt32BE(4); + payloadMsg = payload.slice(8); + } + + if (!payloadMsg) { + return result; + } + + if (messageCompression === GZIP) { + try { + payloadMsg = zlib.gunzipSync(payloadMsg); + } catch (error) { + console.error('Error decompressing payload:', error); + } + } + + if (serializationMethod === JSON_SERIALIZATION) { + try { + payloadMsg = JSON.parse(payloadMsg.toString('utf-8')); + } catch (error) { + console.error('Error parsing JSON payload:', error); + } + } else if (serializationMethod !== NO_SERIALIZATION) { + payloadMsg = payloadMsg.toString('utf-8'); + } + + result.payloadMsg = payloadMsg; + result.payloadSize = payloadSize; + return result; +} + +/** + * Read WAV file information + */ +async function readWavInfo(data: Buffer): Promise<{ + channels: number; + sampleWidth: number; + sampleRate: number; + frames: number; + wavBytes: Buffer; +}> { + // Simple WAV parser - in production you should use a proper library + if (data.length < 44) { + throw new Error('Invalid WAV file: too short'); + } + + // Check WAV header + if (data.slice(0, 4).toString() !== 'RIFF' || data.slice(8, 12).toString() !== 'WAVE') { + throw new Error('Invalid WAV file: not a WAV format'); + } + + // Parse header information + const channels = data.readUInt16LE(22); + const sampleRate = data.readUInt32LE(24); + const bitsPerSample = data.readUInt16LE(34); + const sampleWidth = bitsPerSample / 8; + + // Find data chunk + let offset = 12; // Start after "WAVE" + let dataSize = 0; + let wavBytes: Buffer = Buffer.alloc(0); + + while (offset < data.length) { + const chunkType = data.slice(offset, offset + 4).toString(); + const chunkSize = data.readUInt32LE(offset + 4); + + if (chunkType === 'data') { + dataSize = chunkSize; + wavBytes = data.slice(offset + 8, offset + 8 + chunkSize); + break; + } + + offset += 8 + chunkSize; + } + + const frames = dataSize / (channels * sampleWidth); + + return { + channels, + sampleWidth, + sampleRate, + frames, + wavBytes, + }; +} + +/** + * Generator to slice data into chunks + */ +function* sliceData(data: Buffer, chunkSize: number): Generator<[Buffer, boolean]> { + const dataLen = data.length; + let offset = 0; + + while (offset + chunkSize < dataLen) { + yield [data.slice(offset, offset + chunkSize), false]; + offset += chunkSize; + } + + yield [data.slice(offset, dataLen), true]; +} + +enum AudioType { + LOCAL = 1, // 使用本地音频文件 +} + +interface AsrClientOptions { + segDuration?: number; + nbest?: number; + appid?: string; + token?: string; + wsUrl?: string; + uid?: string; + workflow?: string; + showLanguage?: boolean; + showUtterances?: boolean; + resultType?: string; + format?: string; + sampleRate?: number; + language?: string; + bits?: number; + channel?: number; + codec?: string; + audioType?: AudioType; + mp3SegSize?: number; + cluster?: string; +} + +interface AudioItem { + id: string | number; + path: string; +} + +export class AsrWsClient extends VolcEngineBase { + private audioPath: string; + private cluster: string; + private successCode: number = 1000; + private segDuration: number; + private nbest: number; + private appid: string; + private token: string; + private uid: string; + private workflow: string; + private showLanguage: boolean; + private showUtterances: boolean; + private resultType: string; + private format: string; + private rate: number; + private language: string; + private bits: number; + private channel: number; + private codec: string; + private audioType: AudioType; + private mp3SegSize: number; + + constructor(audioPath: string, cluster: string, options: AsrClientOptions = {}) { + super({ + url: options.wsUrl || 'wss://openspeech.bytedance.com/api/v2/asr', + onConnect: () => this.onWsConnect(), + enabled: false, + wsOptions: { + headers: { + Authorization: `Bearer; ${options.token}`, + }, + }, + }); + this.audioPath = audioPath; + this.cluster = cluster; + this.segDuration = options.segDuration || 15000; + this.nbest = options.nbest || 1; + this.appid = options.appid || ''; + this.token = options.token || ''; + this.uid = options.uid || 'test'; + this.workflow = options.workflow || 'audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate'; + this.showLanguage = options.showLanguage || false; + this.showUtterances = options.showUtterances || false; + this.resultType = options.resultType || 'full'; + this.format = options.format || 'wav'; + this.rate = options.sampleRate || 16000; + this.language = options.language || 'zh-CN'; + this.bits = options.bits || 16; + this.channel = options.channel || 1; + this.codec = options.codec || 'raw'; + this.audioType = options.audioType || AudioType.LOCAL; + this.mp3SegSize = options.mp3SegSize || 10000; + } + + private onWsConnect() { + console.log('ASR WebSocket connected'); + } + + /** + * Construct request parameters + */ + private constructRequest(reqId: string): any { + return { + app: { + appid: this.appid, + cluster: this.cluster, + token: this.token, + }, + user: { + uid: this.uid, + }, + request: { + reqid: reqId, + nbest: this.nbest, + workflow: this.workflow, + show_language: this.showLanguage, + show_utterances: this.showUtterances, + result_type: this.resultType, + sequence: 1, + }, + audio: { + format: this.format, + rate: this.rate, + language: this.language, + bits: this.bits, + channel: this.channel, + codec: this.codec, + }, + }; + } + /** + * Generate headers for authentication + */ + private tokenAuth(): Record { + return { Authorization: `Bearer; ${this.token}` }; + } + + /** + * Process audio data in segments + */ + async segmentDataProcessor(wavData: Buffer, segmentSize: number): Promise { + const reqId = uuid(); + + // Construct full client request and compress + const requestParams = this.constructRequest(reqId); + const payloadBytes = Buffer.from(JSON.stringify(requestParams)); + const compressedPayload = await gzipPromise(payloadBytes); + + // Create full client request + const fullClientRequest = Buffer.concat([ + generateFullDefaultHeader(), + Buffer.alloc(4), // payload size placeholder + compressedPayload, + ]); + + // Set payload size + fullClientRequest.writeInt32BE(compressedPayload.length, 4); + + return new Promise(async (resolve, reject) => { + try { + this.ws.send(fullClientRequest); + const onMessage = async (event: MessageEvent) => { + const res = parseResponse(Buffer.from(event.data as ArrayBuffer)); + + if ('payloadMsg' in res && res.payloadMsg?.code !== this.successCode) { + resolve(res); + return; + } + + let seq = 1; + let lastMessage = null; + + for (const [chunk, last] of sliceData(wavData, segmentSize)) { + // Compress chunk if needed + const compressedChunk = await gzipPromise(chunk); + + // Create audio-only request + const audioOnlyHeader = last ? generateLastAudioDefaultHeader() : generateAudioDefaultHeader(); + + const audioOnlyRequest = Buffer.concat([ + audioOnlyHeader, + Buffer.alloc(4), // payload size placeholder + compressedChunk, + ]); + + // Set payload size + audioOnlyRequest.writeInt32BE(compressedChunk.length, 4); + + // Send audio data + this.ws.send(audioOnlyRequest); + + // Wait for response + const response = await new Promise((resolveChunk) => { + const messageHandler = (messageEvent: MessageEvent) => { + const result = parseResponse(Buffer.from(messageEvent.data as ArrayBuffer)); + this.ws.removeEventListener('message', messageHandler); + resolveChunk(result); + }; + + this.ws.addEventListener('message', messageHandler); + }); + + if ('payloadMsg' in response && response.payloadMsg?.code !== this.successCode) { + resolve(response); + return; + } + + lastMessage = response; + + if (last) { + break; + } + + seq++; + } + + resolve(lastMessage); + }; + this.ws.addEventListener('message', onMessage, { once: true }); + } catch (error) { + reject(error); + } + }); + } + + /** + * Execute ASR on the audio file + */ + async execute(): Promise { + try { + const data = await fs.readFile(this.audioPath); + + if (this.format === 'mp3') { + return await this.segmentDataProcessor(data, this.mp3SegSize); + } + + if (this.format !== 'wav') { + throw new Error('Format should be wav or mp3'); + } + const wavInfo = await readWavInfo(data); + const sizePerSec = wavInfo.channels * wavInfo.sampleWidth * wavInfo.sampleRate; + const segmentSize = Math.floor((sizePerSec * this.segDuration) / 1000); + + return await this.segmentDataProcessor(data, segmentSize); + } catch (error) { + console.error('Error executing ASR:', error); + throw error; + } + } +} + +/** + * Execute ASR on a single audio file + */ +export async function executeOne(audioItem: AudioItem, cluster: string, options: AsrClientOptions = {}): Promise { + if (!('id' in audioItem) || !('path' in audioItem)) { + throw new Error('Audio item must have id and path properties'); + } + + const audioId = audioItem.id; + const audioPath = audioItem.path; + const audioType = AudioType.LOCAL; + + const asrClient = new AsrWsClient(audioPath, cluster, { + ...options, + audioType, + }); + + const result = await asrClient.execute(); + return { id: audioId, path: audioPath, result }; +} + +/** + * Test function + */ +export async function testOne(audioPath: string, cluster: string, appid: string, token: string, audioFormat: string): Promise { + const result = await executeOne( + { + id: 1, + path: audioPath, + }, + cluster, + { + appid, + token, + format: audioFormat, + }, + ); + + console.log(result); +} diff --git a/src/asr/provider/volcengine/base.ts b/src/asr/provider/volcengine/base.ts new file mode 100644 index 0000000..260bb7e --- /dev/null +++ b/src/asr/provider/volcengine/base.ts @@ -0,0 +1,36 @@ +import { initWs } from '../../ws/index.ts'; +import { WSServer } from '../ws-server.ts'; +import { nanoid } from 'nanoid'; + +export const uuid = () => nanoid(16); + +type VolcEngineBaseOptions = { + url?: string; + ws?: WebSocket; + enabled?: boolean; + onConnect?: () => void; + wsOptions?: { + headers?: { + 'X-Api-Resource-Id'?: string; + 'X-Api-Access-Key'?: string; + 'X-Api-App-Key'?: string; + 'X-Api-Request-Id'?: string; + Authorization?: string; + }; + }; +}; +export class VolcEngineBase extends WSServer { + constructor(opts: VolcEngineBaseOptions) { + super({ + url: opts.url, + ws: opts.ws, + onConnect: opts.onConnect, + wsOptions: opts.wsOptions, + enabled: opts.enabled, + }); + } + async onOpen() { + console.log('VolcEngineBase onOpen'); + // 发送认证信息 + } +} diff --git a/src/asr/provider/volcengine/test/asr-bigmodel.ts b/src/asr/provider/volcengine/test/asr-bigmodel.ts new file mode 100644 index 0000000..8fb7578 --- /dev/null +++ b/src/asr/provider/volcengine/test/asr-bigmodel.ts @@ -0,0 +1,28 @@ +import { AsrWsClient, testStream } from '../asr-ws-big-model-client.ts'; +import { audioPath, config } from './common.ts'; + +// const asr = new AsrWsClient('videos/asr_example.wav'); + +// tsx src/asr/provider/volcengine/test/asr-bigmodel.ts +const main = async () => { + const audioId = '123'; + const asrClient = new AsrWsClient(audioPath, { + appid: config.APP_ID, + token: config.TOKEN, + streaming: false, + }); + await new Promise((resolve) => setTimeout(resolve, 2000)); + + return asrClient.execute().then((result) => { + return { + id: audioId, + path: audioPath, + result: result, + }; + }); +}; +const main2 = async () => { + testStream(); +}; + +main(); diff --git a/src/asr/provider/volcengine/test/asr.ts b/src/asr/provider/volcengine/test/asr.ts new file mode 100644 index 0000000..f73a5a6 --- /dev/null +++ b/src/asr/provider/volcengine/test/asr.ts @@ -0,0 +1,16 @@ +import { AsrWsClient } from '../asr-ws-client.ts'; + +import { audioPath, config, sleep } from './common.ts'; + +const asr = new AsrWsClient(audioPath, 'volcengine_input_common', { + appid: config.APP_ID, + token: config.TOKEN, +}); +// tsx src/asr/provider/volcengine/test/asr.ts +const main = async () => { + await sleep(1000); + const result = await asr.execute(); + console.log('result', JSON.stringify(result, null, 2)); +}; + +main(); diff --git a/src/asr/provider/volcengine/test/common.ts b/src/asr/provider/volcengine/test/common.ts new file mode 100644 index 0000000..02afb97 --- /dev/null +++ b/src/asr/provider/volcengine/test/common.ts @@ -0,0 +1,10 @@ +import path from 'path'; +import dotenv from 'dotenv'; + +export const config = dotenv.config({ + path: path.join(process.cwd(), '.env'), +}).parsed; + +export const audioPath = path.join(process.cwd(), 'videos/asr_example.wav'); + +export const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); diff --git a/src/asr/provider/ws-server.ts b/src/asr/provider/ws-server.ts new file mode 100644 index 0000000..5f13c8d --- /dev/null +++ b/src/asr/provider/ws-server.ts @@ -0,0 +1,41 @@ +import { initWs } from '../ws/index.ts'; +import type { ClientOptions } from 'ws'; +type WSSOptions = { + url: string; + ws?: WebSocket; + onConnect?: () => void; + wsOptions?: ClientOptions; + enabled?: boolean; +}; +export class WSServer { + ws: WebSocket; + onConnect?: () => void; + constructor(opts: WSSOptions) { + this.initWs(opts); + } + async initWs(opts: WSSOptions) { + const enabled = opts.enabled || true; + if (opts.ws) { + this.ws = opts.ws; + } else if (enabled) { + this.ws = await initWs(opts.url, opts.wsOptions); + } + this.onConnect = opts?.onConnect || (() => {}); + this.ws.onopen = this.onOpen.bind(this); + this.ws.onmessage = this.onMessage.bind(this); + this.ws.onerror = this.onError.bind(this); + this.ws.onclose = this.onClose.bind(this); + } + async onOpen() { + this.onConnect(); + } + async onMessage(event: MessageEvent) { + // console.log('WSS onMessage', event); + } + async onError(event: Event) { + console.error('WSS onError'); + } + async onClose(event: CloseEvent) { + console.error('WSS onClose'); + } +} diff --git a/src/asr/ws/browser.ts b/src/asr/ws/browser.ts new file mode 100644 index 0000000..45c77e5 --- /dev/null +++ b/src/asr/ws/browser.ts @@ -0,0 +1,18 @@ +// @ts-nocheck +// https://github.com/maxogden/websocket-stream/blob/48dc3ddf943e5ada668c31ccd94e9186f02fafbd/ws-fallback.js + +let ws: typeof WebSocket; + +if (typeof WebSocket !== 'undefined') { + ws = WebSocket; +} else if (typeof MozWebSocket !== 'undefined') { + ws = MozWebSocket; +} else if (typeof global !== 'undefined') { + ws = global.WebSocket || global.MozWebSocket; +} else if (typeof window !== 'undefined') { + ws = window.WebSocket || window.MozWebSocket; +} else if (typeof self !== 'undefined') { + ws = self.WebSocket || self.MozWebSocket; +} + +export default ws; diff --git a/src/asr/ws/index.ts b/src/asr/ws/index.ts new file mode 100644 index 0000000..51d6ee7 --- /dev/null +++ b/src/asr/ws/index.ts @@ -0,0 +1,25 @@ +const isBrowser = process?.env?.BROWSER === 'true'; + +type WebSocketOptions = { + /** + * 是否拒绝不安全的证书, in node only + */ + rejectUnauthorized?: boolean; + headers?: Record; + [key: string]: any; +}; +export const initWs = async (url: string, options?: WebSocketOptions) => { + let ws: WebSocket; + if (isBrowser) { + ws = new WebSocket(url); + } else { + const WebSocket = await import('ws').then((module) => module.default); + const { rejectUnauthorized, headers, ...rest } = options || {}; + ws = new WebSocket(url, { + rejectUnauthorized: rejectUnauthorized || true, + headers: headers, + ...rest, + }) as any; + } + return ws; +}; diff --git a/src/asr/ws/node.ts b/src/asr/ws/node.ts new file mode 100644 index 0000000..63396cf --- /dev/null +++ b/src/asr/ws/node.ts @@ -0,0 +1,3 @@ +import ws from 'ws'; + +export default ws; diff --git a/tsup.config.mjs b/tsup.config.mjs index 3c8d3f1..ce416c5 100644 --- a/tsup.config.mjs +++ b/tsup.config.mjs @@ -9,6 +9,9 @@ export default defineConfig({ outExtension: ({ format }) => ({ js: format === 'esm' ? '.mjs' : '.js', }), + define: { + 'process.env.IS_BROWSER': JSON.stringify(process.env.BROWSER || false), + }, splitting: false, sourcemap: false, clean: true,