Add the rust indications.

This commit is contained in:
laurent 2025-06-17 08:37:44 +02:00
parent e53c94d6b6
commit 581e1d338b
3 changed files with 217 additions and 4 deletions

View File

@ -3,18 +3,58 @@ Delayed Streams Modeling (DSM) is a flexible formulation for streaming, multimod
## Speech To Text
### PyTorch implementation
### Leaderboard model
The leaderboard model handles english only, it has ~2.6B parameters.
#### PyTorch implementation
[[Hugging Face]](https://huggingface.co/kyutai/stt)
```bash
python -m moshi.run_inference --hf-repo kyutai/stt input.mp3
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
python -m moshi.run_inference --hf-repo kyutai/stt bria.mp3
```
### MLX implementation
#### MLX implementation
[[Hugging Face]](https://huggingface.co/kyutai/stt-mlx)
```bash
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-mlx ~/tmp/bria-24khz.mp3 --temp 0
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
python -m moshi_mlx.run_inference --hf-repo kyutai/stt-mlx bria.mp3 --temp 0
```
#### Rust implementation
[[Hugging Face]](https://huggingface.co/kyutai/stt-candle)
The Rust implementation provides a server that can process multiple streaming
queries in parallel. Dependening on the amount of memory on your GPU, you may
have to adjust the batch size from the config file. For a L40S GPU, a batch size
of 64 works well.
In order to run the server, install the `moshi-server` crate via the following
command. The server code can be found in the
[kyutai-labs/moshi](https://github.com/kyutai-labs/moshi/tree/main/rust/moshi-server)
repository.
```bash
cargo install --features cuda moshi-server
```
Then the server can be started via the following command using the config file
from this repository.
```bash
moshi-server worker --config configs/config-stt-hf.toml
```
Once the server has started you can run a streaming inference with the following
script.
```bash
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
uv run scripts/asr-streaming-query.py bria.mp3
```
The script simulates some real-time processing of the audio. Faster processing
can be triggered by setting the real-time factor, e.g. `--rtf 500` will process
the data as fast as possible.
## License
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.

View File

@ -0,0 +1,42 @@
static_dir = "./static/"
log_dir = "$HOME/tmp/tts-logs"
instance_name = "tts"
authorized_ids = ["open_token"]
[modules.asr]
path = "/api/asr-streaming"
type = "BatchedAsr"
lm_model_file = "hf://kyutai/stt-candle/model.safetensors"
text_tokenizer_file = "hf://kyutai/stt-candle/tokenizer_en_audio_4000.model"
audio_tokenizer_file = "hf://kyutai/stt-candle/mimi-pytorch-e351c8d8@125.safetensors"
asr_delay_in_tokens = 6
batch_size = 16
conditioning_learnt_padding = true
temperature = 0
[modules.asr.model]
audio_vocab_size = 2049
text_in_vocab_size = 4001
text_out_vocab_size = 4000
audio_codebooks = 32
[modules.asr.model.transformer]
d_model = 2048
num_heads = 32
num_layers = 48
dim_feedforward = 8192
causal = true
norm_first = true
bias_ff = false
bias_attn = false
context = 375
max_period = 100000
use_conv_block = false
use_conv_bias = true
gating = "silu"
norm = "RmsNorm"
positional_embedding = "Rope"
conv_layout = false
conv_kernel_size = 3
kv_repeat = 1
max_seq_len = 40960

View File

@ -0,0 +1,131 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "msgpack",
# "numpy",
# "sphn",
# "websockets",
# ]
# ///
import argparse
import asyncio
import json
import msgpack
import sphn
import struct
import time
import numpy as np
import websockets
# Desired audio properties
TARGET_SAMPLE_RATE = 24000
TARGET_CHANNELS = 1 # Mono
HEADERS = {"kyutai-api-key": "open_token"}
all_text = []
transcript = []
finished = False
def load_and_process_audio(file_path):
"""Load an MP3 file, resample to 24kHz, convert to mono, and extract PCM float32 data."""
pcm_data, _ = sphn.read(file_path, sample_rate=TARGET_SAMPLE_RATE)
return pcm_data[0]
async def receive_messages(websocket):
global all_text
global transcript
global finished
try:
async for message in websocket:
data = msgpack.unpackb(message, raw=False)
if data["type"] == "Step":
continue
print("received:", data)
if data["type"] == "Word":
all_text.append(data["text"])
transcript.append({
"speaker": "SPEAKER_00",
"text": data["text"],
"timestamp": [data["start_time"], data["start_time"]],
})
if data["type"] == "EndWord":
if len(transcript) > 0:
transcript[-1]["timestamp"][1] = data["stop_time"]
if data["type"] == "Marker":
print("Received marker, stopping stream.")
break
except websockets.ConnectionClosed:
print("Connection closed while receiving messages.")
finished = True
async def send_messages(websocket, rtf: float):
global finished
audio_data = load_and_process_audio(args.in_file)
try:
# Start with a second of silence
chunk = { "type": "Audio", "pcm": [0.0] * 24000 }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
chunk_size = 1920 # Send data in chunks
start_time = time.time()
for i in range(0, len(audio_data), chunk_size):
chunk = { "type": "Audio", "pcm": [float(x) for x in audio_data[i : i + chunk_size]] }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
expected_send_time = start_time + (i + 1) / 24000 / rtf
current_time = time.time()
if current_time < expected_send_time:
await asyncio.sleep(expected_send_time - current_time)
else:
await asyncio.sleep(0.001)
chunk = { "type": "Audio", "pcm": [0.0] * 1920 * 5 }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
msg = msgpack.packb({"type": "Marker", "id": 0}, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
for _ in range(35):
chunk = { "type": "Audio", "pcm": [0.0] * 1920 }
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
await websocket.send(msg)
while True:
if finished:
break
await asyncio.sleep(1.0)
# Keep the connection alive as there is a 20s timeout on the rust side.
await websocket.ping()
except websockets.ConnectionClosed:
print("Connection closed while sending messages.")
async def stream_audio(url: str, rtf: float):
"""Stream audio data to a WebSocket server."""
async with websockets.connect(url, additional_headers=HEADERS) as websocket:
send_task = asyncio.create_task(send_messages(websocket, rtf))
receive_task = asyncio.create_task(receive_messages(websocket))
await asyncio.gather(send_task, receive_task)
print("exiting")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file")
parser.add_argument("--transcript")
parser.add_argument(
"--url",
help="The url of the server to which to send the audio",
default="ws://127.0.0.1:8080",
)
parser.add_argument("--rtf", type=float, default=1.01)
args = parser.parse_args()
url = f"{args.url}/api/asr-streaming"
asyncio.run(stream_audio(url, args.rtf))
print(" ".join(all_text))
if args.transcript is not None:
with open(args.transcript, "w") as fobj:
json.dump({"transcript": transcript}, fobj, indent=4)