From 3e4384492e4fc4ea91d6e46a05c846e2151b93ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=A1clav=20Volhejn?= Date: Wed, 25 Jun 2025 19:08:35 +0200 Subject: [PATCH] Allow visualizing VAD --- .../transcribe_from_mic_via_rust_server.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/scripts/transcribe_from_mic_via_rust_server.py b/scripts/transcribe_from_mic_via_rust_server.py index 56b09fa..3e5bf04 100644 --- a/scripts/transcribe_from_mic_via_rust_server.py +++ b/scripts/transcribe_from_mic_via_rust_server.py @@ -16,18 +16,31 @@ import numpy as np import sounddevice as sd import websockets -# Desired audio properties -TARGET_SAMPLE_RATE = 24000 -TARGET_CHANNELS = 1 # Mono +SAMPLE_RATE = 24000 + +# The VAD has several prediction heads, each of which tries to determine whether there +# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds. +# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2. +PAUSE_PREDICTION_HEAD_INDEX = 2 -async def receive_messages(websocket): +async def receive_messages(websocket, show_vad: bool = False): """Receive and process messages from the WebSocket server.""" try: + speech_started = False async for message in websocket: data = msgpack.unpackb(message, raw=False) - if data["type"] == "Word": + + # The Step message only gets sent if the model has semantic VAD available + if data["type"] == "Step" and show_vad: + pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX] + if pause_prediction > 0.5 and speech_started: + print("| ", end="", flush=True) + speech_started = False + + elif data["type"] == "Word": print(data["text"], end=" ", flush=True) + speech_started = True except websockets.ConnectionClosed: print("Connection closed while receiving messages.") @@ -38,17 +51,20 @@ async def send_messages(websocket, audio_queue): # Start by draining the queue to avoid lags while not audio_queue.empty(): await audio_queue.get() + print("Starting the transcription") + while True: audio_data = await audio_queue.get() chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]} msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await websocket.send(msg) + except websockets.ConnectionClosed: print("Connection closed while sending messages.") -async def stream_audio(url: str, api_key: str): +async def stream_audio(url: str, api_key: str, show_vad: bool): """Stream audio data to a WebSocket server.""" print("Starting microphone recording...") print("Press Ctrl+C to stop recording") @@ -63,8 +79,8 @@ async def stream_audio(url: str, api_key: str): # Start audio stream with sd.InputStream( - samplerate=TARGET_SAMPLE_RATE, - channels=TARGET_CHANNELS, + samplerate=SAMPLE_RATE, + channels=1, dtype="float32", callback=audio_callback, blocksize=1920, # 80ms blocks @@ -73,7 +89,9 @@ async def stream_audio(url: str, api_key: str): # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL async with websockets.connect(url, additional_headers=headers) as websocket: send_task = asyncio.create_task(send_messages(websocket, audio_queue)) - receive_task = asyncio.create_task(receive_messages(websocket)) + receive_task = asyncio.create_task( + receive_messages(websocket, show_vad=show_vad) + ) await asyncio.gather(send_task, receive_task) @@ -91,11 +109,16 @@ if __name__ == "__main__": parser.add_argument( "--device", type=int, help="Input device ID (use --list-devices to see options)" ) + parser.add_argument( + "--show-vad", + action="store_true", + help="Visualize the predictions of the semantic voice activity detector with a '|' symbol", + ) args = parser.parse_args() def handle_sigint(signum, frame): - print("Interrupted by user") + print("Interrupted by user") # Don't complain about KeyboardInterrupt exit(0) signal.signal(signal.SIGINT, handle_sigint) @@ -109,4 +132,4 @@ if __name__ == "__main__": sd.default.device[0] = args.device # Set input device url = f"{args.url}/api/asr-streaming" - asyncio.run(stream_audio(url, args.api_key)) + asyncio.run(stream_audio(url, args.api_key, args.show_vad))