diff --git a/.gitignore b/.gitignore index 013ebc7..ba90038 100644 --- a/.gitignore +++ b/.gitignore @@ -192,5 +192,3 @@ cython_debug/ # refer to https://docs.cursor.com/context/ignore-files .cursorignore .cursorindexingignore -bria.mp3 -sample_fr_hibiki_crepes.mp3 diff --git a/README.md b/README.md index 5b79d1a..b546f3b 100644 --- a/README.md +++ b/README.md @@ -48,12 +48,6 @@ Here is how to choose which one to use: MLX is Apple's ML framework that allows you to use hardware acceleration on Apple silicon. If you want to run the model on a Mac or an iPhone, choose the MLX implementation. -You can retrieve the sample files used in the following snippets via: -```bash -wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 -wget https://github.com/kyutai-labs/moshi/raw/refs/heads/main/data/sample_fr_hibiki_crepes.mp3 -``` - ### PyTorch implementation Hugging Face @@ -70,12 +64,12 @@ This requires the [moshi package](https://pypi.org/project/moshi/) with version 0.2.6 or later, which can be installed via pip. ```bash -python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3 +python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3 ``` If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly: ```bash -uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria.mp3 +uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en audio/bria.mp3 ``` Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs: @@ -84,7 +78,7 @@ Additionally, we provide two scripts that highlight different usage scenarios. T uv run \ scripts/transcribe_from_file_via_pytorch.py \ --hf-repo kyutai/stt-2.6b-en \ - --file bria.mp3 + --file audio/bria.mp3 ``` The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics: @@ -130,7 +124,7 @@ uv run scripts/transcribe_from_mic_via_rust_server.py We also provide a script for transcribing from an audio file. ```bash -uv run scripts/transcribe_from_file_via_rust_server.py bria.mp3 +uv run scripts/transcribe_from_file_via_rust_server.py audio/bria.mp3 ``` The script limits the decoding speed to simulates real-time processing of the audio. @@ -147,7 +141,7 @@ A standalone Rust example script is provided in the `stt-rs` directory in this r This can be used as follows: ```bash cd stt-rs -cargo run --features cuda -r -- bria.mp3 +cargo run --features cuda -r -- audio/bria.mp3 ``` You can get the timestamps by adding the `--timestamps` flag, and see the output of the semantic VAD by adding the `--vad` flag. @@ -164,12 +158,12 @@ This requires the [moshi-mlx package](https://pypi.org/project/moshi-mlx/) with version 0.2.6 or later, which can be installed via pip. ```bash -python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0 +python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0 ``` If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step and run directly: ```bash -uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx bria.mp3 --temp 0 +uvx --with moshi-mlx python -m moshi_mlx.run_inference --hf-repo kyutai/stt-2.6b-en-mlx audio/bria.mp3 --temp 0 ``` It will install the moshi package in a temporary environment and run the speech-to-text. diff --git a/audio/bria.mp3 b/audio/bria.mp3 new file mode 100644 index 0000000..17f5f80 Binary files /dev/null and b/audio/bria.mp3 differ diff --git a/audio/sample_fr_hibiki_crepes.mp3 b/audio/sample_fr_hibiki_crepes.mp3 new file mode 100644 index 0000000..064483e Binary files /dev/null and b/audio/sample_fr_hibiki_crepes.mp3 differ diff --git a/configs/config-stt-en-hf.toml b/configs/config-stt-en-hf.toml index 5eaf7c6..75382f8 100644 --- a/configs/config-stt-en-hf.toml +++ b/configs/config-stt-en-hf.toml @@ -1,7 +1,7 @@ static_dir = "./static/" log_dir = "$HOME/tmp/tts-logs" instance_name = "tts" -authorized_ids = ["open_token"] +authorized_ids = ["public_token"] [modules.asr] path = "/api/asr-streaming" diff --git a/configs/config-stt-en_fr-hf.toml b/configs/config-stt-en_fr-hf.toml index e2fdf96..15a880f 100644 --- a/configs/config-stt-en_fr-hf.toml +++ b/configs/config-stt-en_fr-hf.toml @@ -1,7 +1,7 @@ static_dir = "./static/" log_dir = "$HOME/tmp/tts-logs" instance_name = "tts" -authorized_ids = ["open_token"] +authorized_ids = ["public_token"] [modules.asr] path = "/api/asr-streaming" diff --git a/scripts/transcribe_from_file_via_rust_server.py b/scripts/transcribe_from_file_via_rust_server.py index 3e90b79..9333ca9 100644 --- a/scripts/transcribe_from_file_via_rust_server.py +++ b/scripts/transcribe_from_file_via_rust_server.py @@ -9,129 +9,127 @@ # /// import argparse import asyncio -import json import time import msgpack +import numpy as np import sphn import websockets -# Desired audio properties -TARGET_SAMPLE_RATE = 24000 -TARGET_CHANNELS = 1 # Mono -HEADERS = {"kyutai-api-key": "open_token"} -all_text = [] -transcript = [] -finished = False +SAMPLE_RATE = 24000 +FRAME_SIZE = 1920 # Send data in chunks def load_and_process_audio(file_path): """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data.""" - pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE) + pcm_data, _ = sphn.read(file_path, sample_rate=SAMPLE_RATE) return pcm_data[0] async def receive_messages(websocket): - global all_text - global transcript - global finished - try: - async for message in websocket: - data = msgpack.unpackb(message, raw=False) - if data["type"] == "Step": - continue - print("received:", data) - if data["type"] == "Word": - all_text.append(data["text"]) - transcript.append( - { - "speaker": "SPEAKER_00", - "text": data["text"], - "timestamp": [data["start_time"], data["start_time"]], - } - ) - if data["type"] == "EndWord": - if len(transcript) > 0: - transcript[-1]["timestamp"][1] = data["stop_time"] - if data["type"] == "Marker": - print("Received marker, stopping stream.") - break - except websockets.ConnectionClosed: - print("Connection closed while receiving messages.") - finished = True + transcript = [] + + async for message in websocket: + data = msgpack.unpackb(message, raw=False) + if data["type"] == "Step": + # This message contains the signal from the semantic VAD, and tells us how + # much audio the server has already processed. We don't use either here. + continue + if data["type"] == "Word": + print(data["text"], end=" ", flush=True) + transcript.append( + { + "text": data["text"], + "timestamp": [data["start_time"], data["start_time"]], + } + ) + if data["type"] == "EndWord": + if len(transcript) > 0: + transcript[-1]["timestamp"][1] = data["stop_time"] + if data["type"] == "Marker": + # Received marker, stopping stream + break + + return transcript async def send_messages(websocket, rtf: float): - global finished audio_data = load_and_process_audio(args.in_file) - try: - # Start with a second of silence. - # This is needed for the 2.6B model for technical reasons. - chunk = {"type": "Audio", "pcm": [0.0] * 24000} - msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) - await websocket.send(msg) - chunk_size = 1920 # Send data in chunks - start_time = time.time() - for i in range(0, len(audio_data), chunk_size): - chunk = { - "type": "Audio", - "pcm": [float(x) for x in audio_data[i : i + chunk_size]], - } - msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) - await websocket.send(msg) - expected_send_time = start_time + (i + 1) / 24000 / rtf - current_time = time.time() - if current_time < expected_send_time: - await asyncio.sleep(expected_send_time - current_time) - else: - await asyncio.sleep(0.001) - chunk = {"type": "Audio", "pcm": [0.0] * 1920 * 5} - msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) - await websocket.send(msg) - msg = msgpack.packb( - {"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True + async def send_audio(audio: np.ndarray): + await websocket.send( + msgpack.packb( + {"type": "Audio", "pcm": [float(x) for x in audio]}, + use_single_float=True, + ) ) - await websocket.send(msg) - for _ in range(35): - chunk = {"type": "Audio", "pcm": [0.0] * 1920} - msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) - await websocket.send(msg) - while True: - if finished: - break - await asyncio.sleep(1.0) - # Keep the connection alive as there is a 20s timeout on the rust side. - await websocket.ping() - except websockets.ConnectionClosed: - print("Connection closed while sending messages.") + + # Start with a second of silence. + # This is needed for the 2.6B model for technical reasons. + await send_audio([0.0] * SAMPLE_RATE) + + start_time = time.time() + for i in range(0, len(audio_data), FRAME_SIZE): + await send_audio(audio_data[i : i + FRAME_SIZE]) + + expected_send_time = start_time + (i + 1) / SAMPLE_RATE / rtf + current_time = time.time() + if current_time < expected_send_time: + await asyncio.sleep(expected_send_time - current_time) + else: + await asyncio.sleep(0.001) + + for _ in range(5): + await send_audio([0.0] * SAMPLE_RATE) + + # Send a marker to indicate the end of the stream. + await websocket.send( + msgpack.packb({"type": "Marker", "id": 0}, use_single_float=True) + ) + + # We'll get back the marker once the corresponding audio has been transcribed, + # accounting for the delay of the model. That's why we need to send some silence + # after the marker, because the model will not return the marker immediately. + for _ in range(35): + await send_audio([0.0] * SAMPLE_RATE) -async def stream_audio(url: str, rtf: float): +async def stream_audio(url: str, api_key: str, rtf: float): """Stream audio data to a WebSocket server.""" + headers = {"kyutai-api-key": api_key} - async with websockets.connect(url, additional_headers=HEADERS) as websocket: + # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL + async with websockets.connect(url, additional_headers=headers) as websocket: send_task = asyncio.create_task(send_messages(websocket, rtf)) receive_task = asyncio.create_task(receive_messages(websocket)) - await asyncio.gather(send_task, receive_task) - print("exiting") + _, transcript = await asyncio.gather(send_task, receive_task) + + return transcript if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("in_file") - parser.add_argument("--transcript") parser.add_argument( "--url", help="The url of the server to which to send the audio", default="ws://127.0.0.1:8080", ) - parser.add_argument("--rtf", type=float, default=1.01) + parser.add_argument("--api-key", default="public_token") + parser.add_argument( + "--rtf", + type=float, + default=1.01, + help="The real-time factor of how fast to feed in the audio.", + ) args = parser.parse_args() url = f"{args.url}/api/asr-streaming" - asyncio.run(stream_audio(url, args.rtf)) - print(" ".join(all_text)) - if args.transcript is not None: - with open(args.transcript, "w") as fobj: - json.dump({"transcript": transcript}, fobj, indent=4) + transcript = asyncio.run(stream_audio(url, args.api_key, args.rtf)) + + print() + print() + for word in transcript: + print( + f"{word['timestamp'][0]:7.2f} -{word['timestamp'][1]:7.2f} {word['text']}" + ) diff --git a/scripts/transcribe_from_mic_via_rust_server.py b/scripts/transcribe_from_mic_via_rust_server.py index 85bad93..3e5bf04 100644 --- a/scripts/transcribe_from_mic_via_rust_server.py +++ b/scripts/transcribe_from_mic_via_rust_server.py @@ -16,43 +16,59 @@ import numpy as np import sounddevice as sd import websockets -# Desired audio properties -TARGET_SAMPLE_RATE = 24000 -TARGET_CHANNELS = 1 # Mono -audio_queue = asyncio.Queue() +SAMPLE_RATE = 24000 + +# The VAD has several prediction heads, each of which tries to determine whether there +# has been a pause of a given length. The lengths are 0.5, 1.0, 2.0, and 3.0 seconds. +# Lower indices predict pauses more aggressively. In Unmute, we use 2.0 seconds = index 2. +PAUSE_PREDICTION_HEAD_INDEX = 2 -async def receive_messages(websocket): +async def receive_messages(websocket, show_vad: bool = False): """Receive and process messages from the WebSocket server.""" try: + speech_started = False async for message in websocket: data = msgpack.unpackb(message, raw=False) - if data["type"] == "Word": + + # The Step message only gets sent if the model has semantic VAD available + if data["type"] == "Step" and show_vad: + pause_prediction = data["prs"][PAUSE_PREDICTION_HEAD_INDEX] + if pause_prediction > 0.5 and speech_started: + print("| ", end="", flush=True) + speech_started = False + + elif data["type"] == "Word": print(data["text"], end=" ", flush=True) + speech_started = True except websockets.ConnectionClosed: print("Connection closed while receiving messages.") -async def send_messages(websocket): +async def send_messages(websocket, audio_queue): """Send audio data from microphone to WebSocket server.""" try: # Start by draining the queue to avoid lags while not audio_queue.empty(): await audio_queue.get() + print("Starting the transcription") + while True: audio_data = await audio_queue.get() chunk = {"type": "Audio", "pcm": [float(x) for x in audio_data]} msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) await websocket.send(msg) + except websockets.ConnectionClosed: print("Connection closed while sending messages.") -async def stream_audio(url: str, api_key: str): +async def stream_audio(url: str, api_key: str, show_vad: bool): """Stream audio data to a WebSocket server.""" print("Starting microphone recording...") print("Press Ctrl+C to stop recording") + audio_queue = asyncio.Queue() loop = asyncio.get_event_loop() @@ -63,16 +79,19 @@ async def stream_audio(url: str, api_key: str): # Start audio stream with sd.InputStream( - samplerate=TARGET_SAMPLE_RATE, - channels=TARGET_CHANNELS, + samplerate=SAMPLE_RATE, + channels=1, dtype="float32", callback=audio_callback, blocksize=1920, # 80ms blocks ): headers = {"kyutai-api-key": api_key} + # Instead of using the header, you can authenticate by adding `?auth_id={api_key}` to the URL async with websockets.connect(url, additional_headers=headers) as websocket: - send_task = asyncio.create_task(send_messages(websocket)) - receive_task = asyncio.create_task(receive_messages(websocket)) + send_task = asyncio.create_task(send_messages(websocket, audio_queue)) + receive_task = asyncio.create_task( + receive_messages(websocket, show_vad=show_vad) + ) await asyncio.gather(send_task, receive_task) @@ -83,18 +102,23 @@ if __name__ == "__main__": help="The URL of the server to which to send the audio", default="ws://127.0.0.1:8080", ) - parser.add_argument("--api-key", default="open_token") + parser.add_argument("--api-key", default="public_token") parser.add_argument( "--list-devices", action="store_true", help="List available audio devices" ) parser.add_argument( "--device", type=int, help="Input device ID (use --list-devices to see options)" ) + parser.add_argument( + "--show-vad", + action="store_true", + help="Visualize the predictions of the semantic voice activity detector with a '|' symbol", + ) args = parser.parse_args() def handle_sigint(signum, frame): - print("Interrupted by user") + print("Interrupted by user") # Don't complain about KeyboardInterrupt exit(0) signal.signal(signal.SIGINT, handle_sigint) @@ -108,4 +132,4 @@ if __name__ == "__main__": sd.default.device[0] = args.device # Set input device url = f"{args.url}/api/asr-streaming" - asyncio.run(stream_audio(url, args.api_key)) + asyncio.run(stream_audio(url, args.api_key, args.show_vad))