This commit is contained in:
Václav Volhejn 2025-06-25 17:50:53 +02:00
parent 8c18018b6f
commit 15c8c305ad

View File

@ -9,43 +9,38 @@
# /// # ///
import argparse import argparse
import asyncio import asyncio
import json
import time import time
import msgpack import msgpack
import numpy as np
import sphn import sphn
import websockets import websockets
# Desired audio properties SAMPLE_RATE = 24000
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono TARGET_CHANNELS = 1 # Mono
FRAME_SIZE = 1920 # Send data in chunks
HEADERS = {"kyutai-api-key": "open_token"} HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
def load_and_process_audio(file_path): def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data.""" """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE) pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
return pcm_data[0] return pcm_data[0]
async def receive_messages(websocket): async def receive_messages(websocket):
global all_text transcript = []
global transcript
global finished
try:
async for message in websocket: async for message in websocket:
data = msgpack.unpackb(message, raw=False) data = msgpack.unpackb(message, raw=False)
if data["type"] == "Step": if data["type"] == "Step":
# This message contains the signal from the semantic VAD, and tells us how
# much audio the server has already processed. We don't use either here.
continue continue
print("received:", data)
if data["type"] == "Word": if data["type"] == "Word":
all_text.append(data["text"]) print(data["text"], end=" ", flush=True)
transcript.append( transcript.append(
{ {
"speaker": "SPEAKER_00",
"text": data["text"], "text": data["text"],
"timestamp": [data["start_time"], data["start_time"]], "timestamp": [data["start_time"], data["start_time"]],
} }
@ -54,84 +49,80 @@ async def receive_messages(websocket):
if len(transcript) > 0: if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"] transcript[-1]["timestamp"][1] = data["stop_time"]
if data["type"] == "Marker": if data["type"] == "Marker":
print("Received marker, stopping stream.") # Received marker, stopping stream
break break
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.") return transcript
finished = True
async def send_messages(websocket, rtf: float): async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file) audio_data = load_and_process_audio(args.in_file)
try:
async def send_audio(audio: np.ndarray):
await websocket.send(
msgpack.packb(
{"type": "Audio", "pcm": [float(x) for x in audio]},
use_single_float=True,
)
)
# Start with a second of silence. # Start with a second of silence.
# This is needed for the 2.6B model for technical reasons. # This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000} await send_audio([0.0] * SAMPLE_RATE)
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time() start_time = time.time()
for i in range(0, len(audio_data), chunk_size): for i in range(0, len(audio_data), FRAME_SIZE):
chunk = { await send_audio(audio_data[i : i + FRAME_SIZE])
"type": "Audio",
"pcm": [float(x) for x in audio_data[i : i + chunk_size]], expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
current_time = time.time() current_time = time.time()
if current_time < expected_send_time: if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time) await asyncio.sleep(expected_send_time - current_time)
else: else:
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) for _ in range(5):
await websocket.send(msg) await send_audio([0.0] * SAMPLE_RATE)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True # Send a marker to indicate the end of the stream.
await websocket.send(
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
) )
await websocket.send(msg)
# We'll get back the marker once the corresponding audio has been transcribed,
# accounting for the delay of the model. That's why we need to send some silence
# after the marker, because the model will not return the marker immediately.
for _ in range(35): for _ in range(35):
chunk = {"type": "Audio", "pcm": [0.0] * 1920} await send_audio([0.0] * SAMPLE_RATE)
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
if finished:
break
await asyncio.sleep(1.0)
# Keep the connection alive as there is a 20s timeout on the rust side.
await websocket.ping()
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, rtf: float): async def stream_audio(url: str, api_key: str, rtf: float):
"""Stream audio data to a WebSocket server.""" """Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=HEADERS) as websocket: async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf)) send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket)) receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task) _, transcript = await asyncio.gather(send_task, receive_task)
print("exiting")
return transcript
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("in_file") parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument( parser.add_argument(
"--url", "--url",
help="The url of the server to which to send the audio", help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080", default="ws://127.0.0.1:8080",
) )
parser.add_argument("--api-key", default="public_token")
parser.add_argument("--rtf", type=float, default=1.01) parser.add_argument("--rtf", type=float, default=1.01)
args = parser.parse_args() args = parser.parse_args()
url = f"{args.url}/api/asr-streaming" url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf)) transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
print(" ".join(all_text))
if args.transcript is not None: print()
with open(args.transcript, "w") as fobj: print()
json.dump({"transcript": transcript}, fobj, indent=4) print(transcript)