This commit is contained in:
Václav Volhejn 2025-06-25 17:50:53 +02:00
parent 8c18018b6f
commit 15c8c305ad

View File

@ -9,129 +9,120 @@
# /// # ///
import argparse import argparse
import asyncio import asyncio
import json
import time import time
import msgpack import msgpack
import numpy as np
import sphn import sphn
import websockets import websockets
# Desired audio properties SAMPLE_RATE = 24000
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono TARGET_CHANNELS = 1 # Mono
FRAME_SIZE = 1920 # Send data in chunks
HEADERS = {"kyutai-api-key": "open_token"} HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
def load_and_process_audio(file_path): def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data.""" """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE) pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE)
return pcm_data[0] return pcm_data[0]
async def receive_messages(websocket): async def receive_messages(websocket):
global all_text transcript = []
global transcript
global finished async for message in websocket:
try: data = msgpack.unpackb(message, raw=False)
async for message in websocket: if data["type"] == "Step":
data = msgpack.unpackb(message, raw=False) # This message contains the signal from the semantic VAD, and tells us how
if data["type"] == "Step": # much audio the server has already processed. We don't use either here.
continue continue
print("received:", data) if data["type"] == "Word":
if data["type"] == "Word": print(data["text"], end=" ", flush=True)
all_text.append(data["text"]) transcript.append(
transcript.append( {
{ "text": data["text"],
"speaker": "SPEAKER_00", "timestamp": [data["start_time"], data["start_time"]],
"text": data["text"], }
"timestamp": [data["start_time"], data["start_time"]], )
} if data["type"] == "EndWord":
) if len(transcript) > 0:
if data["type"] == "EndWord": transcript[-1]["timestamp"][1] = data["stop_time"]
if len(transcript) > 0: if data["type"] == "Marker":
transcript[-1]["timestamp"][1] = data["stop_time"] # Received marker, stopping stream
if data["type"] == "Marker": break
print("Received marker, stopping stream.")
break return transcript
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
finished = True
async def send_messages(websocket, rtf: float): async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file) audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence.
# This is needed for the 2.6B model for technical reasons.
chunk = {"type": "Audio", "pcm": [0.0] * 24000}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks async def send_audio(audio: np.ndarray):
start_time = time.time() await websocket.send(
for i in range(0, len(audio_data), chunk_size): msgpack.packb(
chunk = { {"type": "Audio", "pcm": [float(x) for x in audio]},
"type": "Audio", use_single_float=True,
"pcm": [float(x) for x in audio_data[i : i + chunk_size]], )
}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
current_time = time.time()
if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb(
{"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True
) )
await websocket.send(msg)
for _ in range(35): # Start with a second of silence.
chunk = {"type": "Audio", "pcm": [0.0] * 1920} # This is needed for the 2.6B model for technical reasons.
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await send_audio([0.0] * SAMPLE_RATE)
await websocket.send(msg)
while True: start_time = time.time()
if finished: for i in range(0, len(audio_data), FRAME_SIZE):
break await send_audio(audio_data[i : i + FRAME_SIZE])
await asyncio.sleep(1.0)
# Keep the connection alive as there is a 20s timeout on the rust side. expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf
await websocket.ping() current_time = time.time()
except websockets.ConnectionClosed: if current_time < expected_send_time:
print("Connection closed while sending messages.") await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
for _ in range(5):
await send_audio([0.0] * SAMPLE_RATE)
# Send a marker to indicate the end of the stream.
await websocket.send(
msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True)
)
# We'll get back the marker once the corresponding audio has been transcribed,
# accounting for the delay of the model. That's why we need to send some silence
# after the marker, because the model will not return the marker immediately.
for _ in range(35):
await send_audio([0.0] * SAMPLE_RATE)
async def stream_audio(url: str, rtf: float): async def stream_audio(url: str, api_key: str, rtf: float):
"""Stream audio data to a WebSocket server.""" """Stream audio data to a WebSocket server."""
headers = {"kyutai-api-key": api_key}
async with websockets.connect(url, additional_headers=HEADERS) as websocket: async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf)) send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket)) receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task) _, transcript = await asyncio.gather(send_task, receive_task)
print("exiting")
return transcript
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("in_file") parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument( parser.add_argument(
"--url", "--url",
help="The url of the server to which to send the audio", help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080", default="ws://127.0.0.1:8080",
) )
parser.add_argument("--api-key", default="public_token")
parser.add_argument("--rtf", type=float, default=1.01) parser.add_argument("--rtf", type=float, default=1.01)
args = parser.parse_args() args = parser.parse_args()
url = f"{args.url}/api/asr-streaming" url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf)) transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf))
print(" ".join(all_text))
if args.transcript is not None: print()
with open(args.transcript, "w") as fobj: print()
json.dump({"transcript": transcript}, fobj, indent=4) print(transcript)