Allow visualizing VAD

This commit is contained in:
Václav Volhejn 2025-06-25 19:08:35 +02:00
parent b036dc2f68
commit 3e4384492e

View File

@ -16,18 +16,31 @@ import numpy as np
import sounddevice as sd
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
SAMPLE_RATE = 24000
# The VAD has several prediction heads, each of which tries to determine whether there
# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds.
# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2.
PAUSE_PREDICTION_HEAD_INDEX = 2
async def receive_messages(websocket):
async def receive_messages(websocket, show_vad: bool = False):
"""Receive and process messages from the WebSocket server."""
try:
speech_started = False
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
if data["type"] == "Word":
# The Step message only gets sent if the model has semantic VAD available
if data["type"] == "Step" and show_vad:
pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX]
if pause_prediction > 0.5 and speech_started:
print("| ", end="", flush=True)
speech_started = False
elif data["type"] == "Word":
print(data["text"], end=" ", flush=True)
speech_started = True
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
@ -38,17 +51,20 @@ async def send_messages(websocket, audio_queue):
# Start by draining the queue to avoid lags
while not audio_queue.empty():
await audio_queue.get()
print("Starting the transcription")
while True:
audio_data = await audio_queue.get()
chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, api_key: str):
async def stream_audio(url: str, api_key: str, show_vad: bool):
"""Stream audio data to a WebSocket server."""
print("Starting microphone recording...")
print("Press Ctrl+C to stop recording")
@ -63,8 +79,8 @@ async def stream_audio(url: str, api_key: str):
# Start audio stream
with sd.InputStream(
samplerate=TARGET_SAMPLE_RATE,
channels=TARGET_CHANNELS,
samplerate=SAMPLE_RATE,
channels=1,
dtype="float32",
callback=audio_callback,
blocksize=1920, # 80ms blocks
@ -73,7 +89,9 @@ async def stream_audio(url: str, api_key: str):
# Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL
async with websockets.connect(url, additional_headers=headers) as websocket:
send_task = asyncio.create_task(send_messages(websocket, audio_queue))
receive_task = asyncio.create_task(receive_messages(websocket))
receive_task = asyncio.create_task(
receive_messages(websocket, show_vad=show_vad)
)
await asyncio.gather(send_task, receive_task)
@ -91,11 +109,16 @@ if __name__ == "__main__":
parser.add_argument(
"--device", type=int, help="Input device ID (use --list-devices to see options)"
)
parser.add_argument(
"--show-vad",
action="store_true",
help="Visualize the predictions of the semantic voice activity detector with a '|' symbol",
)
args = parser.parse_args()
def handle_sigint(signum, frame):
print("Interrupted by user")
print("Interrupted by user") # Don't complain about KeyboardInterrupt
exit(0)
signal.signal(signal.SIGINT, handle_sigint)
@ -109,4 +132,4 @@ if __name__ == "__main__":
sd.default.device[0] = args.device # Set input device
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.api_key))
asyncio.run(stream_audio(url, args.api_key, args.show_vad))