From 581e1d338b7385d1e72fdcc49b7df113ebfb1590 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 17 Jun 2025 08:37:44 +0200 Subject: [PATCH] Add the rust indications. --- README.md | 48 +++++++++++- configs/config-stt-hf.toml | 42 +++++++++++ scripts/asr-streaming-query.py | 131 +++++++++++++++++++++++++++++++++ 3 files changed, 217 insertions(+), 4 deletions(-) create mode 100644 configs/config-stt-hf.toml create mode 100644 scripts/asr-streaming-query.py diff --git a/README.md b/README.md index 2edcacd..f836108 100644 --- a/README.md +++ b/README.md @@ -3,18 +3,58 @@ Delayed Streams Modeling (DSM) is a flexible formulation for streaming, multimod ## Speech To Text -### PyTorch implementation +### Leaderboard model +The leaderboard model handles english only, it has ~2.6B parameters. + +#### PyTorch implementation +[[Hugging Face]](https://huggingface.co/kyutai/stt) ```bash -python -m moshi.run_inference --hf-repo kyutai/stt input.mp3 +# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 +python -m moshi.run_inference --hf-repo kyutai/stt bria.mp3 ``` -### MLX implementation +#### MLX implementation +[[Hugging Face]](https://huggingface.co/kyutai/stt-mlx) ```bash -python -m moshi_mlx.run_inference --hf-repo kyutai/stt-mlx ~/tmp/bria-24khz.mp3 --temp 0 +# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 +python -m moshi_mlx.run_inference --hf-repo kyutai/stt-mlx bria.mp3 --temp 0 ``` +#### Rust implementation +[[Hugging Face]](https://huggingface.co/kyutai/stt-candle) + +The Rust implementation provides a server that can process multiple streaming +queries in parallel. Dependening on the amount of memory on your GPU, you may +have to adjust the batch size from the config file. For a L40S GPU, a batch size +of 64 works well. + +In order to run the server, install the `moshi-server` crate via the following +command. The server code can be found in the +[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server) +repository. +```bash +cargo install --features cuda moshi-server +``` + +Then the server can be started via the following command using the config file +from this repository. +```bash +moshi-server worker --config configs/config-stt-hf.toml +``` + +Once the server has started you can run a streaming inference with the following +script. +```bash +# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 +uv run scripts/asr-streaming-query.py bria.mp3 +``` + +The script simulates some real-time processing of the audio. Faster processing +can be triggered by setting the real-time factor, e.g. `--rtf 500` will process +the data as fast as possible. + ## License The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend. diff --git a/configs/config-stt-hf.toml b/configs/config-stt-hf.toml new file mode 100644 index 0000000..99c539d --- /dev/null +++ b/configs/config-stt-hf.toml @@ -0,0 +1,42 @@ +static_dir = "./static/" +log_dir = "$HOME/tmp/tts-logs" +instance_name = "tts" +authorized_ids = ["open_token"] + +[modules.asr] +path = "/api/asr-streaming" +type = "BatchedAsr" +lm_model_file = "hf://kyutai/stt-candle/model.safetensors" +text_tokenizer_file = "hf://kyutai/stt-candle/tokenizer_en_audio_4000.model" +audio_tokenizer_file = "hf://kyutai/stt-candle/mimi-pytorch-e351c8d8@125.safetensors" +asr_delay_in_tokens = 6 +batch_size = 16 +conditioning_learnt_padding = true +temperature = 0 + +[modules.asr.model] +audio_vocab_size = 2049 +text_in_vocab_size = 4001 +text_out_vocab_size = 4000 +audio_codebooks = 32 + +[modules.asr.model.transformer] +d_model = 2048 +num_heads = 32 +num_layers = 48 +dim_feedforward = 8192 +causal = true +norm_first = true +bias_ff = false +bias_attn = false +context = 375 +max_period = 100000 +use_conv_block = false +use_conv_bias = true +gating = "silu" +norm = "RmsNorm" +positional_embedding = "Rope" +conv_layout = false +conv_kernel_size = 3 +kv_repeat = 1 +max_seq_len = 40960 diff --git a/scripts/asr-streaming-query.py b/scripts/asr-streaming-query.py new file mode 100644 index 0000000..780abd2 --- /dev/null +++ b/scripts/asr-streaming-query.py @@ -0,0 +1,131 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "msgpack", +# "numpy", +# "sphn", +# "websockets", +# ] +# /// +import argparse +import asyncio +import json +import msgpack +import sphn +import struct +import time + +import numpy as np +import websockets + +# Desired audio properties +TARGET_SAMPLE_RATE = 24000 +TARGET_CHANNELS = 1 # Mono +HEADERS = {"kyutai-api-key": "open_token"} +all_text = [] +transcript = [] +finished = False + + +def load_and_process_audio(file_path): + """Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data.""" + pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE) + return pcm_data[0] + + +async def receive_messages(websocket): + global all_text + global transcript + global finished + try: + async for message in websocket: + data = msgpack.unpackb(message, raw=False) + if data["type"] == "Step": + continue + print("received:", data) + if data["type"] == "Word": + all_text.append(data["text"]) + transcript.append({ + "speaker": "SPEAKER_00", + "text": data["text"], + "timestamp": [data["start_time"], data["start_time"]], + }) + if data["type"] == "EndWord": + if len(transcript) > 0: + transcript[-1]["timestamp"][1] = data["stop_time"] + if data["type"] == "Marker": + print("Received marker, stopping stream.") + break + except websockets.ConnectionClosed: + print("Connection closed while receiving messages.") + finished = True + + +async def send_messages(websocket, rtf: float): + global finished + audio_data = load_and_process_audio(args.in_file) + try: + # Start with a second of silence + chunk = { "type": "Audio", "pcm": [0.0] * 24000 } + msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) + await websocket.send(msg) + + chunk_size = 1920 # Send data in chunks + start_time = time.time() + for i in range(0, len(audio_data), chunk_size): + chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] } + msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) + await websocket.send(msg) + expected_send_time = start_time + (i + 1) / 24000 / rtf + current_time = time.time() + if current_time < expected_send_time: + await asyncio.sleep(expected_send_time - current_time) + else: + await asyncio.sleep(0.001) + chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 } + msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) + await websocket.send(msg) + msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True) + await websocket.send(msg) + for _ in range(35): + chunk = { "type": "Audio", "pcm": [0.0] * 1920 } + msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) + await websocket.send(msg) + while True: + if finished: + break + await asyncio.sleep(1.0) + # Keep the connection alive as there is a 20s timeout on the rust side. + await websocket.ping() + except websockets.ConnectionClosed: + print("Connection closed while sending messages.") + + +async def stream_audio(url: str, rtf: float): + """Stream audio data to a WebSocket server.""" + + async with websockets.connect(url, additional_headers=HEADERS) as websocket: + send_task = asyncio.create_task(send_messages(websocket, rtf)) + receive_task = asyncio.create_task(receive_messages(websocket)) + await asyncio.gather(send_task, receive_task) + print("exiting") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("in_file") + parser.add_argument("--transcript") + parser.add_argument( + "--url", + help="The url of the server to which to send the audio", + default="ws://127.0.0.1:8080", + ) + parser.add_argument("--rtf", type=float, default=1.01) + args = parser.parse_args() + + url = f"{args.url}/api/asr-streaming" + asyncio.run(stream_audio(url, args.rtf)) + print(" ".join(all_text)) + if args.transcript is not None: + with open(args.transcript, "w") as fobj: + json.dump({"transcript": transcript}, fobj, indent=4)